Files
X-Financial/server/tests/test_reimbursement_endpoints.py

419 lines
15 KiB
Python

from __future__ import annotations
import base64
from collections.abc import Generator
from datetime import UTC, date, datetime
from decimal import Decimal
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.api.deps import get_db
from app.db.base import Base
from app.main import create_app
from app.models.employee import Employee
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
from app.models.role import Role
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead
from app.services.expense_claims import ExpenseClaimService
from app.services.ocr import OcrService
def build_session_factory() -> sessionmaker[Session]:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
return sessionmaker(bind=engine, autoflush=False, autocommit=False)
def build_client() -> tuple[TestClient, sessionmaker[Session]]:
session_factory = build_session_factory()
app = create_app()
def override_db() -> Generator[Session, None, None]:
db = session_factory()
try:
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_db
return TestClient(app), session_factory
def seed_claim(db: Session) -> tuple[ExpenseClaim, ExpenseClaimItem]:
manager = Employee(
id="mgr-1",
employee_no="E20001",
name="李总",
email="manager@example.com",
position="市场总监",
grade="P7",
)
role = Role(
id="role-user",
role_code="user",
name="员工",
description="普通员工",
)
employee = Employee(
id="emp-1",
employee_no="E10001",
name="张三",
email="zhangsan@example.com",
position="招商主管",
grade="P4",
manager=manager,
roles=[role],
)
claim = ExpenseClaim(
id="claim-attachment-1",
claim_no="EXP-202605-101",
employee_id=employee.id,
employee_name="张三",
department_id="dept-1",
department_name="市场部",
project_code=None,
expense_type="office",
reason="办公用品采购",
location="深圳南山",
amount=Decimal("88.00"),
currency="CNY",
invoice_count=0,
occurred_at=datetime(2026, 5, 13, tzinfo=UTC),
submitted_at=None,
status="draft",
approval_stage="待提交",
risk_flags_json=[],
)
item = ExpenseClaimItem(
id="item-attachment-1",
claim_id=claim.id,
item_date=date(2026, 5, 13),
item_type="office",
item_reason="办公用品采购",
item_location="深圳南山",
item_amount=Decimal("88.00"),
invoice_id=None,
)
claim.items = [item]
db.add(manager)
db.add(role)
db.add(employee)
db.add(claim)
db.commit()
return claim, item
def test_claim_item_attachment_upload_preview_and_delete(monkeypatch, tmp_path) -> None:
def fake_recognize(
self,
files: list[tuple[str, bytes, str | None]],
) -> OcrRecognizeBatchRead:
assert files[0][0] == "office-note.png"
return OcrRecognizeBatchRead(
total_file_count=1,
success_count=1,
documents=[
OcrRecognizeDocumentRead(
filename="office-note.png",
media_type="image/png",
text="办公用品发票 金额88元 2026-05-13",
summary="识别到办公用品发票,金额 88 元。",
avg_score=0.98,
line_count=1,
page_count=1,
warnings=[],
)
],
)
monkeypatch.setattr(OcrService, "recognize_files", fake_recognize)
monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path)
client, session_factory = build_client()
with session_factory() as db:
claim, item = seed_claim(db)
claim_id = claim.id
item_id = item.id
headers = {"x-auth-username": "emp-1", "x-auth-name": "Zhang San"}
file_bytes = b"fake-image-bytes"
upload_response = client.post(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment",
headers=headers,
files=[("file", ("office-note.png", file_bytes, "image/png"))],
)
assert upload_response.status_code == 200
upload_payload = upload_response.json()
assert upload_payload["attachment"]["file_name"] == "office-note.png"
assert upload_payload["attachment"]["analysis"]["label"] == "AI提示符合条件"
assert upload_payload["attachment"]["document_info"]["document_type"] == "office_invoice"
assert upload_payload["attachment"]["requirement_check"]["matches"] is True
assert upload_payload["invoice_id"]
meta_response = client.get(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment/meta",
headers=headers,
)
assert meta_response.status_code == 200
meta_payload = meta_response.json()
assert meta_payload["media_type"] == "image/png"
assert meta_payload["preview_kind"] == "image"
assert meta_payload["preview_url"].endswith(f"/reimbursements/claims/{claim_id}/items/{item_id}/attachment/preview")
assert meta_payload["analysis"]["headline"]
assert meta_payload["document_info"]["fields"][0]["label"] == "金额"
preview_response = client.get(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment/preview",
headers=headers,
)
assert preview_response.status_code == 200
assert preview_response.content == file_bytes
content_response = client.get(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment",
headers=headers,
)
assert content_response.status_code == 200
assert content_response.content == file_bytes
delete_response = client.delete(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment",
headers=headers,
)
assert delete_response.status_code == 200
assert delete_response.json()["invoice_id"] is None
deleted_meta_response = client.get(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment/meta",
headers=headers,
)
assert deleted_meta_response.status_code == 404
def test_claim_item_attachment_upload_flags_purpose_and_amount_mismatch(monkeypatch, tmp_path) -> None:
def fake_recognize(
self,
files: list[tuple[str, bytes, str | None]],
) -> OcrRecognizeBatchRead:
return OcrRecognizeBatchRead(
total_file_count=1,
success_count=1,
documents=[
OcrRecognizeDocumentRead(
filename="taxi-note.png",
media_type="image/png",
text="滴滴出行电子发票 金额120元 2026-05-13",
summary="识别到交通出行发票,金额 120 元。",
avg_score=0.97,
line_count=1,
page_count=1,
warnings=[],
)
],
)
monkeypatch.setattr(OcrService, "recognize_files", fake_recognize)
monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path)
client, session_factory = build_client()
with session_factory() as db:
claim, item = seed_claim(db)
claim_id = claim.id
item_id = item.id
headers = {"x-auth-username": "emp-1", "x-auth-name": "Zhang San"}
upload_response = client.post(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment",
headers=headers,
files=[("file", ("taxi-note.png", b"fake-image-bytes", "image/png"))],
)
assert upload_response.status_code == 200
analysis = upload_response.json()["attachment"]["analysis"]
assert analysis["severity"] == "high"
assert any("金额字段" in point for point in analysis["points"])
assert any("附件类型要求" in point for point in analysis["points"])
assert upload_response.json()["attachment"]["requirement_check"]["matches"] is False
def test_claim_item_attachment_upload_flags_non_invoice_image_as_high_risk(monkeypatch, tmp_path) -> None:
def fake_recognize(
self,
files: list[tuple[str, bytes, str | None]],
) -> OcrRecognizeBatchRead:
return OcrRecognizeBatchRead(
total_file_count=1,
success_count=1,
documents=[
OcrRecognizeDocumentRead(
filename="random-image.png",
media_type="image/png",
text="",
summary="",
avg_score=0.0,
line_count=0,
page_count=1,
warnings=[],
)
],
)
monkeypatch.setattr(OcrService, "recognize_files", fake_recognize)
monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path)
client, session_factory = build_client()
with session_factory() as db:
claim, item = seed_claim(db)
claim_id = claim.id
item_id = item.id
headers = {"x-auth-username": "emp-1", "x-auth-name": "Zhang San"}
upload_response = client.post(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment",
headers=headers,
files=[("file", ("random-image.png", b"fake-image-bytes", "image/png"))],
)
assert upload_response.status_code == 200
analysis = upload_response.json()["attachment"]["analysis"]
assert analysis["severity"] == "high"
assert any("附件内容" in point for point in analysis["points"])
def test_claim_item_pdf_attachment_preview_returns_generated_image(monkeypatch, tmp_path) -> None:
preview_bytes = b"fake-preview-png"
preview_data_url = f"data:image/png;base64,{base64.b64encode(preview_bytes).decode('ascii')}"
def fake_recognize(
self,
files: list[tuple[str, bytes, str | None]],
) -> OcrRecognizeBatchRead:
return OcrRecognizeBatchRead(
total_file_count=1,
success_count=1,
documents=[
OcrRecognizeDocumentRead(
filename="invoice.pdf",
media_type="application/pdf",
text="滴滴出行电子发票 金额13.4元",
summary="识别到交通票据,金额 13.4 元。",
avg_score=0.96,
line_count=1,
page_count=1,
document_type="taxi_receipt",
document_type_label="出租车/网约车票据",
scene_code="transport",
scene_label="交通票据",
preview_kind="image",
preview_data_url=preview_data_url,
warnings=[],
)
],
)
monkeypatch.setattr(OcrService, "recognize_files", fake_recognize)
monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path)
client, session_factory = build_client()
with session_factory() as db:
claim, item = seed_claim(db)
claim_id = claim.id
item_id = item.id
headers = {"x-auth-username": "emp-1", "x-auth-name": "Zhang San"}
upload_response = client.post(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment",
headers=headers,
files=[("file", ("invoice.pdf", b"%PDF-1.4 fake", "application/pdf"))],
)
assert upload_response.status_code == 200
meta_payload = upload_response.json()["attachment"]
assert meta_payload["preview_kind"] == "image"
assert meta_payload["preview_url"].endswith(f"/reimbursements/claims/{claim_id}/items/{item_id}/attachment/preview")
preview_response = client.get(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment/preview",
headers=headers,
)
assert preview_response.status_code == 200
assert preview_response.headers["content-type"].startswith("image/png")
assert preview_response.content == preview_bytes
def test_claim_item_delete_removes_item_and_attachment(monkeypatch, tmp_path) -> None:
def fake_recognize(
self,
files: list[tuple[str, bytes, str | None]],
) -> OcrRecognizeBatchRead:
return OcrRecognizeBatchRead(
total_file_count=1,
success_count=1,
documents=[
OcrRecognizeDocumentRead(
filename="office-note.png",
media_type="image/png",
text="办公用品发票 金额88元 2026-05-13",
summary="识别到办公用品发票,金额 88 元。",
avg_score=0.98,
line_count=1,
page_count=1,
warnings=[],
)
],
)
monkeypatch.setattr(OcrService, "recognize_files", fake_recognize)
monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path)
client, session_factory = build_client()
with session_factory() as db:
claim, item = seed_claim(db)
claim_id = claim.id
item_id = item.id
headers = {"x-auth-username": "emp-1", "x-auth-name": "Zhang San"}
upload_response = client.post(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment",
headers=headers,
files=[("file", ("office-note.png", b"fake-image-bytes", "image/png"))],
)
assert upload_response.status_code == 200
assert (tmp_path / claim_id / item_id).exists()
delete_response = client.delete(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}",
headers=headers,
)
assert delete_response.status_code == 200
assert delete_response.json()["item_id"] == item_id
assert not (tmp_path / claim_id / item_id).exists()
detail_response = client.get(
f"/api/v1/reimbursements/claims/{claim_id}",
headers=headers,
)
assert detail_response.status_code == 200
detail_payload = detail_response.json()
assert detail_payload["items"] == []
assert detail_payload["invoice_count"] == 0
assert detail_payload["employee_position"] == "招商主管"
assert detail_payload["employee_grade"] == "P4"
assert detail_payload["manager_name"] == "李总"
assert detail_payload["role_labels"] == ["员工"]
deleted_meta_response = client.get(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment/meta",
headers=headers,
)
assert deleted_meta_response.status_code == 404