diff --git a/server/tests/test_expense_claim_service.py b/server/tests/test_expense_claim_service.py new file mode 100644 index 0000000..0f84ec5 --- /dev/null +++ b/server/tests/test_expense_claim_service.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +from datetime import UTC, date, datetime +from decimal import Decimal + +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.pool import StaticPool + +from app.api.deps import CurrentUserContext +from app.db.base import Base +from app.models.financial_record import ExpenseClaim, ExpenseClaimItem +from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead +from app.schemas.reimbursement import ExpenseClaimItemCreate, ExpenseClaimItemUpdate +from app.services.expense_claims import ExpenseClaimService +from app.services.ocr import OcrService + + +def build_claim(*, expense_type: str, location: str) -> ExpenseClaim: + claim = ExpenseClaim( + id="claim-1", + claim_no="EXP-202605-001", + employee_id="emp-1", + employee_name="张三", + department_id="dept-1", + department_name="市场部", + project_code=None, + expense_type=expense_type, + reason="费用报销", + location=location, + amount=Decimal("88.00"), + currency="CNY", + invoice_count=1, + occurred_at=datetime(2026, 5, 13, tzinfo=UTC), + submitted_at=None, + status="draft", + approval_stage="待提交", + risk_flags_json=[], + ) + claim.items = [ + ExpenseClaimItem( + id="item-1", + claim_id="claim-1", + item_date=date(2026, 5, 13), + item_type=expense_type, + item_reason="费用报销", + item_location=location, + item_amount=Decimal("88.00"), + invoice_id="invoice-1", + ) + ] + return claim + + +def build_session() -> Session: + engine = create_engine( + "sqlite+pysqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + Base.metadata.create_all(bind=engine) + session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False) + return session_factory() + + +def test_validate_claim_for_submission_allows_office_claim_without_location() -> None: + service = ExpenseClaimService.__new__(ExpenseClaimService) + claim = build_claim(expense_type="office", location="待补充") + + issues = service._validate_claim_for_submission(claim) + + assert "业务地点未完善" not in issues + assert not any("缺少地点" in item for item in issues) + + +def test_validate_claim_for_submission_still_requires_location_for_travel_claim() -> None: + service = ExpenseClaimService.__new__(ExpenseClaimService) + claim = build_claim(expense_type="travel", location="待补充") + + issues = service._validate_claim_for_submission(claim) + + assert "业务地点未完善" in issues + assert any("缺少地点" in item for item in issues) + + +def test_resolve_expense_type_maps_office_supplies_review_value_to_office() -> None: + expense_type = ExpenseClaimService._resolve_expense_type( + [], + context_json={ + "review_form_values": { + "expense_type": "办公用品" + } + }, + ) + + assert expense_type == "office" + + +def test_create_claim_item_adds_blank_draft_row_without_forcing_attachment() -> None: + current_user = CurrentUserContext( + username="emp-1", + name="张三", + role_codes=[], + is_admin=False, + ) + + with build_session() as db: + claim = build_claim(expense_type="office", location="深圳南山") + db.add(claim) + db.commit() + + service = ExpenseClaimService(db) + updated = service.create_claim_item( + claim_id=claim.id, + payload=ExpenseClaimItemCreate(), + current_user=current_user, + ) + + assert updated is not None + assert len(updated.items) == 2 + assert updated.amount == Decimal("88.00") + assert updated.invoice_count == 1 + + new_item = next(item for item in updated.items if item.id != "item-1") + assert new_item.item_type == "office" + assert new_item.item_reason == "" + assert new_item.item_location == "" + assert new_item.item_amount == Decimal("0.00") + assert new_item.invoice_id is None + + +def test_update_claim_item_reanalyzes_existing_attachment(monkeypatch, tmp_path) -> None: + current_user = CurrentUserContext( + username="emp-1", + name="张三", + role_codes=[], + is_admin=False, + ) + + def fake_recognize( + self, + files: list[tuple[str, bytes, str | None]], + ) -> OcrRecognizeBatchRead: + return OcrRecognizeBatchRead( + total_file_count=1, + success_count=1, + documents=[ + OcrRecognizeDocumentRead( + filename="office-note.png", + media_type="image/png", + text="办公用品发票 金额88元 2026-05-13", + summary="识别到办公用品发票,金额 88 元。", + avg_score=0.98, + line_count=1, + page_count=1, + warnings=[], + ) + ], + ) + + monkeypatch.setattr(OcrService, "recognize_files", fake_recognize) + monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path) + + with build_session() as db: + claim = build_claim(expense_type="office", location="深圳南山") + claim.invoice_count = 0 + claim.items[0].invoice_id = None + claim.items[0].item_reason = "办公用品采购" + db.add(claim) + db.commit() + + service = ExpenseClaimService(db) + service.upload_claim_item_attachment( + claim_id=claim.id, + item_id=claim.items[0].id, + filename="office-note.png", + content=b"fake-image-bytes", + media_type="image/png", + current_user=current_user, + ) + + uploaded_meta = service.get_claim_item_attachment_meta( + claim_id=claim.id, + item_id=claim.items[0].id, + current_user=current_user, + ) + assert uploaded_meta is not None + assert uploaded_meta["analysis"]["severity"] == "pass" + + updated = service.update_claim_item( + claim_id=claim.id, + item_id=claim.items[0].id, + payload=ExpenseClaimItemUpdate( + item_type="transport", + item_reason="打车报销", + ), + current_user=current_user, + ) + + assert updated is not None + assert any(flag.get("source") == "attachment_analysis" for flag in updated.risk_flags_json) + + refreshed_meta = service.get_claim_item_attachment_meta( + claim_id=claim.id, + item_id=claim.items[0].id, + current_user=current_user, + ) + assert refreshed_meta is not None + assert refreshed_meta["analysis"]["severity"] == "medium" + assert any("用途字段" in point for point in refreshed_meta["analysis"]["points"]) + + +def test_delete_claim_item_removes_row_and_attachment_files(monkeypatch, tmp_path) -> None: + current_user = CurrentUserContext( + username="emp-1", + name="张三", + role_codes=[], + is_admin=False, + ) + + def fake_recognize( + self, + files: list[tuple[str, bytes, str | None]], + ) -> OcrRecognizeBatchRead: + return OcrRecognizeBatchRead( + total_file_count=1, + success_count=1, + documents=[ + OcrRecognizeDocumentRead( + filename="office-note.png", + media_type="image/png", + text="办公用品发票 金额88元 2026-05-13", + summary="识别到办公用品发票,金额 88 元。", + avg_score=0.98, + line_count=1, + page_count=1, + warnings=[], + ) + ], + ) + + monkeypatch.setattr(OcrService, "recognize_files", fake_recognize) + monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path) + + with build_session() as db: + claim = build_claim(expense_type="office", location="深圳南山") + claim.invoice_count = 0 + claim.items[0].invoice_id = None + claim.items[0].item_reason = "办公用品采购" + db.add(claim) + db.commit() + + service = ExpenseClaimService(db) + upload_payload = service.upload_claim_item_attachment( + claim_id=claim.id, + item_id=claim.items[0].id, + filename="office-note.png", + content=b"fake-image-bytes", + media_type="image/png", + current_user=current_user, + ) + + assert upload_payload is not None + attachment_root = tmp_path / claim.id / claim.items[0].id + assert attachment_root.exists() + + delete_payload = service.delete_claim_item( + claim_id=claim.id, + item_id=claim.items[0].id, + current_user=current_user, + ) + + assert delete_payload is not None + assert delete_payload["claim_id"] == claim.id + refreshed_claim = service.get_claim(claim.id, current_user) + assert refreshed_claim is not None + assert refreshed_claim.items == [] + assert refreshed_claim.amount == Decimal("0.00") + assert refreshed_claim.invoice_count == 0 + assert not attachment_root.exists() diff --git a/server/tests/test_ontology_service.py b/server/tests/test_ontology_service.py index ab45ee8..93eb52c 100644 --- a/server/tests/test_ontology_service.py +++ b/server/tests/test_ontology_service.py @@ -333,6 +333,24 @@ def test_semantic_ontology_service_extracts_day_before_yesterday_from_client_loc assert result.time_range.end_date == "2026-05-11" +def test_semantic_ontology_service_maps_office_supplies_to_office_expense_type() -> None: + session_factory = build_session_factory() + with session_factory() as db: + result = SemanticOntologyService(db).parse( + OntologyParseRequest( + query="我买了办公用品和文具,花了88元,帮我报销", + user_id="pytest", + ) + ) + + assert result.scenario == "expense" + assert result.intent == "draft" + assert any( + item.type == "expense_type" and item.normalized_value == "office" + for item in result.entities + ) + + def test_semantic_ontology_service_uses_model_parse_when_available(monkeypatch) -> None: session_factory = build_session_factory() with session_factory() as db: diff --git a/server/tests/test_reimbursement_endpoints.py b/server/tests/test_reimbursement_endpoints.py new file mode 100644 index 0000000..edbd831 --- /dev/null +++ b/server/tests/test_reimbursement_endpoints.py @@ -0,0 +1,310 @@ +from __future__ import annotations + +from collections.abc import Generator +from datetime import UTC, date, datetime +from decimal import Decimal + +from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.pool import StaticPool + +from app.api.deps import get_db +from app.db.base import Base +from app.main import create_app +from app.models.financial_record import ExpenseClaim, ExpenseClaimItem +from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead +from app.services.expense_claims import ExpenseClaimService +from app.services.ocr import OcrService + + +def build_session_factory() -> sessionmaker[Session]: + engine = create_engine( + "sqlite+pysqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + Base.metadata.create_all(bind=engine) + return sessionmaker(bind=engine, autoflush=False, autocommit=False) + + +def build_client() -> tuple[TestClient, sessionmaker[Session]]: + session_factory = build_session_factory() + app = create_app() + + def override_db() -> Generator[Session, None, None]: + db = session_factory() + try: + yield db + finally: + db.close() + + app.dependency_overrides[get_db] = override_db + return TestClient(app), session_factory + + +def seed_claim(db: Session) -> tuple[ExpenseClaim, ExpenseClaimItem]: + claim = ExpenseClaim( + id="claim-attachment-1", + claim_no="EXP-202605-101", + employee_id="emp-1", + employee_name="张三", + department_id="dept-1", + department_name="市场部", + project_code=None, + expense_type="office", + reason="办公用品采购", + location="深圳南山", + amount=Decimal("88.00"), + currency="CNY", + invoice_count=0, + occurred_at=datetime(2026, 5, 13, tzinfo=UTC), + submitted_at=None, + status="draft", + approval_stage="待提交", + risk_flags_json=[], + ) + item = ExpenseClaimItem( + id="item-attachment-1", + claim_id=claim.id, + item_date=date(2026, 5, 13), + item_type="office", + item_reason="办公用品采购", + item_location="深圳南山", + item_amount=Decimal("88.00"), + invoice_id=None, + ) + claim.items = [item] + db.add(claim) + db.commit() + return claim, item + + +def test_claim_item_attachment_upload_preview_and_delete(monkeypatch, tmp_path) -> None: + def fake_recognize( + self, + files: list[tuple[str, bytes, str | None]], + ) -> OcrRecognizeBatchRead: + assert files[0][0] == "office-note.png" + return OcrRecognizeBatchRead( + total_file_count=1, + success_count=1, + documents=[ + OcrRecognizeDocumentRead( + filename="office-note.png", + media_type="image/png", + text="办公用品发票 金额88元 2026-05-13", + summary="识别到办公用品发票,金额 88 元。", + avg_score=0.98, + line_count=1, + page_count=1, + warnings=[], + ) + ], + ) + + monkeypatch.setattr(OcrService, "recognize_files", fake_recognize) + monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path) + + client, session_factory = build_client() + with session_factory() as db: + claim, item = seed_claim(db) + claim_id = claim.id + item_id = item.id + + headers = {"x-auth-username": "emp-1", "x-auth-name": "Zhang San"} + file_bytes = b"fake-image-bytes" + + upload_response = client.post( + f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment", + headers=headers, + files=[("file", ("office-note.png", file_bytes, "image/png"))], + ) + + assert upload_response.status_code == 200 + upload_payload = upload_response.json() + assert upload_payload["attachment"]["file_name"] == "office-note.png" + assert upload_payload["attachment"]["analysis"]["label"] == "AI提示符合条件" + assert upload_payload["invoice_id"] + + meta_response = client.get( + f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment/meta", + headers=headers, + ) + assert meta_response.status_code == 200 + meta_payload = meta_response.json() + assert meta_payload["media_type"] == "image/png" + assert meta_payload["analysis"]["headline"] + + content_response = client.get( + f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment", + headers=headers, + ) + assert content_response.status_code == 200 + assert content_response.content == file_bytes + + delete_response = client.delete( + f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment", + headers=headers, + ) + assert delete_response.status_code == 200 + assert delete_response.json()["invoice_id"] is None + + deleted_meta_response = client.get( + f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment/meta", + headers=headers, + ) + assert deleted_meta_response.status_code == 404 + + +def test_claim_item_attachment_upload_flags_purpose_and_amount_mismatch(monkeypatch, tmp_path) -> None: + def fake_recognize( + self, + files: list[tuple[str, bytes, str | None]], + ) -> OcrRecognizeBatchRead: + return OcrRecognizeBatchRead( + total_file_count=1, + success_count=1, + documents=[ + OcrRecognizeDocumentRead( + filename="taxi-note.png", + media_type="image/png", + text="滴滴出行电子发票 金额120元 2026-05-13", + summary="识别到交通出行发票,金额 120 元。", + avg_score=0.97, + line_count=1, + page_count=1, + warnings=[], + ) + ], + ) + + monkeypatch.setattr(OcrService, "recognize_files", fake_recognize) + monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path) + + client, session_factory = build_client() + with session_factory() as db: + claim, item = seed_claim(db) + claim_id = claim.id + item_id = item.id + + headers = {"x-auth-username": "emp-1", "x-auth-name": "Zhang San"} + upload_response = client.post( + f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment", + headers=headers, + files=[("file", ("taxi-note.png", b"fake-image-bytes", "image/png"))], + ) + + assert upload_response.status_code == 200 + analysis = upload_response.json()["attachment"]["analysis"] + assert analysis["severity"] == "high" + assert any("金额字段" in point for point in analysis["points"]) + assert any("用途字段" in point for point in analysis["points"]) + + +def test_claim_item_attachment_upload_flags_non_invoice_image_as_high_risk(monkeypatch, tmp_path) -> None: + def fake_recognize( + self, + files: list[tuple[str, bytes, str | None]], + ) -> OcrRecognizeBatchRead: + return OcrRecognizeBatchRead( + total_file_count=1, + success_count=1, + documents=[ + OcrRecognizeDocumentRead( + filename="random-image.png", + media_type="image/png", + text="", + summary="", + avg_score=0.0, + line_count=0, + page_count=1, + warnings=[], + ) + ], + ) + + monkeypatch.setattr(OcrService, "recognize_files", fake_recognize) + monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path) + + client, session_factory = build_client() + with session_factory() as db: + claim, item = seed_claim(db) + claim_id = claim.id + item_id = item.id + + headers = {"x-auth-username": "emp-1", "x-auth-name": "Zhang San"} + upload_response = client.post( + f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment", + headers=headers, + files=[("file", ("random-image.png", b"fake-image-bytes", "image/png"))], + ) + + assert upload_response.status_code == 200 + analysis = upload_response.json()["attachment"]["analysis"] + assert analysis["severity"] == "high" + assert any("附件内容" in point for point in analysis["points"]) + + +def test_claim_item_delete_removes_item_and_attachment(monkeypatch, tmp_path) -> None: + def fake_recognize( + self, + files: list[tuple[str, bytes, str | None]], + ) -> OcrRecognizeBatchRead: + return OcrRecognizeBatchRead( + total_file_count=1, + success_count=1, + documents=[ + OcrRecognizeDocumentRead( + filename="office-note.png", + media_type="image/png", + text="办公用品发票 金额88元 2026-05-13", + summary="识别到办公用品发票,金额 88 元。", + avg_score=0.98, + line_count=1, + page_count=1, + warnings=[], + ) + ], + ) + + monkeypatch.setattr(OcrService, "recognize_files", fake_recognize) + monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path) + + client, session_factory = build_client() + with session_factory() as db: + claim, item = seed_claim(db) + claim_id = claim.id + item_id = item.id + + headers = {"x-auth-username": "emp-1", "x-auth-name": "Zhang San"} + upload_response = client.post( + f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment", + headers=headers, + files=[("file", ("office-note.png", b"fake-image-bytes", "image/png"))], + ) + assert upload_response.status_code == 200 + assert (tmp_path / claim_id / item_id).exists() + + delete_response = client.delete( + f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}", + headers=headers, + ) + assert delete_response.status_code == 200 + assert delete_response.json()["item_id"] == item_id + assert not (tmp_path / claim_id / item_id).exists() + + detail_response = client.get( + f"/api/v1/reimbursements/claims/{claim_id}", + headers=headers, + ) + assert detail_response.status_code == 200 + detail_payload = detail_response.json() + assert detail_payload["items"] == [] + assert detail_payload["invoice_count"] == 0 + + deleted_meta_response = client.get( + f"/api/v1/reimbursements/claims/{claim_id}/items/{item_id}/attachment/meta", + headers=headers, + ) + assert deleted_meta_response.status_code == 404