from __future__ import annotations from datetime import UTC, date, datetime from decimal import Decimal from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.pool import StaticPool from app.api.deps import CurrentUserContext from app.db.base import Base from app.models.employee import Employee from app.models.financial_record import ExpenseClaim, ExpenseClaimItem from app.schemas.ontology import OntologyParseRequest from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead from app.schemas.reimbursement import ExpenseClaimItemCreate, ExpenseClaimItemUpdate from app.services.expense_claims import ExpenseClaimService from app.services.ontology import SemanticOntologyService from app.services.ocr import OcrService def build_claim(*, expense_type: str, location: str) -> ExpenseClaim: claim = ExpenseClaim( id="claim-1", claim_no="EXP-202605-001", employee_id="emp-1", employee_name="张三", department_id="dept-1", department_name="市场部", project_code=None, expense_type=expense_type, reason="费用报销", location=location, amount=Decimal("88.00"), currency="CNY", invoice_count=1, occurred_at=datetime(2026, 5, 13, tzinfo=UTC), submitted_at=None, status="draft", approval_stage="待提交", risk_flags_json=[], ) claim.items = [ ExpenseClaimItem( id="item-1", claim_id="claim-1", item_date=date(2026, 5, 13), item_type=expense_type, item_reason="费用报销", item_location=location, item_amount=Decimal("88.00"), invoice_id="invoice-1", ) ] return claim def build_session() -> Session: engine = create_engine( "sqlite+pysqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) Base.metadata.create_all(bind=engine) session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False) return session_factory() def test_validate_claim_for_submission_allows_office_claim_without_location() -> None: service = ExpenseClaimService.__new__(ExpenseClaimService) claim = build_claim(expense_type="office", location="待补充") issues = service._validate_claim_for_submission(claim) assert "业务地点未完善" not in issues assert not any("缺少地点" in item for item in issues) def test_validate_claim_for_submission_still_requires_location_for_travel_claim() -> None: service = ExpenseClaimService.__new__(ExpenseClaimService) claim = build_claim(expense_type="travel", location="待补充") issues = service._validate_claim_for_submission(claim) assert "业务地点未完善" in issues assert any("缺少地点" in item for item in issues) def test_resolve_expense_type_maps_office_supplies_review_value_to_office() -> None: expense_type = ExpenseClaimService._resolve_expense_type( [], context_json={ "review_form_values": { "expense_type": "办公用品" } }, ) assert expense_type == "office" def test_upsert_draft_from_ontology_defers_multi_document_association_choice() -> None: user_id = "zhangsan@example.com" with build_session() as db: employee = Employee( employee_no="E5001", name="张三", email=user_id, ) db.add(employee) db.flush() existing_claim = ExpenseClaim( claim_no="EXP-202605-010", employee_id=employee.id, employee_name="张三", department_name="市场部", project_code=None, expense_type="transport", reason="原有交通报销", location="深圳", amount=Decimal("20.00"), currency="CNY", invoice_count=1, occurred_at=datetime(2026, 5, 13, tzinfo=UTC), status="draft", approval_stage="待提交", risk_flags_json=[], ) existing_claim.items = [ ExpenseClaimItem( claim_id=existing_claim.id, item_date=date(2026, 5, 13), item_type="transport", item_reason="原有交通报销", item_location="深圳", item_amount=Decimal("20.00"), invoice_id="old-trip.png", ) ] db.add(existing_claim) db.commit() ontology = SemanticOntologyService(db).parse( OntologyParseRequest( query="我上传了两张交通票据,帮我生成报销草稿", user_id=user_id, ) ) service = ExpenseClaimService(db) result = service.upsert_draft_from_ontology( run_id=ontology.run_id, user_id=user_id, message="我上传了两张交通票据,帮我生成报销草稿", ontology=ontology, context_json={ "name": "张三", "attachment_names": ["didi-trip.png", "parking-ticket.jpg"], "attachment_count": 2, "draft_claim_id": existing_claim.id, "ocr_documents": [ { "filename": "didi-trip.png", "summary": "滴滴出行 支付金额 32 元", "text": "滴滴出行 支付金额 32 元", }, { "filename": "parking-ticket.jpg", "summary": "停车费 合计 18 元", "text": "停车费 合计 18 元", }, ], }, ) db.refresh(existing_claim) assert result["pending_association_decision"] is True assert result["association_candidate_claim_id"] == existing_claim.id assert existing_claim.invoice_count == 1 assert len(existing_claim.items) == 1 assert existing_claim.items[0].invoice_id == "old-trip.png" def test_upsert_draft_from_ontology_keeps_reason_missing_for_attachment_only_upload() -> None: user_id = "wangwu@example.com" with build_session() as db: employee = Employee( employee_no="E5003", name="王五", email=user_id, ) db.add(employee) db.commit() ontology = SemanticOntologyService(db).parse( OntologyParseRequest( query="我上传了 1 份票据,请结合附件名称给出报销建议并尽量生成草稿。", user_id=user_id, ) ) service = ExpenseClaimService(db) result = service.upsert_draft_from_ontology( run_id=ontology.run_id, user_id=user_id, message="我上传了 1 份票据,请结合附件名称给出报销建议并尽量生成草稿。\n附件名称:didi-trip.png", ontology=ontology, context_json={ "name": "王五", "user_input_text": "", "attachment_names": ["didi-trip.png"], "attachment_count": 1, "ocr_documents": [ { "filename": "didi-trip.png", "summary": "滴滴出行 支付金额 32 元", "text": "滴滴出行 支付金额 32 元", "document_type": "taxi_receipt", "scene_code": "transport", } ], }, ) claim = db.get(ExpenseClaim, result["claim_id"]) assert claim is not None assert claim.reason == "待补充" def test_upsert_draft_from_ontology_supports_link_or_create_for_multi_documents() -> None: user_id = "lisi@example.com" with build_session() as db: employee = Employee( employee_no="E5002", name="李四", email=user_id, ) db.add(employee) db.flush() existing_claim = ExpenseClaim( claim_no="EXP-202605-011", employee_id=employee.id, employee_name="李四", department_name="销售部", project_code=None, expense_type="transport", reason="原有交通报销", location="上海", amount=Decimal("20.00"), currency="CNY", invoice_count=1, occurred_at=datetime(2026, 5, 13, tzinfo=UTC), status="draft", approval_stage="待提交", risk_flags_json=[], ) existing_claim.items = [ ExpenseClaimItem( claim_id=existing_claim.id, item_date=date(2026, 5, 13), item_type="transport", item_reason="原有交通报销", item_location="上海", item_amount=Decimal("20.00"), invoice_id="existing.png", ) ] db.add(existing_claim) db.commit() ontology = SemanticOntologyService(db).parse( OntologyParseRequest( query="我上传了两张交通票据,帮我生成报销草稿", user_id=user_id, ) ) service = ExpenseClaimService(db) context_json = { "name": "李四", "attachment_names": ["didi-trip.png", "parking-ticket.jpg"], "attachment_count": 2, "draft_claim_id": existing_claim.id, "ocr_documents": [ { "filename": "didi-trip.png", "summary": "滴滴出行", "text": "滴滴出行 支付金额 32.50 元", "document_type": "taxi_receipt", "scene_code": "transport", "document_fields": [{"key": "amount", "label": "支付金额", "value": "32.50"}], }, { "filename": "parking-ticket.jpg", "summary": "停车票", "text": "停车费 合计 18 元", "document_type": "parking_toll_receipt", "scene_code": "transport", "document_fields": [{"key": "total_amount", "label": "合计金额", "value": "18"}], }, ], } link_result = service.upsert_draft_from_ontology( run_id=ontology.run_id, user_id=user_id, message="把这两张票据关联到已有草稿", ontology=ontology, context_json={ **context_json, "review_action": "link_to_existing_draft", }, ) db.refresh(existing_claim) assert link_result["claim_id"] == existing_claim.id assert existing_claim.invoice_count == 3 assert len(existing_claim.items) == 3 assert float(existing_claim.amount) == 70.5 create_result = service.upsert_draft_from_ontology( run_id=f"{ontology.run_id}-new", user_id=user_id, message="单独新建一张报销单", ontology=ontology, context_json={ **context_json, "review_action": "create_new_claim_from_documents", }, ) assert create_result["claim_id"] != existing_claim.id new_claim = db.get(ExpenseClaim, create_result["claim_id"]) assert new_claim is not None assert new_claim.invoice_count == 2 assert len(new_claim.items) == 2 assert float(new_claim.amount) == 50.5 def test_generate_claim_no_uses_max_suffix_instead_of_count() -> None: with build_session() as db: db.add_all( [ ExpenseClaim( claim_no="EXP-202605-001", employee_name="张三", department_name="市场部", project_code=None, expense_type="transport", reason="交通报销", location="深圳", amount=Decimal("10.00"), currency="CNY", invoice_count=1, occurred_at=datetime(2026, 5, 10, tzinfo=UTC), status="draft", approval_stage="待提交", risk_flags_json=[], ), ExpenseClaim( claim_no="EXP-202605-003", employee_name="李四", department_name="销售部", project_code=None, expense_type="transport", reason="交通报销", location="上海", amount=Decimal("20.00"), currency="CNY", invoice_count=1, occurred_at=datetime(2026, 5, 11, tzinfo=UTC), status="submitted", approval_stage="审批中", risk_flags_json=[], ), ] ) db.commit() service = ExpenseClaimService(db) assert service._generate_claim_no(datetime(2026, 5, 14, tzinfo=UTC)) == "EXP-202605-004" def test_upsert_draft_from_ontology_retries_claim_no_conflict() -> None: user_id = "zhaoliu-claimno@example.com" with build_session() as db: employee = Employee( employee_no="E5006", name="赵六", email=user_id, ) db.add(employee) db.flush() db.add( ExpenseClaim( claim_no="EXP-202605-004", employee_name="历史单据", department_name="财务部", project_code=None, expense_type="other", reason="历史草稿", location="北京", amount=Decimal("0.00"), currency="CNY", invoice_count=0, occurred_at=datetime(2026, 5, 12, tzinfo=UTC), status="submitted", approval_stage="审批中", risk_flags_json=[], ) ) db.commit() ontology = SemanticOntologyService(db).parse( OntologyParseRequest( query="帮我生成报销草稿,我昨天交通费 13.4 元", user_id=user_id, ) ) service = ExpenseClaimService(db) generated_claim_nos = iter(["EXP-202605-004", "EXP-202605-005"]) service._generate_claim_no = lambda occurred_at: next(generated_claim_nos) result = service.upsert_draft_from_ontology( run_id=ontology.run_id, user_id=user_id, message="帮我生成报销草稿,我昨天交通费 13.4 元", ontology=ontology, context_json={ "name": "赵六", "user_input_text": "帮我生成报销草稿,我昨天交通费 13.4 元", }, ) created_claim = db.get(ExpenseClaim, result["claim_id"]) assert created_claim is not None assert created_claim.claim_no == "EXP-202605-005" assert result["claim_no"] == "EXP-202605-005" def test_create_claim_item_adds_blank_draft_row_without_forcing_attachment() -> None: current_user = CurrentUserContext( username="emp-1", name="张三", role_codes=[], is_admin=False, ) with build_session() as db: claim = build_claim(expense_type="office", location="深圳南山") db.add(claim) db.commit() service = ExpenseClaimService(db) updated = service.create_claim_item( claim_id=claim.id, payload=ExpenseClaimItemCreate(), current_user=current_user, ) assert updated is not None assert len(updated.items) == 2 assert updated.amount == Decimal("88.00") assert updated.invoice_count == 1 new_item = next(item for item in updated.items if item.id != "item-1") assert new_item.item_type == "office" assert new_item.item_reason == "" assert new_item.item_location == "" assert new_item.item_amount == Decimal("0.00") assert new_item.invoice_id is None def test_update_claim_item_reanalyzes_existing_attachment(monkeypatch, tmp_path) -> None: current_user = CurrentUserContext( username="emp-1", name="张三", role_codes=[], is_admin=False, ) def fake_recognize( self, files: list[tuple[str, bytes, str | None]], ) -> OcrRecognizeBatchRead: return OcrRecognizeBatchRead( total_file_count=1, success_count=1, documents=[ OcrRecognizeDocumentRead( filename="office-note.png", media_type="image/png", text="办公用品发票 金额88元 2026-05-13", summary="识别到办公用品发票,金额 88 元。", avg_score=0.98, line_count=1, page_count=1, warnings=[], ) ], ) monkeypatch.setattr(OcrService, "recognize_files", fake_recognize) monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path) with build_session() as db: claim = build_claim(expense_type="office", location="深圳南山") claim.invoice_count = 0 claim.items[0].invoice_id = None claim.items[0].item_reason = "办公用品采购" db.add(claim) db.commit() service = ExpenseClaimService(db) service.upload_claim_item_attachment( claim_id=claim.id, item_id=claim.items[0].id, filename="office-note.png", content=b"fake-image-bytes", media_type="image/png", current_user=current_user, ) uploaded_meta = service.get_claim_item_attachment_meta( claim_id=claim.id, item_id=claim.items[0].id, current_user=current_user, ) assert uploaded_meta is not None assert uploaded_meta["preview_kind"] == "image" assert uploaded_meta["preview_url"].endswith( f"/reimbursements/claims/{claim.id}/items/{claim.items[0].id}/attachment/preview" ) assert uploaded_meta["analysis"]["severity"] == "pass" assert uploaded_meta["document_info"]["document_type"] == "office_invoice" assert uploaded_meta["requirement_check"]["matches"] is True updated = service.update_claim_item( claim_id=claim.id, item_id=claim.items[0].id, payload=ExpenseClaimItemUpdate( item_type="transport", item_reason="打车报销", ), current_user=current_user, ) assert updated is not None assert any(flag.get("source") == "attachment_analysis" for flag in updated.risk_flags_json) refreshed_meta = service.get_claim_item_attachment_meta( claim_id=claim.id, item_id=claim.items[0].id, current_user=current_user, ) assert refreshed_meta is not None assert refreshed_meta["analysis"]["severity"] == "high" assert refreshed_meta["requirement_check"]["matches"] is False assert any("附件类型要求" in point for point in refreshed_meta["analysis"]["points"]) def test_delete_claim_item_removes_row_and_attachment_files(monkeypatch, tmp_path) -> None: current_user = CurrentUserContext( username="emp-1", name="张三", role_codes=[], is_admin=False, ) def fake_recognize( self, files: list[tuple[str, bytes, str | None]], ) -> OcrRecognizeBatchRead: return OcrRecognizeBatchRead( total_file_count=1, success_count=1, documents=[ OcrRecognizeDocumentRead( filename="office-note.png", media_type="image/png", text="办公用品发票 金额88元 2026-05-13", summary="识别到办公用品发票,金额 88 元。", avg_score=0.98, line_count=1, page_count=1, warnings=[], ) ], ) monkeypatch.setattr(OcrService, "recognize_files", fake_recognize) monkeypatch.setattr(ExpenseClaimService, "_get_attachment_storage_root", lambda self: tmp_path) with build_session() as db: claim = build_claim(expense_type="office", location="深圳南山") claim.invoice_count = 0 claim.items[0].invoice_id = None claim.items[0].item_reason = "办公用品采购" db.add(claim) db.commit() service = ExpenseClaimService(db) upload_payload = service.upload_claim_item_attachment( claim_id=claim.id, item_id=claim.items[0].id, filename="office-note.png", content=b"fake-image-bytes", media_type="image/png", current_user=current_user, ) assert upload_payload is not None attachment_root = tmp_path / claim.id / claim.items[0].id assert attachment_root.exists() delete_payload = service.delete_claim_item( claim_id=claim.id, item_id=claim.items[0].id, current_user=current_user, ) assert delete_payload is not None assert delete_payload["claim_id"] == claim.id refreshed_claim = service.get_claim(claim.id, current_user) assert refreshed_claim is not None assert refreshed_claim.items == [] assert refreshed_claim.amount == Decimal("0.00") assert refreshed_claim.invoice_count == 0 assert not attachment_root.exists() def test_list_claims_scopes_to_current_user_id_even_when_names_duplicate() -> None: current_user = CurrentUserContext( username="zhangsan1@example.com", name="张三", role_codes=["manager"], is_admin=False, ) with build_session() as db: employee_a = Employee( employee_no="E2001", name="张三", email="zhangsan1@example.com", ) employee_b = Employee( employee_no="E2002", name="张三", email="zhangsan2@example.com", ) db.add_all([employee_a, employee_b]) db.flush() db.add_all( [ ExpenseClaim( claim_no="EXP-DUP-001", employee_id=employee_a.id, employee_name="张三", department_name="市场部", project_code="PRJ-A", expense_type="travel", reason="本人报销", location="上海", amount=Decimal("120.00"), currency="CNY", invoice_count=1, occurred_at=datetime(2026, 5, 12, 9, 0, tzinfo=UTC), submitted_at=datetime(2026, 5, 12, 10, 0, tzinfo=UTC), status="submitted", approval_stage="finance_review", risk_flags_json=[], ), ExpenseClaim( claim_no="EXP-DUP-002", employee_id=employee_b.id, employee_name="张三", department_name="销售部", project_code="PRJ-B", expense_type="meal", reason="他人报销", location="杭州", amount=Decimal("300.00"), currency="CNY", invoice_count=1, occurred_at=datetime(2026, 5, 12, 12, 0, tzinfo=UTC), submitted_at=datetime(2026, 5, 12, 13, 0, tzinfo=UTC), status="approved", approval_stage="completed", risk_flags_json=[], ), ] ) db.commit() claims = ExpenseClaimService(db).list_claims(current_user) assert len(claims) == 1 assert claims[0].claim_no == "EXP-DUP-001" def test_list_claims_allows_finance_to_view_all_records() -> None: current_user = CurrentUserContext( username="finance@example.com", name="财务", role_codes=["finance"], is_admin=False, ) with build_session() as db: db.add_all( [ ExpenseClaim( claim_no="EXP-FIN-101", employee_name="甲", department_name="A部", project_code="PRJ-A", expense_type="travel", reason="A 报销", location="上海", amount=Decimal("120.00"), currency="CNY", invoice_count=1, occurred_at=datetime(2026, 5, 11, 9, 0, tzinfo=UTC), submitted_at=datetime(2026, 5, 11, 10, 0, tzinfo=UTC), status="submitted", approval_stage="finance_review", risk_flags_json=[], ), ExpenseClaim( claim_no="EXP-FIN-102", employee_name="乙", department_name="B部", project_code="PRJ-B", expense_type="meal", reason="B 报销", location="杭州", amount=Decimal("300.00"), currency="CNY", invoice_count=1, occurred_at=datetime(2026, 5, 11, 12, 0, tzinfo=UTC), submitted_at=datetime(2026, 5, 11, 13, 0, tzinfo=UTC), status="approved", approval_stage="completed", risk_flags_json=[], ), ] ) db.commit() claims = ExpenseClaimService(db).list_claims(current_user) assert len(claims) == 2 assert {claim.claim_no for claim in claims} == {"EXP-FIN-101", "EXP-FIN-102"}