test(backend): update and add service tests

Updated tests:
- test_ontology_service.py: update ontology service tests

New tests:
- test_expense_claim_service.py: add expense claim service tests
- test_reimbursement_endpoints.py: add reimbursement endpoint tests
This commit is contained in:
caoxiaozhu
2026-05-13 06:46:24 +00:00
parent 6317fc0ccd
commit 13df8fc9dc
3 changed files with 608 additions and 0 deletions

View File

@@ -0,0 +1,280 @@
from __future__ import annotations
from datetime import UTC, date, datetime
from decimal import Decimal
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.api.deps import CurrentUserContext
from app.db.base import Base
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead
from app.schemas.reimbursement import ExpenseClaimItemCreate, ExpenseClaimItemUpdate
from app.services.expense_claims import ExpenseClaimService
from app.services.ocr import OcrService
def build_claim(*, expense_type: str, location: str) -> ExpenseClaim:
claim = ExpenseClaim(
id="claim-1",
claim_no="EXP-202605-001",
employee_id="emp-1",
employee_name="张三",
department_id="dept-1",
department_name="市场部",
project_code=None,
expense_type=expense_type,
reason="费用报销",
location=location,
amount=Decimal("88.00"),
currency="CNY",
invoice_count=1,
occurred_at=datetime(2026, 5, 13, tzinfo=UTC),
submitted_at=None,
status="draft",
approval_stage="待提交",
risk_flags_json=[],
)
claim.items = [
ExpenseClaimItem(
id="item-1",
claim_id="claim-1",
item_date=date(2026, 5, 13),
item_type=expense_type,
item_reason="费用报销",
item_location=location,
item_amount=Decimal("88.00"),
invoice_id="invoice-1",
)
]
return claim
def build_session() -> Session:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
return session_factory()
def test_validate_claim_for_submission_allows_office_claim_without_location() -> None:
service = ExpenseClaimService.__new__(ExpenseClaimService)
claim = build_claim(expense_type="office", location="待补充")
issues = service._validate_claim_for_submission(claim)
assert "业务地点未完善" not in issues
assert not any("缺少地点" in item for item in issues)
def test_validate_claim_for_submission_still_requires_location_for_travel_claim() -> None:
service = ExpenseClaimService.__new__(ExpenseClaimService)
claim = build_claim(expense_type="travel", location="待补充")
issues = service._validate_claim_for_submission(claim)
assert "业务地点未完善" in issues
assert any("缺少地点" in item for item in issues)
def test_resolve_expense_type_maps_office_supplies_review_value_to_office() -> None:
expense_type = ExpenseClaimService._resolve_expense_type(
[],
context_json={
"review_form_values": {
"expense_type": "办公用品"
}
},
)
assert expense_type == "office"
def test_create_claim_item_adds_blank_draft_row_without_forcing_attachment() -> None:
current_user = CurrentUserContext(
username="emp-1",
name="张三",
role_codes=[],
is_admin=False,
)
with build_session() as db:
claim = build_claim(expense_type="office", location="深圳南山")
db.add(claim)
db.commit()
service = ExpenseClaimService(db)
updated = service.create_claim_item(
claim_id=claim.id,
payload=ExpenseClaimItemCreate(),
current_user=current_user,
)
assert updated is not None
assert len(updated.items) == 2
assert updated.amount == Decimal("88.00")
assert updated.invoice_count == 1
new_item = next(item for item in updated.items if item.id != "item-1")
assert new_item.item_type == "office"
assert new_item.item_reason == ""
assert new_item.item_location == ""
assert new_item.item_amount == Decimal("0.00")
assert new_item.invoice_id is None
def test_update_claim_item_reanalyzes_existing_attachment(monkeypatch, tmp_path) -> None:
current_user = CurrentUserContext(
username="emp-1",
name="张三",
role_codes=[],
is_admin=False,
)
def fake_recognize(
self,
files: list[tuple[str, bytes, str | None]],
) -> OcrRecognizeBatchRead:
return OcrRecognizeBatchRead(
total_file_count=1,
success_count=1,
documents=[
OcrRecognizeDocumentRead(
filename="office-note.png",
media_type="image/png",
text="办公用品发票 金额88元 2026-05-13",
summary="识别到办公用品发票,金额 88 元。",
avg_score=0.98,
line_count=1,
page_count=1,
warnings=[],
)
],
)
monkeypatch.setattr(OcrService, "recognize_files", fake_recognize)
monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path)
with build_session() as db:
claim = build_claim(expense_type="office", location="深圳南山")
claim.invoice_count = 0
claim.items[0].invoice_id = None
claim.items[0].item_reason = "办公用品采购"
db.add(claim)
db.commit()
service = ExpenseClaimService(db)
service.upload_claim_item_attachment(
claim_id=claim.id,
item_id=claim.items[0].id,
filename="office-note.png",
content=b"fake-image-bytes",
media_type="image/png",
current_user=current_user,
)
uploaded_meta = service.get_claim_item_attachment_meta(
claim_id=claim.id,
item_id=claim.items[0].id,
current_user=current_user,
)
assert uploaded_meta is not None
assert uploaded_meta["analysis"]["severity"] == "pass"
updated = service.update_claim_item(
claim_id=claim.id,
item_id=claim.items[0].id,
payload=ExpenseClaimItemUpdate(
item_type="transport",
item_reason="打车报销",
),
current_user=current_user,
)
assert updated is not None
assert any(flag.get("source") == "attachment_analysis" for flag in updated.risk_flags_json)
refreshed_meta = service.get_claim_item_attachment_meta(
claim_id=claim.id,
item_id=claim.items[0].id,
current_user=current_user,
)
assert refreshed_meta is not None
assert refreshed_meta["analysis"]["severity"] == "medium"
assert any("用途字段" in point for point in refreshed_meta["analysis"]["points"])
def test_delete_claim_item_removes_row_and_attachment_files(monkeypatch, tmp_path) -> None:
current_user = CurrentUserContext(
username="emp-1",
name="张三",
role_codes=[],
is_admin=False,
)
def fake_recognize(
self,
files: list[tuple[str, bytes, str | None]],
) -> OcrRecognizeBatchRead:
return OcrRecognizeBatchRead(
total_file_count=1,
success_count=1,
documents=[
OcrRecognizeDocumentRead(
filename="office-note.png",
media_type="image/png",
text="办公用品发票 金额88元 2026-05-13",
summary="识别到办公用品发票,金额 88 元。",
avg_score=0.98,
line_count=1,
page_count=1,
warnings=[],
)
],
)
monkeypatch.setattr(OcrService, "recognize_files", fake_recognize)
monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path)
with build_session() as db:
claim = build_claim(expense_type="office", location="深圳南山")
claim.invoice_count = 0
claim.items[0].invoice_id = None
claim.items[0].item_reason = "办公用品采购"
db.add(claim)
db.commit()
service = ExpenseClaimService(db)
upload_payload = service.upload_claim_item_attachment(
claim_id=claim.id,
item_id=claim.items[0].id,
filename="office-note.png",
content=b"fake-image-bytes",
media_type="image/png",
current_user=current_user,
)
assert upload_payload is not None
attachment_root = tmp_path / claim.id / claim.items[0].id
assert attachment_root.exists()
delete_payload = service.delete_claim_item(
claim_id=claim.id,
item_id=claim.items[0].id,
current_user=current_user,
)
assert delete_payload is not None
assert delete_payload["claim_id"] == claim.id
refreshed_claim = service.get_claim(claim.id, current_user)
assert refreshed_claim is not None
assert refreshed_claim.items == []
assert refreshed_claim.amount == Decimal("0.00")
assert refreshed_claim.invoice_count == 0
assert not attachment_root.exists()

