Files
X-Financial/server/tests/test_steward_action_executor.py

456 lines
15 KiB
Python
Raw Permalink Normal View History

from __future__ import annotations
from collections.abc import Generator
from datetime import UTC, datetime
from fastapi.testclient import TestClient
from sqlalchemy import create_engine, select
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.api.deps import get_db
from app.db.base import Base
from app.main import create_app
from app.models.agent_conversation import AgentConversation
from app.models.employee import Employee
from app.models.financial_record import ExpenseClaim
from app.services import attachment_association_jobs as attachment_jobs_module
def build_session_factory() -> sessionmaker[Session]:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
return sessionmaker(bind=engine, autoflush=False, autocommit=False)
def build_client() -> tuple[TestClient, sessionmaker[Session]]:
session_factory = build_session_factory()
app = create_app()
def override_db() -> Generator[Session, None, None]:
db = session_factory()
try:
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_db
return TestClient(app), session_factory
def seed_employee(db: Session) -> None:
manager = Employee(
id="steward-action-manager",
employee_no="E90000",
name="李总",
email="leader@example.com",
position="部门负责人",
grade="P7",
)
employee = Employee(
id="steward-action-employee",
employee_no="E90001",
name="张三",
email="zhangsan@example.com",
position="实施工程师",
grade="P4",
manager=manager,
)
db.add_all([manager, employee])
db.commit()
def auth_headers() -> dict[str, str]:
return {
"x-auth-username": "zhangsan@example.com",
"x-auth-name": "Zhang San",
"x-auth-employee-no": "E90001",
"x-auth-role-codes": "user",
"x-auth-position": "Engineer",
"x-auth-grade": "P4",
"x-auth-manager-name": "Leader",
}
def base_application_task(requested_action: str = "save_draft") -> dict[str, object]:
return {
"task_id": "task_app_001",
"task_type": "expense_application",
"assigned_agent": "application_assistant",
"title": "上海出差申请",
"summary": "2026-02-20 至 2026-02-23 去上海出差,辅助国网仿生产服务器部署,火车出行。",
"status": "needs_confirmation",
"confidence": 0.96,
"requested_action": requested_action,
"ontology_fields": {
"expense_type": "travel",
"time_range": "2026-02-20 至 2026-02-23",
"location": "上海",
"reason": "辅助国网仿生产服务器部署",
"transport_mode": "train",
},
"missing_fields": [],
"confirmation_required": requested_action == "submit",
"action_steps": [],
}
def base_reimbursement_task() -> dict[str, object]:
return {
"task_id": "task_reim_001",
"task_type": "reimbursement",
"assigned_agent": "reimbursement_assistant",
"title": "客户现场交通费报销",
"summary": "2026-03-04 打车去客户现场,交通费 32 元。",
"status": "needs_confirmation",
"confidence": 0.9,
"requested_action": "save_draft",
"ontology_fields": {
"expense_type": "transport",
"time_range": "2026-03-04",
"location": "客户现场",
"reason": "客户现场沟通",
"amount": "32元",
"transport_mode": "taxi",
},
"missing_fields": [],
"confirmation_required": False,
"action_steps": [],
}
def claim_count(db: Session) -> int:
return len(db.scalars(select(ExpenseClaim)).all())
def seed_approved_application(db: Session) -> None:
application = ExpenseClaim(
id="application-action-approved",
claim_no="AAPPROVED1",
employee_id="steward-action-employee",
employee_name="张三",
department_id="dept-delivery",
department_name="交付部",
project_code=None,
expense_type="travel_application",
reason="辅助国网仿生产服务器部署",
location="上海",
amount=3000,
currency="CNY",
invoice_count=0,
occurred_at=datetime(2026, 2, 20, tzinfo=UTC),
submitted_at=None,
status="approved",
approval_stage="已完成",
risk_flags_json=[],
)
db.add(application)
db.commit()
def test_steward_action_executor_rejects_unknown_action_without_creating_claim() -> None:
client, session_factory = build_client()
with session_factory() as db:
seed_employee(db)
before_count = claim_count(db)
response = client.post(
"/api/v1/steward/actions/execute",
headers=auth_headers(),
json={
"action_type": "delete_all_claims",
"message": "请执行未知动作",
"task": base_application_task(),
},
)
assert response.status_code == 200
payload = response.json()
assert payload["status"] == "blocked"
assert payload["action_type"] == "delete_all_claims"
assert "不支持" in payload["message"]
with session_factory() as db:
assert claim_count(db) == before_count
def test_steward_action_executor_blocks_attachment_action_without_receipts() -> None:
client, session_factory = build_client()
with session_factory() as db:
seed_employee(db)
before_count = claim_count(db)
task = base_reimbursement_task()
task["ontology_fields"] = {
**task["ontology_fields"],
"attachments": "taxi.png",
}
response = client.post(
"/api/v1/steward/actions/execute",
headers=auth_headers(),
json={
"action_type": "associate_attachments",
"message": "关联附件 taxi.png",
"task": task,
},
)
assert response.status_code == 200
payload = response.json()
assert payload["status"] == "blocked"
assert "receipt_id" in payload["message"] or "票据" in payload["message"]
with session_factory() as db:
assert claim_count(db) == before_count
def test_steward_action_executor_records_pending_interrupt_in_conversation_state() -> None:
client, session_factory = build_client()
with session_factory() as db:
seed_employee(db)
response = client.post(
"/api/v1/steward/actions/execute",
headers=auth_headers(),
json={
"action_type": "submit_application",
"message": "2026-02-20 至 2026-02-23去上海出差辅助国网仿生产服务器部署交通火车直接提交",
"conversation_id": "conv-action-submit",
"client_trace_id": "trace-submit-pending",
"task": base_application_task("submit"),
"confirmed": False,
"context_json": {
"precheck_result": {
"status": "ok",
"blocking": False,
}
},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["status"] == "needs_confirmation"
with session_factory() as db:
conversation = db.scalar(
select(AgentConversation).where(
AgentConversation.conversation_id == "conv-action-submit"
)
)
assert conversation is not None
checkpoint = conversation.state_json["steward_action_checkpoint"]
assert checkpoint["pending_interrupt"]["client_trace_id"] == "trace-submit-pending"
assert checkpoint["pending_interrupt"]["action_type"] == "submit_application"
assert checkpoint["actions"]["trace-submit-pending"]["status"] == "needs_confirmation"
def test_steward_action_executor_reuses_checkpoint_for_duplicate_trace_without_duplicate_draft() -> None:
client, session_factory = build_client()
with session_factory() as db:
seed_employee(db)
request_payload = {
"action_type": "save_application_draft",
"message": "2026-02-20 至 2026-02-23去上海出差辅助国网仿生产服务器部署交通火车保存草稿",
"conversation_id": "conv-action-draft",
"client_trace_id": "trace-save-draft",
"task": base_application_task("save_draft"),
}
first_response = client.post(
"/api/v1/steward/actions/execute",
headers=auth_headers(),
json=request_payload,
)
second_response = client.post(
"/api/v1/steward/actions/execute",
headers=auth_headers(),
json=request_payload,
)
assert first_response.status_code == 200
assert second_response.status_code == 200
first_payload = first_response.json()
second_payload = second_response.json()
assert first_payload["status"] == "succeeded"
assert second_payload["status"] == "succeeded"
assert (
first_payload["result_payload"]["draft_payload"]["claim_id"]
== second_payload["result_payload"]["draft_payload"]["claim_id"]
)
assert second_payload["result_payload"]["idempotent_replay"] is True
with session_factory() as db:
assert claim_count(db) == 1
def test_steward_action_executor_requires_confirmation_before_submit_side_effect() -> None:
client, session_factory = build_client()
with session_factory() as db:
seed_employee(db)
response = client.post(
"/api/v1/steward/actions/execute",
headers=auth_headers(),
json={
"action_type": "submit_application",
"message": "2026-02-20 至 2026-02-23去上海出差辅助国网仿生产服务器部署交通火车直接提交",
"task": base_application_task("submit"),
"confirmed": False,
"context_json": {
"precheck_result": {
"status": "ok",
"blocking": False,
}
},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["status"] == "needs_confirmation"
assert payload["requires_confirmation"] is True
with session_factory() as db:
assert claim_count(db) == 0
def test_steward_action_executor_saves_application_draft_from_action_step() -> None:
client, session_factory = build_client()
with session_factory() as db:
seed_employee(db)
response = client.post(
"/api/v1/steward/actions/execute",
headers=auth_headers(),
json={
"action_type": "save_application_draft",
"message": "2026-02-20 至 2026-02-23去上海出差辅助国网仿生产服务器部署交通火车保存草稿",
"task": base_application_task("save_draft"),
},
)
assert response.status_code == 200
payload = response.json()
assert payload["status"] == "succeeded"
draft_payload = payload["result_payload"]["draft_payload"]
assert draft_payload["draft_type"] == "expense_application"
assert draft_payload["status"] == "draft"
assert draft_payload["claim_no"].startswith("A")
with session_factory() as db:
claim = db.scalars(select(ExpenseClaim)).one()
assert claim.status == "draft"
assert claim.reason == "辅助国网仿生产服务器部署"
def test_steward_action_executor_creates_reimbursement_draft_from_action_step() -> None:
client, session_factory = build_client()
with session_factory() as db:
seed_employee(db)
response = client.post(
"/api/v1/steward/actions/execute",
headers=auth_headers(),
json={
"action_type": "create_reimbursement_draft",
"message": "2026-03-04打车去客户现场交通费32元保存草稿",
"task": base_reimbursement_task(),
},
)
assert response.status_code == 200
payload = response.json()
assert payload["status"] == "succeeded"
assert payload["result_payload"]["status"] == "draft"
assert payload["result_payload"]["claim_id"]
with session_factory() as db:
claim = db.scalars(select(ExpenseClaim)).one()
assert claim.status == "draft"
assert claim.expense_type == "transport"
assert claim.reason == "客户现场沟通"
def test_steward_action_executor_links_application_when_creating_reimbursement_draft() -> None:
client, session_factory = build_client()
with session_factory() as db:
seed_employee(db)
seed_approved_application(db)
response = client.post(
"/api/v1/steward/actions/execute",
headers=auth_headers(),
json={
"action_type": "link_existing_application",
"message": "关联申请单 AAPPROVED1并保存报销草稿",
"task": base_reimbursement_task(),
"context_json": {
"application_claim_id": "application-action-approved",
"application_claim_no": "AAPPROVED1",
"application_reason": "辅助国网仿生产服务器部署",
"application_location": "上海",
"application_business_time": "2026-02-20 至 2026-02-23",
},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["status"] == "succeeded"
assert payload["result_payload"]["status"] == "draft"
with session_factory() as db:
claims = db.scalars(select(ExpenseClaim)).all()
reimbursement = next(claim for claim in claims if claim.id != "application-action-approved")
assert reimbursement.status == "draft"
link_flags = [
flag
for flag in list(reimbursement.risk_flags_json or [])
if isinstance(flag, dict) and flag.get("source") == "application_link"
]
assert link_flags
assert link_flags[0]["application_claim_no"] == "AAPPROVED1"
def test_steward_action_executor_associates_receipt_attachments(monkeypatch) -> None:
client, session_factory = build_client()
with session_factory() as db:
seed_employee(db)
calls: list[dict[str, object]] = []
def fake_run(self, *, receipt_ids, current_user):
calls.append({
"receipt_ids": list(receipt_ids),
"username": current_user.username,
})
return {
"claim_id": "claim-associated",
"claim_no": "BX-20260220-001",
"uploaded_count": 2,
"skipped_count": 0,
}
monkeypatch.setattr(attachment_jobs_module.AttachmentAssociationJobRunner, "run", fake_run)
response = client.post(
"/api/v1/steward/actions/execute",
headers=auth_headers(),
json={
"action_type": "associate_attachments",
"message": "把两张火车票关联到报销草稿",
"task": base_reimbursement_task(),
"context_json": {
"receipt_ids": ["receipt-001", "receipt-002"],
},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["status"] == "succeeded"
assert payload["result_payload"]["claim_no"] == "BX-20260220-001"
assert calls == [
{
"receipt_ids": ["receipt-001", "receipt-002"],
"username": "zhangsan@example.com",
}
]