from __future__ import annotations from collections.abc import Generator from datetime import UTC, date, datetime from decimal import Decimal from fastapi.testclient import TestClient from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.pool import StaticPool from app.api.deps import get_db from app.db.base import Base from app.main import create_app from app.models.employee import Employee from app.models.financial_record import ExpenseClaim, ExpenseClaimItem from app.models.role import Role from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead from app.services.expense_claims import ExpenseClaimService from app.services.ocr import OcrService def build_session_factory() -> sessionmaker[Session]: engine = create_engine( "sqlite+pysqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) Base.metadata.create_all(bind=engine) return sessionmaker(bind=engine, autoflush=False, autocommit=False) def build_client() -> tuple[TestClient, sessionmaker[Session]]: session_factory = build_session_factory() app = create_app() def override_db() -> Generator[Session, None, None]: db = session_factory() try: yield db finally: db.close() app.dependency_overrides[get_db] = override_db return TestClient(app), session_factory def seed_claim(db: Session) -> tuple[ExpenseClaim, ExpenseClaimItem]: manager = Employee( id="mgr-1", employee_no="E20001", name="李总", email="manager@example.com", position="市场总监", grade="P7", ) role = Role( id="role-user", role_code="user", name="员工", description="普通员工", ) employee = Employee( id="emp-1", employee_no="E10001", name="张三", email="zhangsan@example.com", position="招商主管", grade="P4", manager=manager, roles=[role], ) claim = ExpenseClaim( id="claim-attachment-1", claim_no="EXP-202605-101", employee_id=employee.id, employee_name="张三", department_id="dept-1", department_name="市场部", project_code=None, expense_type="office", reason="办公用品采购", location="深圳南山", amount=Decimal("88.00"), currency="CNY", invoice_count=0, occurred_at=datetime(2026, 5, 13, tzinfo=UTC), submitted_at=None, status="draft", approval_stage="待提交", risk_flags_json=[], ) item = ExpenseClaimItem( id="item-attachment-1", claim_id=claim.id, item_date=date(2026, 5, 13), item_type="office", item_reason="办公用品采购", item_location="深圳南山", item_amount=Decimal("88.00"), invoice_id=None, ) claim.items = [item] db.add(manager) db.add(role) db.add(employee) db.add(claim) db.commit() return claim, item def test_claim_item_attachment_upload_preview_and_delete(monkeypatch, tmp_path) -> None: def fake_recognize( self, files: list[tuple[str, bytes, str | None]], ) -> OcrRecognizeBatchRead: assert files[0][0] == "office-note.png" return OcrRecognizeBatchRead( total_file_count=1, success_count=1, documents=[ OcrRecognizeDocumentRead( filename="office-note.png", media_type="image/png", text="办公用品发票 金额88元 2026-05-13", summary="识别到办公用品发票,金额 88 元。", avg_score=0.98, line_count=1, page_count=1, warnings=[], ) ], ) monkeypatch.setattr(OcrService, "recognize_files", fake_recognize) monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path) client, session_factory = build_client() with session_factory() as db: claim, item = seed_claim(db) claim_id = claim.id item_id = item.id headers = {"x-auth-username": "emp-1", "x-auth-name": "Zhang San"} file_bytes = b"fake-image-bytes" upload_response = client.post( f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment", headers=headers, files=[("file", ("office-note.png", file_bytes, "image/png"))], ) assert upload_response.status_code == 200 upload_payload = upload_response.json() assert upload_payload["attachment"]["file_name"] == "office-note.png" assert upload_payload["attachment"]["analysis"]["label"] == "AI提示符合条件" assert upload_payload["invoice_id"] meta_response = client.get( f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment/meta", headers=headers, ) assert meta_response.status_code == 200 meta_payload = meta_response.json() assert meta_payload["media_type"] == "image/png" assert meta_payload["analysis"]["headline"] content_response = client.get( f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment", headers=headers, ) assert content_response.status_code == 200 assert content_response.content == file_bytes delete_response = client.delete( f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment", headers=headers, ) assert delete_response.status_code == 200 assert delete_response.json()["invoice_id"] is None deleted_meta_response = client.get( f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment/meta", headers=headers, ) assert deleted_meta_response.status_code == 404 def test_claim_item_attachment_upload_flags_purpose_and_amount_mismatch(monkeypatch, tmp_path) -> None: def fake_recognize( self, files: list[tuple[str, bytes, str | None]], ) -> OcrRecognizeBatchRead: return OcrRecognizeBatchRead( total_file_count=1, success_count=1, documents=[ OcrRecognizeDocumentRead( filename="taxi-note.png", media_type="image/png", text="滴滴出行电子发票 金额120元 2026-05-13", summary="识别到交通出行发票,金额 120 元。", avg_score=0.97, line_count=1, page_count=1, warnings=[], ) ], ) monkeypatch.setattr(OcrService, "recognize_files", fake_recognize) monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path) client, session_factory = build_client() with session_factory() as db: claim, item = seed_claim(db) claim_id = claim.id item_id = item.id headers = {"x-auth-username": "emp-1", "x-auth-name": "Zhang San"} upload_response = client.post( f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment", headers=headers, files=[("file", ("taxi-note.png", b"fake-image-bytes", "image/png"))], ) assert upload_response.status_code == 200 analysis = upload_response.json()["attachment"]["analysis"] assert analysis["severity"] == "high" assert any("金额字段" in point for point in analysis["points"]) assert any("用途字段" in point for point in analysis["points"]) def test_claim_item_attachment_upload_flags_non_invoice_image_as_high_risk(monkeypatch, tmp_path) -> None: def fake_recognize( self, files: list[tuple[str, bytes, str | None]], ) -> OcrRecognizeBatchRead: return OcrRecognizeBatchRead( total_file_count=1, success_count=1, documents=[ OcrRecognizeDocumentRead( filename="random-image.png", media_type="image/png", text="", summary="", avg_score=0.0, line_count=0, page_count=1, warnings=[], ) ], ) monkeypatch.setattr(OcrService, "recognize_files", fake_recognize) monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path) client, session_factory = build_client() with session_factory() as db: claim, item = seed_claim(db) claim_id = claim.id item_id = item.id headers = {"x-auth-username": "emp-1", "x-auth-name": "Zhang San"} upload_response = client.post( f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment", headers=headers, files=[("file", ("random-image.png", b"fake-image-bytes", "image/png"))], ) assert upload_response.status_code == 200 analysis = upload_response.json()["attachment"]["analysis"] assert analysis["severity"] == "high" assert any("附件内容" in point for point in analysis["points"]) def test_claim_item_delete_removes_item_and_attachment(monkeypatch, tmp_path) -> None: def fake_recognize( self, files: list[tuple[str, bytes, str | None]], ) -> OcrRecognizeBatchRead: return OcrRecognizeBatchRead( total_file_count=1, success_count=1, documents=[ OcrRecognizeDocumentRead( filename="office-note.png", media_type="image/png", text="办公用品发票 金额88元 2026-05-13", summary="识别到办公用品发票,金额 88 元。", avg_score=0.98, line_count=1, page_count=1, warnings=[], ) ], ) monkeypatch.setattr(OcrService, "recognize_files", fake_recognize) monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path) client, session_factory = build_client() with session_factory() as db: claim, item = seed_claim(db) claim_id = claim.id item_id = item.id headers = {"x-auth-username": "emp-1", "x-auth-name": "Zhang San"} upload_response = client.post( f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment", headers=headers, files=[("file", ("office-note.png", b"fake-image-bytes", "image/png"))], ) assert upload_response.status_code == 200 assert (tmp_path / claim_id / item_id).exists() delete_response = client.delete( f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}", headers=headers, ) assert delete_response.status_code == 200 assert delete_response.json()["item_id"] == item_id assert not (tmp_path / claim_id / item_id).exists() detail_response = client.get( f"/api/v1/reimbursements/claims/{claim_id}", headers=headers, ) assert detail_response.status_code == 200 detail_payload = detail_response.json() assert detail_payload["items"] == [] assert detail_payload["invoice_count"] == 0 assert detail_payload["employee_position"] == "招商主管" assert detail_payload["employee_grade"] == "P4" assert detail_payload["manager_name"] == "李总" assert detail_payload["role_labels"] == ["员工"] deleted_meta_response = client.get( f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment/meta", headers=headers, ) assert deleted_meta_response.status_code == 404