Files
X-Financial/server/src/app/api/v1/endpoints/ocr.py

68 lines
2.5 KiB
Python
Raw Normal View History

from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
from sqlalchemy.orm import Session
from starlette.concurrency import run_in_threadpool
from app.api.deps import CurrentUserContext, get_current_user, get_db
from app.schemas.common import ErrorResponse
from app.schemas.ocr import OcrRecognizeBatchRead
from app.services.ocr import OcrService
from app.services.receipt_folder import ReceiptFolderService
router = APIRouter(prefix="/ocr")
@router.post(
"/recognize",
response_model=OcrRecognizeBatchRead,
summary="识别票据或图片 OCR",
description="使用 PaddleOCR mobile 模型对上传的图片或 PDF 执行 OCR并返回结构化文本摘要。",
responses={
status.HTTP_400_BAD_REQUEST: {
"model": ErrorResponse,
"description": "未上传文件或文件参数非法。",
},
status.HTTP_401_UNAUTHORIZED: {
"model": ErrorResponse,
"description": "未提供当前登录用户。",
},
status.HTTP_503_SERVICE_UNAVAILABLE: {
"model": ErrorResponse,
"description": "OCR 运行时不可用或执行失败。",
},
},
)
async def recognize_ocr_documents(
files: Annotated[list[UploadFile], File(description="待识别的票据图片或 PDF。")],
current_user: Annotated[CurrentUserContext, Depends(get_current_user)],
db: Annotated[Session, Depends(get_db)],
receipt_ids: Annotated[list[str] | None, Form(description="可选,来源于票据夹的持久化票据 ID。")] = None,
) -> OcrRecognizeBatchRead:
try:
payload = []
for upload in files:
payload.append(
(
str(upload.filename or "upload.bin"),
await upload.read(),
upload.content_type,
)
)
result = await run_in_threadpool(lambda: OcrService(db).recognize_files(payload))
return ReceiptFolderService().persist_ocr_batch(
files=payload,
result=result,
current_user=current_user,
receipt_ids=receipt_ids or [],
)
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
except RuntimeError as exc:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=str(exc),
) from exc