from __future__ import annotations from collections.abc import Generator from datetime import UTC, datetime from fastapi.testclient import TestClient from sqlalchemy import create_engine, select from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.pool import StaticPool from app.api.deps import get_db from app.db.base import Base from app.main import create_app from app.models.agent_conversation import AgentConversation from app.models.employee import Employee from app.models.financial_record import ExpenseClaim from app.services import attachment_association_jobs as attachment_jobs_module def build_session_factory() -> sessionmaker[Session]: engine = create_engine( "sqlite+pysqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) Base.metadata.create_all(bind=engine) return sessionmaker(bind=engine, autoflush=False, autocommit=False) def build_client() -> tuple[TestClient, sessionmaker[Session]]: session_factory = build_session_factory() app = create_app() def override_db() -> Generator[Session, None, None]: db = session_factory() try: yield db finally: db.close() app.dependency_overrides[get_db] = override_db return TestClient(app), session_factory def seed_employee(db: Session) -> None: manager = Employee( id="steward-action-manager", employee_no="E90000", name="李总", email="leader@example.com", position="部门负责人", grade="P7", ) employee = Employee( id="steward-action-employee", employee_no="E90001", name="张三", email="zhangsan@example.com", position="实施工程师", grade="P4", manager=manager, ) db.add_all([manager, employee]) db.commit() def auth_headers() -> dict[str, str]: return { "x-auth-username": "zhangsan@example.com", "x-auth-name": "Zhang San", "x-auth-employee-no": "E90001", "x-auth-role-codes": "user", "x-auth-position": "Engineer", "x-auth-grade": "P4", "x-auth-manager-name": "Leader", } def base_application_task(requested_action: str = "save_draft") -> dict[str, object]: return { "task_id": "task_app_001", "task_type": "expense_application", "assigned_agent": "application_assistant", "title": "上海出差申请", "summary": "2026-02-20 至 2026-02-23 去上海出差,辅助国网仿生产服务器部署,火车出行。", "status": "needs_confirmation", "confidence": 0.96, "requested_action": requested_action, "ontology_fields": { "expense_type": "travel", "time_range": "2026-02-20 至 2026-02-23", "location": "上海", "reason": "辅助国网仿生产服务器部署", "transport_mode": "train", }, "missing_fields": [], "confirmation_required": requested_action == "submit", "action_steps": [], } def base_reimbursement_task() -> dict[str, object]: return { "task_id": "task_reim_001", "task_type": "reimbursement", "assigned_agent": "reimbursement_assistant", "title": "客户现场交通费报销", "summary": "2026-03-04 打车去客户现场,交通费 32 元。", "status": "needs_confirmation", "confidence": 0.9, "requested_action": "save_draft", "ontology_fields": { "expense_type": "transport", "time_range": "2026-03-04", "location": "客户现场", "reason": "客户现场沟通", "amount": "32元", "transport_mode": "taxi", }, "missing_fields": [], "confirmation_required": False, "action_steps": [], } def claim_count(db: Session) -> int: return len(db.scalars(select(ExpenseClaim)).all()) def seed_approved_application(db: Session) -> None: application = ExpenseClaim( id="application-action-approved", claim_no="AAPPROVED1", employee_id="steward-action-employee", employee_name="张三", department_id="dept-delivery", department_name="交付部", project_code=None, expense_type="travel_application", reason="辅助国网仿生产服务器部署", location="上海", amount=3000, currency="CNY", invoice_count=0, occurred_at=datetime(2026, 2, 20, tzinfo=UTC), submitted_at=None, status="approved", approval_stage="已完成", risk_flags_json=[], ) db.add(application) db.commit() def test_steward_action_executor_rejects_unknown_action_without_creating_claim() -> None: client, session_factory = build_client() with session_factory() as db: seed_employee(db) before_count = claim_count(db) response = client.post( "/api/v1/steward/actions/execute", headers=auth_headers(), json={ "action_type": "delete_all_claims", "message": "请执行未知动作", "task": base_application_task(), }, ) assert response.status_code == 200 payload = response.json() assert payload["status"] == "blocked" assert payload["action_type"] == "delete_all_claims" assert "不支持" in payload["message"] with session_factory() as db: assert claim_count(db) == before_count def test_steward_action_executor_blocks_attachment_action_without_receipts() -> None: client, session_factory = build_client() with session_factory() as db: seed_employee(db) before_count = claim_count(db) task = base_reimbursement_task() task["ontology_fields"] = { **task["ontology_fields"], "attachments": "taxi.png", } response = client.post( "/api/v1/steward/actions/execute", headers=auth_headers(), json={ "action_type": "associate_attachments", "message": "关联附件 taxi.png", "task": task, }, ) assert response.status_code == 200 payload = response.json() assert payload["status"] == "blocked" assert "receipt_id" in payload["message"] or "票据" in payload["message"] with session_factory() as db: assert claim_count(db) == before_count def test_steward_action_executor_records_pending_interrupt_in_conversation_state() -> None: client, session_factory = build_client() with session_factory() as db: seed_employee(db) response = client.post( "/api/v1/steward/actions/execute", headers=auth_headers(), json={ "action_type": "submit_application", "message": "2026-02-20 至 2026-02-23,去上海出差,辅助国网仿生产服务器部署,交通火车,直接提交", "conversation_id": "conv-action-submit", "client_trace_id": "trace-submit-pending", "task": base_application_task("submit"), "confirmed": False, "context_json": { "precheck_result": { "status": "ok", "blocking": False, } }, }, ) assert response.status_code == 200 payload = response.json() assert payload["status"] == "needs_confirmation" with session_factory() as db: conversation = db.scalar( select(AgentConversation).where( AgentConversation.conversation_id == "conv-action-submit" ) ) assert conversation is not None checkpoint = conversation.state_json["steward_action_checkpoint"] assert checkpoint["pending_interrupt"]["client_trace_id"] == "trace-submit-pending" assert checkpoint["pending_interrupt"]["action_type"] == "submit_application" assert checkpoint["actions"]["trace-submit-pending"]["status"] == "needs_confirmation" def test_steward_action_executor_reuses_checkpoint_for_duplicate_trace_without_duplicate_draft() -> None: client, session_factory = build_client() with session_factory() as db: seed_employee(db) request_payload = { "action_type": "save_application_draft", "message": "2026-02-20 至 2026-02-23,去上海出差,辅助国网仿生产服务器部署,交通火车,保存草稿", "conversation_id": "conv-action-draft", "client_trace_id": "trace-save-draft", "task": base_application_task("save_draft"), } first_response = client.post( "/api/v1/steward/actions/execute", headers=auth_headers(), json=request_payload, ) second_response = client.post( "/api/v1/steward/actions/execute", headers=auth_headers(), json=request_payload, ) assert first_response.status_code == 200 assert second_response.status_code == 200 first_payload = first_response.json() second_payload = second_response.json() assert first_payload["status"] == "succeeded" assert second_payload["status"] == "succeeded" assert ( first_payload["result_payload"]["draft_payload"]["claim_id"] == second_payload["result_payload"]["draft_payload"]["claim_id"] ) assert second_payload["result_payload"]["idempotent_replay"] is True with session_factory() as db: assert claim_count(db) == 1 def test_steward_action_executor_requires_confirmation_before_submit_side_effect() -> None: client, session_factory = build_client() with session_factory() as db: seed_employee(db) response = client.post( "/api/v1/steward/actions/execute", headers=auth_headers(), json={ "action_type": "submit_application", "message": "2026-02-20 至 2026-02-23,去上海出差,辅助国网仿生产服务器部署,交通火车,直接提交", "task": base_application_task("submit"), "confirmed": False, "context_json": { "precheck_result": { "status": "ok", "blocking": False, } }, }, ) assert response.status_code == 200 payload = response.json() assert payload["status"] == "needs_confirmation" assert payload["requires_confirmation"] is True with session_factory() as db: assert claim_count(db) == 0 def test_steward_action_executor_saves_application_draft_from_action_step() -> None: client, session_factory = build_client() with session_factory() as db: seed_employee(db) response = client.post( "/api/v1/steward/actions/execute", headers=auth_headers(), json={ "action_type": "save_application_draft", "message": "2026-02-20 至 2026-02-23,去上海出差,辅助国网仿生产服务器部署,交通火车,保存草稿", "task": base_application_task("save_draft"), }, ) assert response.status_code == 200 payload = response.json() assert payload["status"] == "succeeded" draft_payload = payload["result_payload"]["draft_payload"] assert draft_payload["draft_type"] == "expense_application" assert draft_payload["status"] == "draft" assert draft_payload["claim_no"].startswith("A") with session_factory() as db: claim = db.scalars(select(ExpenseClaim)).one() assert claim.status == "draft" assert claim.reason == "辅助国网仿生产服务器部署" def test_steward_action_executor_creates_reimbursement_draft_from_action_step() -> None: client, session_factory = build_client() with session_factory() as db: seed_employee(db) response = client.post( "/api/v1/steward/actions/execute", headers=auth_headers(), json={ "action_type": "create_reimbursement_draft", "message": "2026-03-04,打车去客户现场,交通费32元,保存草稿", "task": base_reimbursement_task(), }, ) assert response.status_code == 200 payload = response.json() assert payload["status"] == "succeeded" assert payload["result_payload"]["status"] == "draft" assert payload["result_payload"]["claim_id"] with session_factory() as db: claim = db.scalars(select(ExpenseClaim)).one() assert claim.status == "draft" assert claim.expense_type == "transport" assert claim.reason == "客户现场沟通" def test_steward_action_executor_links_application_when_creating_reimbursement_draft() -> None: client, session_factory = build_client() with session_factory() as db: seed_employee(db) seed_approved_application(db) response = client.post( "/api/v1/steward/actions/execute", headers=auth_headers(), json={ "action_type": "link_existing_application", "message": "关联申请单 AAPPROVED1,并保存报销草稿", "task": base_reimbursement_task(), "context_json": { "application_claim_id": "application-action-approved", "application_claim_no": "AAPPROVED1", "application_reason": "辅助国网仿生产服务器部署", "application_location": "上海", "application_business_time": "2026-02-20 至 2026-02-23", }, }, ) assert response.status_code == 200 payload = response.json() assert payload["status"] == "succeeded" assert payload["result_payload"]["status"] == "draft" with session_factory() as db: claims = db.scalars(select(ExpenseClaim)).all() reimbursement = next(claim for claim in claims if claim.id != "application-action-approved") assert reimbursement.status == "draft" link_flags = [ flag for flag in list(reimbursement.risk_flags_json or []) if isinstance(flag, dict) and flag.get("source") == "application_link" ] assert link_flags assert link_flags[0]["application_claim_no"] == "AAPPROVED1" def test_steward_action_executor_associates_receipt_attachments(monkeypatch) -> None: client, session_factory = build_client() with session_factory() as db: seed_employee(db) calls: list[dict[str, object]] = [] def fake_run(self, *, receipt_ids, current_user): calls.append({ "receipt_ids": list(receipt_ids), "username": current_user.username, }) return { "claim_id": "claim-associated", "claim_no": "BX-20260220-001", "uploaded_count": 2, "skipped_count": 0, } monkeypatch.setattr(attachment_jobs_module.AttachmentAssociationJobRunner, "run", fake_run) response = client.post( "/api/v1/steward/actions/execute", headers=auth_headers(), json={ "action_type": "associate_attachments", "message": "把两张火车票关联到报销草稿", "task": base_reimbursement_task(), "context_json": { "receipt_ids": ["receipt-001", "receipt-002"], }, }, ) assert response.status_code == 200 payload = response.json() assert payload["status"] == "succeeded" assert payload["result_payload"]["claim_no"] == "BX-20260220-001" assert calls == [ { "receipt_ids": ["receipt-001", "receipt-002"], "username": "zhangsan@example.com", } ]