View File

@@ -333,6 +333,24 @@ def test_semantic_ontology_service_extracts_day_before_yesterday_from_client_loc
assert result.time_range.end_date == "2026-05-11"
def test_semantic_ontology_service_maps_office_supplies_to_office_expense_type() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我买了办公用品和文具花了88元帮我报销",
user_id="pytest",
)
)
assert result.scenario == "expense"
assert result.intent == "draft"
assert any(
item.type == "expense_type" and item.normalized_value == "office"
for item in result.entities
)
def test_semantic_ontology_service_uses_model_parse_when_available(monkeypatch) -> None:
session_factory = build_session_factory()
with session_factory() as db:

View File

@@ -0,0 +1,310 @@
from __future__ import annotations
from collections.abc import Generator
from datetime import UTC, date, datetime
from decimal import Decimal
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.api.deps import get_db
from app.db.base import Base
from app.main import create_app
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead
from app.services.expense_claims import ExpenseClaimService
from app.services.ocr import OcrService
def build_session_factory() -> sessionmaker[Session]:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
return sessionmaker(bind=engine, autoflush=False, autocommit=False)
def build_client() -> tuple[TestClient, sessionmaker[Session]]:
session_factory = build_session_factory()
app = create_app()
def override_db() -> Generator[Session, None, None]:
db = session_factory()
try:
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_db
return TestClient(app), session_factory
def seed_claim(db: Session) -> tuple[ExpenseClaim, ExpenseClaimItem]:
claim = ExpenseClaim(
id="claim-attachment-1",
claim_no="EXP-202605-101",
employee_id="emp-1",
employee_name="张三",
department_id="dept-1",
department_name="市场部",
project_code=None,
expense_type="office",
reason="办公用品采购",
location="深圳南山",
amount=Decimal("88.00"),
currency="CNY",
invoice_count=0,
occurred_at=datetime(2026, 5, 13, tzinfo=UTC),
submitted_at=None,
status="draft",
approval_stage="待提交",
risk_flags_json=[],
)
item = ExpenseClaimItem(
id="item-attachment-1",
claim_id=claim.id,
item_date=date(2026, 5, 13),
item_type="office",
item_reason="办公用品采购",
item_location="深圳南山",
item_amount=Decimal("88.00"),
invoice_id=None,
)
claim.items = [item]
db.add(claim)
db.commit()
return claim, item
def test_claim_item_attachment_upload_preview_and_delete(monkeypatch, tmp_path) -> None:
def fake_recognize(
self,
files: list[tuple[str, bytes, str | None]],
) -> OcrRecognizeBatchRead:
assert files[0][0] == "office-note.png"
return OcrRecognizeBatchRead(
total_file_count=1,
success_count=1,
documents=[
OcrRecognizeDocumentRead(
filename="office-note.png",
media_type="image/png",
text="办公用品发票 金额88元 2026-05-13",
summary="识别到办公用品发票,金额 88 元。",
avg_score=0.98,
line_count=1,
page_count=1,
warnings=[],
)
],
)
monkeypatch.setattr(OcrService, "recognize_files", fake_recognize)
monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path)
client, session_factory = build_client()
with session_factory() as db:
claim, item = seed_claim(db)
claim_id = claim.id
item_id = item.id
headers = {"x-auth-username": "emp-1", "x-auth-name": "Zhang San"}
file_bytes = b"fake-image-bytes"
upload_response = client.post(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment",
headers=headers,
files=[("file", ("office-note.png", file_bytes, "image/png"))],
)
assert upload_response.status_code == 200
upload_payload = upload_response.json()
assert upload_payload["attachment"]["file_name"] == "office-note.png"
assert upload_payload["attachment"]["analysis"]["label"] == "AI提示符合条件"
assert upload_payload["invoice_id"]
meta_response = client.get(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment/meta",
headers=headers,
)
assert meta_response.status_code == 200
meta_payload = meta_response.json()
assert meta_payload["media_type"] == "image/png"
assert meta_payload["analysis"]["headline"]
content_response = client.get(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment",
headers=headers,
)
assert content_response.status_code == 200
assert content_response.content == file_bytes
delete_response = client.delete(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment",
headers=headers,
)
assert delete_response.status_code == 200
assert delete_response.json()["invoice_id"] is None
deleted_meta_response = client.get(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment/meta",
headers=headers,
)
assert deleted_meta_response.status_code == 404
def test_claim_item_attachment_upload_flags_purpose_and_amount_mismatch(monkeypatch, tmp_path) -> None:
def fake_recognize(
self,
files: list[tuple[str, bytes, str | None]],
) -> OcrRecognizeBatchRead:
return OcrRecognizeBatchRead(
total_file_count=1,
success_count=1,
documents=[
OcrRecognizeDocumentRead(
filename="taxi-note.png",
media_type="image/png",
text="滴滴出行电子发票 金额120元 2026-05-13",
summary="识别到交通出行发票,金额 120 元。",
avg_score=0.97,
line_count=1,
page_count=1,
warnings=[],
)
],
)
monkeypatch.setattr(OcrService, "recognize_files", fake_recognize)
monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path)
client, session_factory = build_client()
with session_factory() as db:
claim, item = seed_claim(db)
claim_id = claim.id
item_id = item.id
headers = {"x-auth-username": "emp-1", "x-auth-name": "Zhang San"}
upload_response = client.post(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment",
headers=headers,
files=[("file", ("taxi-note.png", b"fake-image-bytes", "image/png"))],
)
assert upload_response.status_code == 200
analysis = upload_response.json()["attachment"]["analysis"]
assert analysis["severity"] == "high"
assert any("金额字段" in point for point in analysis["points"])
assert any("用途字段" in point for point in analysis["points"])
def test_claim_item_attachment_upload_flags_non_invoice_image_as_high_risk(monkeypatch, tmp_path) -> None:
def fake_recognize(
self,
files: list[tuple[str, bytes, str | None]],
) -> OcrRecognizeBatchRead:
return OcrRecognizeBatchRead(
total_file_count=1,
success_count=1,
documents=[
OcrRecognizeDocumentRead(
filename="random-image.png",
media_type="image/png",
text="",
summary="",
avg_score=0.0,
line_count=0,
page_count=1,
warnings=[],
)
],
)
monkeypatch.setattr(OcrService, "recognize_files", fake_recognize)
monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path)
client, session_factory = build_client()
with session_factory() as db:
claim, item = seed_claim(db)
claim_id = claim.id
item_id = item.id
headers = {"x-auth-username": "emp-1", "x-auth-name": "Zhang San"}
upload_response = client.post(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment",
headers=headers,
files=[("file", ("random-image.png", b"fake-image-bytes", "image/png"))],
)
assert upload_response.status_code == 200
analysis = upload_response.json()["attachment"]["analysis"]
assert analysis["severity"] == "high"
assert any("附件内容" in point for point in analysis["points"])
def test_claim_item_delete_removes_item_and_attachment(monkeypatch, tmp_path) -> None:
def fake_recognize(
self,
files: list[tuple[str, bytes, str | None]],
) -> OcrRecognizeBatchRead:
return OcrRecognizeBatchRead(
total_file_count=1,
success_count=1,
documents=[
OcrRecognizeDocumentRead(
filename="office-note.png",
media_type="image/png",
text="办公用品发票 金额88元 2026-05-13",
summary="识别到办公用品发票,金额 88 元。",
avg_score=0.98,
line_count=1,
page_count=1,
warnings=[],
)
],
)
monkeypatch.setattr(OcrService, "recognize_files", fake_recognize)
monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path)
client, session_factory = build_client()
with session_factory() as db:
claim, item = seed_claim(db)
claim_id = claim.id
item_id = item.id
headers = {"x-auth-username": "emp-1", "x-auth-name": "Zhang San"}
upload_response = client.post(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment",
headers=headers,
files=[("file", ("office-note.png", b"fake-image-bytes", "image/png"))],
)
assert upload_response.status_code == 200
assert (tmp_path / claim_id / item_id).exists()
delete_response = client.delete(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}",
headers=headers,
)
assert delete_response.status_code == 200
assert delete_response.json()["item_id"] == item_id
assert not (tmp_path / claim_id / item_id).exists()
detail_response = client.get(
f"/api/v1/reimbursements/claims/{claim_id}",
headers=headers,
)
assert detail_response.status_code == 200
detail_payload = detail_response.json()
assert detail_payload["items"] == []
assert detail_payload["invoice_count"] == 0
deleted_meta_response = client.get(
f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment/meta",
headers=headers,
)
assert deleted_meta_response.status_code == 404