diff --git a/server/tests/test_ocr_endpoints.py b/server/tests/test_ocr_endpoints.py new file mode 100644 index 0000000..0a1b0b7 --- /dev/null +++ b/server/tests/test_ocr_endpoints.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from collections.abc import Generator + +from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.pool import StaticPool + +from app.api.deps import get_db +from app.db.base import Base +from app.main import create_app +from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeLineRead +from app.services.ocr import OcrService + + +def build_client() -> TestClient: + engine = create_engine( + "sqlite+pysqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + Base.metadata.create_all(bind=engine) + session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False) + app = create_app() + + def override_db() -> Generator[Session, None, None]: + db = session_factory() + try: + yield db + finally: + db.close() + + app.dependency_overrides[get_db] = override_db + return TestClient(app) + + +def test_ocr_recognize_endpoint_returns_structured_payload(monkeypatch) -> None: + def fake_recognize( + self, + files: list[tuple[str, bytes, str | None]], + ) -> OcrRecognizeBatchRead: + assert files[0][0] == "invoice.png" + return OcrRecognizeBatchRead( + engine="paddleocr_mobile", + model="PP-OCRv5_mobile", + total_file_count=1, + success_count=1, + documents=[ + OcrRecognizeDocumentRead( + filename="invoice.png", + media_type="image/png", + text="发票金额 100 元", + summary="发票金额 100 元", + avg_score=0.98, + line_count=1, + page_count=1, + lines=[ + OcrRecognizeLineRead( + text="发票金额 100 元", + score=0.98, + box=[[1, 2], [10, 2], [10, 8], [1, 8]], + page_index=0, + ) + ], + ) + ], + ) + + monkeypatch.setattr(OcrService, "recognize_files", fake_recognize) + client = build_client() + + response = client.post( + "/api/v1/ocr/recognize", + headers={"x-auth-username": "pytest", "x-auth-name": "Py Test"}, + files=[("files", ("invoice.png", b"fake-image", "image/png"))], + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["engine"] == "paddleocr_mobile" + assert payload["success_count"] == 1 + assert payload["documents"][0]["filename"] == "invoice.png" + assert payload["documents"][0]["summary"] == "发票金额 100 元" diff --git a/server/tests/test_ocr_service.py b/server/tests/test_ocr_service.py new file mode 100644 index 0000000..37cba1f --- /dev/null +++ b/server/tests/test_ocr_service.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import stat +from pathlib import Path + +from app.core.config import get_settings +from app.services.ocr import OcrService + + +def test_ocr_service_uses_worker_runtime_and_keeps_unsupported_files_as_warnings( + monkeypatch, + tmp_path: Path, +) -> None: + fake_python = tmp_path / "fake-ocr-python.py" + fake_python.write_text( + """#!/usr/bin/env python3 +import json +import sys + +inputs = [] +for index, arg in enumerate(sys.argv): + if arg == "--input" and index + 1 < len(sys.argv): + input_path = sys.argv[index + 1] + inputs.append( + { + "input_path": input_path, + "engine": "paddleocr_mobile", + "model": "PP-OCRv5_mobile", + "text": "发票金额 100 元", + "summary": "发票金额 100 元", + "avg_score": 0.98, + "line_count": 1, + "page_count": 1, + "warnings": [], + "lines": [ + { + "text": "发票金额 100 元", + "score": 0.98, + "box": [[1, 2], [10, 2], [10, 8], [1, 8]], + "page_index": 0, + } + ], + } + ) + +payload = { + "engine": "paddleocr_mobile", + "model": "PP-OCRv5_mobile", + "documents": inputs, +} +print("__OCR_JSON__=" + json.dumps(payload, ensure_ascii=False)) +""", + encoding="utf-8", + ) + fake_python.chmod(fake_python.stat().st_mode | stat.S_IEXEC) + + monkeypatch.setenv("OCR_PYTHON_BIN", str(fake_python)) + monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage")) + get_settings.cache_clear() + try: + result = OcrService().recognize_files( + [ + ("invoice.png", b"fake-image", "image/png"), + ("notes.txt", b"plain-text", "text/plain"), + ] + ) + finally: + get_settings.cache_clear() + + assert result.engine == "paddleocr_mobile" + assert result.model == "PP-OCRv5_mobile" + assert result.total_file_count == 2 + assert result.success_count == 1 + assert len(result.documents) == 2 + + recognized = next(item for item in result.documents if item.filename == "invoice.png") + assert recognized.summary == "发票金额 100 元" + assert recognized.line_count == 1 + assert recognized.lines[0].text == "发票金额 100 元" + + skipped = next(item for item in result.documents if item.filename == "notes.txt") + assert skipped.line_count == 0 + assert skipped.warnings == ["当前仅支持图片和 PDF 文件进行 OCR。"] diff --git a/server/tests/test_openapi_schema.py b/server/tests/test_openapi_schema.py index e990cf9..8a7e002 100644 --- a/server/tests/test_openapi_schema.py +++ b/server/tests/test_openapi_schema.py @@ -10,6 +10,7 @@ def test_openapi_schema_includes_documented_backend_routes() -> None: assert schema["info"]["title"] == get_settings().app_name assert any(tag["name"] == "agent-assets" for tag in schema["tags"]) assert any(tag["name"] == "knowledge" for tag in schema["tags"]) + assert any(tag["name"] == "ocr" for tag in schema["tags"]) assert any(tag["name"] == "ontology" for tag in schema["tags"]) assert any(tag["name"] == "orchestrator" for tag in schema["tags"]) @@ -27,6 +28,10 @@ def test_openapi_schema_includes_documented_backend_routes() -> None: assert knowledge_callback_post["summary"] == "接收 ONLYOFFICE 回调" assert "application/json" in knowledge_callback_post["requestBody"]["content"] + ocr_post = schema["paths"]["/api/v1/ocr/recognize"]["post"] + assert ocr_post["summary"] == "识别票据或图片 OCR" + assert "multipart/form-data" in ocr_post["requestBody"]["content"] + ontology_parse_post = schema["paths"]["/api/v1/ontology/parse"]["post"] assert ontology_parse_post["summary"] == "解析自然语言为语义本体" assert "application/json" in ontology_parse_post["requestBody"]["content"] diff --git a/server/tests/test_orchestrator_service.py b/server/tests/test_orchestrator_service.py index 14d440f..af227ec 100644 --- a/server/tests/test_orchestrator_service.py +++ b/server/tests/test_orchestrator_service.py @@ -3,13 +3,14 @@ from __future__ import annotations from collections.abc import Generator from fastapi.testclient import TestClient -from sqlalchemy import create_engine +from sqlalchemy import create_engine, select from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.pool import StaticPool from app.api.deps import get_db from app.db.base import Base from app.main import create_app +from app.models.financial_record import ExpenseClaim from app.services.agent_assets import AgentAssetService @@ -142,7 +143,7 @@ def test_orchestrator_approval_required_returns_confirmation_result() -> None: def test_orchestrator_user_agent_draft_returns_structured_payload() -> None: - client, _ = build_client() + client, session_factory = build_client() response = client.post( "/api/v1/orchestrator/run", @@ -159,8 +160,22 @@ def test_orchestrator_user_agent_draft_returns_structured_payload() -> None: assert payload["selected_agent"] == "user_agent" assert payload["status"] == "succeeded" assert payload["result"]["draft_payload"]["confirmation_required"] is True + assert payload["result"]["draft_payload"]["claim_id"] + assert payload["result"]["draft_payload"]["claim_no"].startswith("EXP-") + assert payload["result"]["draft_payload"]["status"] == "draft" assert payload["result"]["suggested_actions"] + with session_factory() as db: + claim = db.scalar( + select(ExpenseClaim).where( + ExpenseClaim.id == payload["result"]["draft_payload"]["claim_id"] + ) + ) + assert claim is not None + assert claim.claim_no == payload["result"]["draft_payload"]["claim_no"] + assert claim.status == "draft" + assert claim.items + def test_orchestrator_treats_expense_narrative_as_draft_instead_of_ar_query() -> None: client, _ = build_client()