from __future__ import annotations from collections.abc import Generator from datetime import UTC, date, datetime from decimal import Decimal from fastapi.testclient import TestClient from sqlalchemy import select from sqlalchemy.orm import Session, selectinload from app.api.deps import CurrentUserContext, get_db from app.api.v1.endpoints import attachment_association_jobs as attachment_jobs_endpoint from app.core.config import get_settings from app.main import create_app from app.models.employee import Employee from app.models.financial_record import ExpenseClaim, ExpenseClaimItem from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeFieldRead from app.services.attachment_association_jobs import clear_attachment_association_jobs_for_tests from app.services.expense_claim_attachment_storage import ExpenseClaimAttachmentStorage from app.services.ocr import OcrService from app.services.receipt_folder import ReceiptFolderService from app.test_helpers.db import build_in_memory_session_factory def build_client(monkeypatch) -> tuple[TestClient, object]: session_factory = build_in_memory_session_factory() app = create_app() def override_db() -> Generator[Session, None, None]: db = session_factory() try: yield db finally: db.close() app.dependency_overrides[get_db] = override_db monkeypatch.setattr(attachment_jobs_endpoint, "get_session_factory", lambda: session_factory) return TestClient(app), session_factory def seed_travel_claim(db: Session) -> ExpenseClaim: employee = Employee( id="emp-bg-association", employee_no="E10001", name="张三", email="zhangsan@example.com", position="实施顾问", grade="P4", ) claim = ExpenseClaim( id="claim-bg-association", claim_no="BX-20260220-001", employee_id=employee.id, employee_name=employee.name, department_id="dept-delivery", department_name="交付部", project_code=None, expense_type="travel", reason="辅助国网仿生产服务器部署,武汉往返上海", location="上海", amount=Decimal("0.00"), currency="CNY", invoice_count=0, occurred_at=datetime(2026, 2, 20, tzinfo=UTC), submitted_at=None, status="draft", approval_stage="待提交", risk_flags_json=[], ) item = ExpenseClaimItem( id="item-bg-association-1", claim_id=claim.id, item_date=date(2026, 2, 20), item_type="train_ticket", item_reason="武汉至上海高铁", item_location="上海", item_amount=Decimal("0.00"), invoice_id=None, ) claim.items = [item] db.add_all([employee, claim]) db.commit() return claim def save_train_receipt( *, service: ReceiptFolderService, current_user: CurrentUserContext, filename: str, route: str, trip_date: str, ) -> str: receipt = service.save_receipt( filename=filename, content=f"fake-pdf-{filename}".encode("utf-8"), media_type="application/pdf", current_user=current_user, document=OcrRecognizeDocumentRead( filename=filename, media_type="application/pdf", text=f"电子发票(铁路电子客票) {route} {trip_date} 票价 354 元", summary=f"铁路电子客票,{route},票价 354 元。", avg_score=0.96, line_count=1, page_count=1, document_type="train_ticket", document_type_label="火车/高铁票", scene_code="travel", scene_label="差旅票据", document_fields=[ OcrRecognizeFieldRead(key="date", label="列车出发时间", value=trip_date), OcrRecognizeFieldRead(key="route", label="行程", value=route), OcrRecognizeFieldRead(key="amount", label="金额", value="354元"), ], ), ) return receipt.id def fake_ocr_recognize( self, files: list[tuple[str, bytes, str | None]], ) -> OcrRecognizeBatchRead: filename = files[0][0] return OcrRecognizeBatchRead( total_file_count=1, success_count=1, documents=[ OcrRecognizeDocumentRead( filename=filename, media_type=files[0][2] or "application/pdf", text="电子发票(铁路电子客票) 武汉 上海 2026-02-20 票价 354 元", summary="铁路电子客票,武汉至上海,票价 354 元。", avg_score=0.96, line_count=1, page_count=1, document_type="train_ticket", document_type_label="火车/高铁票", scene_code="travel", scene_label="差旅票据", document_fields=[ OcrRecognizeFieldRead(key="date", label="列车出发时间", value="2026-02-20"), OcrRecognizeFieldRead(key="route", label="行程", value="武汉-上海"), OcrRecognizeFieldRead(key="amount", label="金额", value="354元"), ], ) ], ) def test_attachment_association_job_links_receipts_after_conversation_exit(monkeypatch, tmp_path) -> None: monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage")) get_settings.cache_clear() clear_attachment_association_jobs_for_tests() monkeypatch.setattr(OcrService, "recognize_files", fake_ocr_recognize) monkeypatch.setattr(ExpenseClaimAttachmentStorage, "root", lambda self: tmp_path / "attachments") try: client, session_factory = build_client(monkeypatch) current_user = CurrentUserContext( username="zhangsan@example.com", name="张三", role_codes=["user"], is_admin=False, employee_no="E10001", ) with session_factory() as db: seed_travel_claim(db) receipt_service = ReceiptFolderService() receipt_ids = [ save_train_receipt( service=receipt_service, current_user=current_user, filename="2月20 武汉-上海.pdf", route="武汉-上海", trip_date="2026-02-20", ), save_train_receipt( service=receipt_service, current_user=current_user, filename="2月23 上海-武汉.pdf", route="上海-武汉", trip_date="2026-02-23", ), ] headers = { "x-auth-username": "zhangsan@example.com", "x-auth-name": "Zhang San", "x-auth-employee-no": "E10001", "x-auth-role-codes": "user", } response = client.post( "/api/v1/reimbursements/attachment-association-jobs", headers=headers, json={ "receipt_ids": receipt_ids, "prompt": "请帮我处理已上传的附件。", "conversation_id": "inline-test", }, ) assert response.status_code == 202 job_id = response.json()["job_id"] status_response = client.get( f"/api/v1/reimbursements/attachment-association-jobs/{job_id}", headers=headers, ) assert status_response.status_code == 200 payload = status_response.json() assert payload["status"] == "succeeded" assert payload["claim_id"] == "claim-bg-association" assert payload["claim_no"] == "BX-20260220-001" assert payload["uploaded_count"] == 2 with session_factory() as db: claim = db.scalar( select(ExpenseClaim) .options(selectinload(ExpenseClaim.items)) .where(ExpenseClaim.id == "claim-bg-association") ) assert claim is not None attached_items = [item for item in claim.items if item.invoice_id] assert len(attached_items) == 2 linked_receipts = receipt_service.list_receipts(current_user=current_user, status_filter="linked") assert {item.id for item in linked_receipts} == set(receipt_ids) assert {item.linked_claim_id for item in linked_receipts} == {"claim-bg-association"} finally: clear_attachment_association_jobs_for_tests() get_settings.cache_clear() def test_attachment_association_job_fails_without_editable_claim(monkeypatch, tmp_path) -> None: monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage")) get_settings.cache_clear() clear_attachment_association_jobs_for_tests() try: client, _session_factory = build_client(monkeypatch) current_user = CurrentUserContext( username="zhangsan@example.com", name="张三", role_codes=["user"], is_admin=False, employee_no="E10001", ) receipt_id = save_train_receipt( service=ReceiptFolderService(), current_user=current_user, filename="2月20 武汉-上海.pdf", route="武汉-上海", trip_date="2026-02-20", ) headers = { "x-auth-username": "zhangsan@example.com", "x-auth-name": "Zhang San", "x-auth-employee-no": "E10001", "x-auth-role-codes": "user", } response = client.post( "/api/v1/reimbursements/attachment-association-jobs", headers=headers, json={"receipt_ids": [receipt_id], "conversation_id": "inline-empty"}, ) assert response.status_code == 202 status_response = client.get( f"/api/v1/reimbursements/attachment-association-jobs/{response.json()['job_id']}", headers=headers, ) assert status_response.status_code == 200 payload = status_response.json() assert payload["status"] == "failed" assert "没有找到可自动关联的报销草稿" in payload["message"] finally: clear_attachment_association_jobs_for_tests() get_settings.cache_clear()