from __future__ import annotations from collections.abc import Generator from fastapi.testclient import TestClient from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.pool import StaticPool from app.api.deps import get_db from app.core.config import get_settings from app.db.base import Base from app.main import create_app from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeFieldRead, OcrRecognizeLineRead from app.services.ocr import OcrService def build_client() -> TestClient: engine = create_engine( "sqlite+pysqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) Base.metadata.create_all(bind=engine) session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False) app = create_app() def override_db() -> Generator[Session, None, None]: db = session_factory() try: yield db finally: db.close() app.dependency_overrides[get_db] = override_db return TestClient(app) def test_ocr_recognize_endpoint_returns_structured_payload(monkeypatch, tmp_path) -> None: def fake_recognize( self, files: list[tuple[str, bytes, str | None]], ) -> OcrRecognizeBatchRead: assert files[0][0] == "invoice.png" return OcrRecognizeBatchRead( engine="paddleocr_mobile", model="PP-OCRv5_mobile", total_file_count=1, success_count=1, documents=[ OcrRecognizeDocumentRead( filename="invoice.png", media_type="image/png", text="增值税电子发票 发票号码12345678 金额 100 元 2026-05-13", summary="增值税电子发票,金额 100 元。", avg_score=0.98, line_count=1, page_count=1, document_type="vat_invoice", document_type_label="增值税发票", scene_code="other", scene_label="通用发票", document_fields=[ OcrRecognizeFieldRead(key="amount", label="金额", value="100元"), OcrRecognizeFieldRead(key="date", label="日期", value="2026-05-13"), OcrRecognizeFieldRead(key="invoice_number", label="票据号码", value="12345678"), ], lines=[ OcrRecognizeLineRead( text="增值税电子发票 发票号码12345678 金额 100 元 2026-05-13", score=0.98, box=[[1, 2], [10, 2], [10, 8], [1, 8]], page_index=0, ) ], ) ], ) monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage")) get_settings.cache_clear() monkeypatch.setattr(OcrService, "recognize_files", fake_recognize) try: client = build_client() auth_headers = {"x-auth-username": "pytest", "x-auth-name": "Py Test"} response = client.post( "/api/v1/ocr/recognize", headers=auth_headers, files=[("files", ("invoice.png", b"fake-image", "image/png"))], ) assert response.status_code == 200 payload = response.json() document = payload["documents"][0] assert payload["engine"] == "paddleocr_mobile" assert payload["success_count"] == 1 assert document["filename"] == "invoice.png" assert document["summary"] == "增值税电子发票,金额 100 元。" assert document["document_type"] == "vat_invoice" assert document["document_type_label"] == "增值税发票" assert document["document_fields"][0]["label"] == "金额" assert document["receipt_id"] assert document["receipt_status"] == "unlinked" assert document["receipt_preview_url"].endswith(f"/receipt-folder/{document['receipt_id']}/preview") assert document["receipt_source_url"].endswith(f"/receipt-folder/{document['receipt_id']}/source") receipt_id = document["receipt_id"] list_response = client.get("/api/v1/receipt-folder?status=unlinked", headers=auth_headers) assert list_response.status_code == 200 receipt_list = list_response.json() assert len(receipt_list) == 1 assert receipt_list[0]["id"] == receipt_id assert receipt_list[0]["amount"] == "100元" repeated_response = client.post( "/api/v1/ocr/recognize", headers=auth_headers, data={"receipt_ids": receipt_id}, files=[("files", ("invoice.png", b"fake-image", "image/png"))], ) assert repeated_response.status_code == 200 repeated_document = repeated_response.json()["documents"][0] assert repeated_document["receipt_id"] == receipt_id duplicate_response = client.post( "/api/v1/ocr/recognize", headers=auth_headers, files=[("files", ("invoice.png", b"fake-image", "image/png"))], ) assert duplicate_response.status_code == 200 duplicate_document = duplicate_response.json()["documents"][0] assert duplicate_document["receipt_id"] == receipt_id assert duplicate_document["receipt_status"] == "unlinked" assert any("重复上传" in warning for warning in duplicate_document["warnings"]) all_receipts_response = client.get("/api/v1/receipt-folder?status=all", headers=auth_headers) assert all_receipts_response.status_code == 200 assert len(all_receipts_response.json()) == 1 detail_response = client.get(f"/api/v1/receipt-folder/{receipt_id}", headers=auth_headers) assert detail_response.status_code == 200 detail_payload = detail_response.json() assert detail_payload["file_name"] == "invoice.png" assert detail_payload["fields"][0]["label"] == "金额" update_response = client.patch( f"/api/v1/receipt-folder/{receipt_id}", headers=auth_headers, json={ "document_type_label": "电子发票", "amount": "108元", "fields": [{"key": "amount", "label": "金额", "value": "108元"}], }, ) assert update_response.status_code == 200 updated_payload = update_response.json() assert update_response.json()["document_type_label"] == "电子发票" assert update_response.json()["amount"] == "108元" assert updated_payload["edit_logs"] assert any( change["after"] == updated_payload["amount"] for change in updated_payload["edit_logs"][0]["changes"] ) preview_response = client.get(f"/api/v1/receipt-folder/{receipt_id}/preview", headers=auth_headers) assert preview_response.status_code == 200 assert preview_response.content == b"fake-image" delete_response = client.delete(f"/api/v1/receipt-folder/{receipt_id}", headers=auth_headers) assert delete_response.status_code == 200 assert delete_response.json()["receipt_id"] == receipt_id deleted_response = client.get(f"/api/v1/receipt-folder/{receipt_id}", headers=auth_headers) assert deleted_response.status_code == 404 finally: get_settings.cache_clear() def test_ocr_recognize_endpoint_returns_receipt_enriched_train_fields(monkeypatch, tmp_path) -> None: def fake_recognize( self, files: list[tuple[str, bytes, str | None]], ) -> OcrRecognizeBatchRead: return OcrRecognizeBatchRead( engine="paddleocr_mobile", model="PP-OCRv5_mobile", total_file_count=1, success_count=1, documents=[ OcrRecognizeDocumentRead( filename="2月20_武汉-上海.png", media_type="image/png", text=( ":26429165800002785705\n" "G458\n" "Wuhan\n" "Shanghaihongqiao\n" "2026 02 20 07:55\n" "06 01B\n" ": 354.00\n" "4201061987****1615\n" ":6580061086021391007342026\n" "12306 95306" ), summary="Wuhan Shanghaihongqiao G458 354.00", avg_score=0.92, line_count=0, page_count=1, document_type="train_ticket", document_type_label="火车/高铁票", scene_code="travel", scene_label="差旅票据", document_fields=[ OcrRecognizeFieldRead(key="date", label="列车出发时间", value="2026-02-20 07:55"), OcrRecognizeFieldRead(key="trip_no", label="车次/航班", value="G458"), OcrRecognizeFieldRead(key="route", label="行程", value="武汉-上海"), OcrRecognizeFieldRead(key="amount", label="金额", value="354元"), ], ) ], ) monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage")) get_settings.cache_clear() monkeypatch.setattr(OcrService, "recognize_files", fake_recognize) try: client = build_client() response = client.post( "/api/v1/ocr/recognize", headers={"x-auth-username": "pytest", "x-auth-name": "Py Test"}, files=[("files", ("2月20_武汉-上海.png", b"fake-image", "image/png"))], ) finally: get_settings.cache_clear() assert response.status_code == 200 document = response.json()["documents"][0] fields = { item["label"]: item["value"] for item in document["document_fields"] } assert document["receipt_id"] assert fields["身份证号"] == "4201061987****1615" assert fields["车厢"] == "06车" assert fields["座位号"] == "01B" assert fields["票价"] == "354.00元"