Files
X-Financial/server/tests/test_attachment_association_jobs.py

281 lines
10 KiB
Python
Raw Normal View History

from __future__ import annotations
from collections.abc import Generator
from datetime import UTC, date, datetime
from decimal import Decimal
from fastapi.testclient import TestClient
from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload
from app.api.deps import CurrentUserContext, get_db
from app.api.v1.endpoints import attachment_association_jobs as attachment_jobs_endpoint
from app.core.config import get_settings
from app.main import create_app
from app.models.employee import Employee
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeFieldRead
from app.services.attachment_association_jobs import clear_attachment_association_jobs_for_tests
from app.services.expense_claim_attachment_storage import ExpenseClaimAttachmentStorage
from app.services.ocr import OcrService
from app.services.receipt_folder import ReceiptFolderService
from app.test_helpers.db import build_in_memory_session_factory
def build_client(monkeypatch) -> tuple[TestClient, object]:
session_factory = build_in_memory_session_factory()
app = create_app()
def override_db() -> Generator[Session, None, None]:
db = session_factory()
try:
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_db
monkeypatch.setattr(attachment_jobs_endpoint, "get_session_factory", lambda: session_factory)
return TestClient(app), session_factory
def seed_travel_claim(db: Session) -> ExpenseClaim:
employee = Employee(
id="emp-bg-association",
employee_no="E10001",
name="张三",
email="zhangsan@example.com",
position="实施顾问",
grade="P4",
)
claim = ExpenseClaim(
id="claim-bg-association",
claim_no="BX-20260220-001",
employee_id=employee.id,
employee_name=employee.name,
department_id="dept-delivery",
department_name="交付部",
project_code=None,
expense_type="travel",
reason="辅助国网仿生产服务器部署,武汉往返上海",
location="上海",
amount=Decimal("0.00"),
currency="CNY",
invoice_count=0,
occurred_at=datetime(2026, 2, 20, tzinfo=UTC),
submitted_at=None,
status="draft",
approval_stage="待提交",
risk_flags_json=[],
)
item = ExpenseClaimItem(
id="item-bg-association-1",
claim_id=claim.id,
item_date=date(2026, 2, 20),
item_type="train_ticket",
item_reason="武汉至上海高铁",
item_location="上海",
item_amount=Decimal("0.00"),
invoice_id=None,
)
claim.items = [item]
db.add_all([employee, claim])
db.commit()
return claim
def save_train_receipt(
*,
service: ReceiptFolderService,
current_user: CurrentUserContext,
filename: str,
route: str,
trip_date: str,
) -> str:
receipt = service.save_receipt(
filename=filename,
content=f"fake-pdf-{filename}".encode("utf-8"),
media_type="application/pdf",
current_user=current_user,
document=OcrRecognizeDocumentRead(
filename=filename,
media_type="application/pdf",
text=f"电子发票(铁路电子客票) {route} {trip_date} 票价 354 元",
summary=f"铁路电子客票,{route},票价 354 元。",
avg_score=0.96,
line_count=1,
page_count=1,
document_type="train_ticket",
document_type_label="火车/高铁票",
scene_code="travel",
scene_label="差旅票据",
document_fields=[
OcrRecognizeFieldRead(key="date", label="列车出发时间", value=trip_date),
OcrRecognizeFieldRead(key="route", label="行程", value=route),
OcrRecognizeFieldRead(key="amount", label="金额", value="354元"),
],
),
)
return receipt.id
def fake_ocr_recognize(
self,
files: list[tuple[str, bytes, str | None]],
) -> OcrRecognizeBatchRead:
filename = files[0][0]
return OcrRecognizeBatchRead(
total_file_count=1,
success_count=1,
documents=[
OcrRecognizeDocumentRead(
filename=filename,
media_type=files[0][2] or "application/pdf",
text="电子发票(铁路电子客票) 武汉 上海 2026-02-20 票价 354 元",
summary="铁路电子客票,武汉至上海,票价 354 元。",
avg_score=0.96,
line_count=1,
page_count=1,
document_type="train_ticket",
document_type_label="火车/高铁票",
scene_code="travel",
scene_label="差旅票据",
document_fields=[
OcrRecognizeFieldRead(key="date", label="列车出发时间", value="2026-02-20"),
OcrRecognizeFieldRead(key="route", label="行程", value="武汉-上海"),
OcrRecognizeFieldRead(key="amount", label="金额", value="354元"),
],
)
],
)
def test_attachment_association_job_links_receipts_after_conversation_exit(monkeypatch, tmp_path) -> None:
monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage"))
get_settings.cache_clear()
clear_attachment_association_jobs_for_tests()
monkeypatch.setattr(OcrService, "recognize_files", fake_ocr_recognize)
monkeypatch.setattr(ExpenseClaimAttachmentStorage, "root", lambda self: tmp_path / "attachments")
try:
client, session_factory = build_client(monkeypatch)
current_user = CurrentUserContext(
username="zhangsan@example.com",
name="张三",
role_codes=["user"],
is_admin=False,
employee_no="E10001",
)
with session_factory() as db:
seed_travel_claim(db)
receipt_service = ReceiptFolderService()
receipt_ids = [
save_train_receipt(
service=receipt_service,
current_user=current_user,
filename="2月20 武汉-上海.pdf",
route="武汉-上海",
trip_date="2026-02-20",
),
save_train_receipt(
service=receipt_service,
current_user=current_user,
filename="2月23 上海-武汉.pdf",
route="上海-武汉",
trip_date="2026-02-23",
),
]
headers = {
"x-auth-username": "zhangsan@example.com",
"x-auth-name": "Zhang San",
"x-auth-employee-no": "E10001",
"x-auth-role-codes": "user",
}
response = client.post(
"/api/v1/reimbursements/attachment-association-jobs",
headers=headers,
json={
"receipt_ids": receipt_ids,
"prompt": "请帮我处理已上传的附件。",
"conversation_id": "inline-test",
},
)
assert response.status_code == 202
job_id = response.json()["job_id"]
status_response = client.get(
f"/api/v1/reimbursements/attachment-association-jobs/{job_id}",
headers=headers,
)
assert status_response.status_code == 200
payload = status_response.json()
assert payload["status"] == "succeeded"
assert payload["claim_id"] == "claim-bg-association"
assert payload["claim_no"] == "BX-20260220-001"
assert payload["uploaded_count"] == 2
with session_factory() as db:
claim = db.scalar(
select(ExpenseClaim)
.options(selectinload(ExpenseClaim.items))
.where(ExpenseClaim.id == "claim-bg-association")
)
assert claim is not None
attached_items = [item for item in claim.items if item.invoice_id]
assert len(attached_items) == 2
linked_receipts = receipt_service.list_receipts(current_user=current_user, status_filter="linked")
assert {item.id for item in linked_receipts} == set(receipt_ids)
assert {item.linked_claim_id for item in linked_receipts} == {"claim-bg-association"}
finally:
clear_attachment_association_jobs_for_tests()
get_settings.cache_clear()
def test_attachment_association_job_fails_without_editable_claim(monkeypatch, tmp_path) -> None:
monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage"))
get_settings.cache_clear()
clear_attachment_association_jobs_for_tests()
try:
client, _session_factory = build_client(monkeypatch)
current_user = CurrentUserContext(
username="zhangsan@example.com",
name="张三",
role_codes=["user"],
is_admin=False,
employee_no="E10001",
)
receipt_id = save_train_receipt(
service=ReceiptFolderService(),
current_user=current_user,
filename="2月20 武汉-上海.pdf",
route="武汉-上海",
trip_date="2026-02-20",
)
headers = {
"x-auth-username": "zhangsan@example.com",
"x-auth-name": "Zhang San",
"x-auth-employee-no": "E10001",
"x-auth-role-codes": "user",
}
response = client.post(
"/api/v1/reimbursements/attachment-association-jobs",
headers=headers,
json={"receipt_ids": [receipt_id], "conversation_id": "inline-empty"},
)
assert response.status_code == 202
status_response = client.get(
f"/api/v1/reimbursements/attachment-association-jobs/{response.json()['job_id']}",
headers=headers,
)
assert status_response.status_code == 200
payload = status_response.json()
assert payload["status"] == "failed"
assert "没有找到可自动关联的报销草稿" in payload["message"]
finally:
clear_attachment_association_jobs_for_tests()
get_settings.cache_clear()