Files
X-Financial/server/tests/test_steward_action_executor.py
caoxiaozhu 5311c99d69 refactor(server): steward 决策链路改用 LangGraph 编排
- 新增 StewardGraphPlannerService,用 LangGraph 状态图编排意图识别→流程判断→模型/规则分支→兜底,替代原 planner 内线性调用
- 新增 StewardGraphRuntimeService 编排运行时决策与槽位决策;StewardActionContracts/Executor 统一动作合约与执行
- steward_intent_agent/application_fact_resolver/runtime_chat 适配图执行器,config 暴露图相关开关
- pyproject/uv.lock 新增 langgraph 依赖
- 新增 graph_planner/graph_runtime/action_executor 测试,更新 intent_agent/planner/fact_resolver/runtime_chat/reimbursement 测试
2026-06-24 21:58:35 +08:00

456 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
from collections.abc import Generator
from datetime import UTC, datetime
from fastapi.testclient import TestClient
from sqlalchemy import create_engine, select
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.api.deps import get_db
from app.db.base import Base
from app.main import create_app
from app.models.agent_conversation import AgentConversation
from app.models.employee import Employee
from app.models.financial_record import ExpenseClaim
from app.services import attachment_association_jobs as attachment_jobs_module
def build_session_factory() -> sessionmaker[Session]:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
return sessionmaker(bind=engine, autoflush=False, autocommit=False)
def build_client() -> tuple[TestClient, sessionmaker[Session]]:
session_factory = build_session_factory()
app = create_app()
def override_db() -> Generator[Session, None, None]:
db = session_factory()
try:
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_db
return TestClient(app), session_factory
def seed_employee(db: Session) -> None:
manager = Employee(
id="steward-action-manager",
employee_no="E90000",
name="李总",
email="leader@example.com",
position="部门负责人",
grade="P7",
)
employee = Employee(
id="steward-action-employee",
employee_no="E90001",
name="张三",
email="zhangsan@example.com",
position="实施工程师",
grade="P4",
manager=manager,
)
db.add_all([manager, employee])
db.commit()
def auth_headers() -> dict[str, str]:
return {
"x-auth-username": "zhangsan@example.com",
"x-auth-name": "Zhang San",
"x-auth-employee-no": "E90001",
"x-auth-role-codes": "user",
"x-auth-position": "Engineer",
"x-auth-grade": "P4",
"x-auth-manager-name": "Leader",
}
def base_application_task(requested_action: str = "save_draft") -> dict[str, object]:
return {
"task_id": "task_app_001",
"task_type": "expense_application",
"assigned_agent": "application_assistant",
"title": "上海出差申请",
"summary": "2026-02-20 至 2026-02-23 去上海出差,辅助国网仿生产服务器部署,火车出行。",
"status": "needs_confirmation",
"confidence": 0.96,
"requested_action": requested_action,
"ontology_fields": {
"expense_type": "travel",
"time_range": "2026-02-20 至 2026-02-23",
"location": "上海",
"reason": "辅助国网仿生产服务器部署",
"transport_mode": "train",
},
"missing_fields": [],
"confirmation_required": requested_action == "submit",
"action_steps": [],
}
def base_reimbursement_task() -> dict[str, object]:
return {
"task_id": "task_reim_001",
"task_type": "reimbursement",
"assigned_agent": "reimbursement_assistant",
"title": "客户现场交通费报销",
"summary": "2026-03-04 打车去客户现场,交通费 32 元。",
"status": "needs_confirmation",
"confidence": 0.9,
"requested_action": "save_draft",
"ontology_fields": {
"expense_type": "transport",
"time_range": "2026-03-04",
"location": "客户现场",
"reason": "客户现场沟通",
"amount": "32元",
"transport_mode": "taxi",
},
"missing_fields": [],
"confirmation_required": False,
"action_steps": [],
}
def claim_count(db: Session) -> int:
return len(db.scalars(select(ExpenseClaim)).all())
def seed_approved_application(db: Session) -> None:
application = ExpenseClaim(
id="application-action-approved",
claim_no="AAPPROVED1",
employee_id="steward-action-employee",
employee_name="张三",
department_id="dept-delivery",
department_name="交付部",
project_code=None,
expense_type="travel_application",
reason="辅助国网仿生产服务器部署",
location="上海",
amount=3000,
currency="CNY",
invoice_count=0,
occurred_at=datetime(2026, 2, 20, tzinfo=UTC),
submitted_at=None,
status="approved",
approval_stage="已完成",
risk_flags_json=[],
)
db.add(application)
db.commit()
def test_steward_action_executor_rejects_unknown_action_without_creating_claim() -> None:
client, session_factory = build_client()
with session_factory() as db:
seed_employee(db)
before_count = claim_count(db)
response = client.post(
"/api/v1/steward/actions/execute",
headers=auth_headers(),
json={
"action_type": "delete_all_claims",
"message": "请执行未知动作",
"task": base_application_task(),
},
)
assert response.status_code == 200
payload = response.json()
assert payload["status"] == "blocked"
assert payload["action_type"] == "delete_all_claims"
assert "不支持" in payload["message"]
with session_factory() as db:
assert claim_count(db) == before_count
def test_steward_action_executor_blocks_attachment_action_without_receipts() -> None:
client, session_factory = build_client()
with session_factory() as db:
seed_employee(db)
before_count = claim_count(db)
task = base_reimbursement_task()
task["ontology_fields"] = {
**task["ontology_fields"],
"attachments": "taxi.png",
}
response = client.post(
"/api/v1/steward/actions/execute",
headers=auth_headers(),
json={
"action_type": "associate_attachments",
"message": "关联附件 taxi.png",
"task": task,
},
)
assert response.status_code == 200
payload = response.json()
assert payload["status"] == "blocked"
assert "receipt_id" in payload["message"] or "票据" in payload["message"]
with session_factory() as db:
assert claim_count(db) == before_count
def test_steward_action_executor_records_pending_interrupt_in_conversation_state() -> None:
client, session_factory = build_client()
with session_factory() as db:
seed_employee(db)
response = client.post(
"/api/v1/steward/actions/execute",
headers=auth_headers(),
json={
"action_type": "submit_application",
"message": "2026-02-20 至 2026-02-23去上海出差辅助国网仿生产服务器部署交通火车直接提交",
"conversation_id": "conv-action-submit",
"client_trace_id": "trace-submit-pending",
"task": base_application_task("submit"),
"confirmed": False,
"context_json": {
"precheck_result": {
"status": "ok",
"blocking": False,
}
},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["status"] == "needs_confirmation"
with session_factory() as db:
conversation = db.scalar(
select(AgentConversation).where(
AgentConversation.conversation_id == "conv-action-submit"
)
)
assert conversation is not None
checkpoint = conversation.state_json["steward_action_checkpoint"]
assert checkpoint["pending_interrupt"]["client_trace_id"] == "trace-submit-pending"
assert checkpoint["pending_interrupt"]["action_type"] == "submit_application"
assert checkpoint["actions"]["trace-submit-pending"]["status"] == "needs_confirmation"
def test_steward_action_executor_reuses_checkpoint_for_duplicate_trace_without_duplicate_draft() -> None:
client, session_factory = build_client()
with session_factory() as db:
seed_employee(db)
request_payload = {
"action_type": "save_application_draft",
"message": "2026-02-20 至 2026-02-23去上海出差辅助国网仿生产服务器部署交通火车保存草稿",
"conversation_id": "conv-action-draft",
"client_trace_id": "trace-save-draft",
"task": base_application_task("save_draft"),
}
first_response = client.post(
"/api/v1/steward/actions/execute",
headers=auth_headers(),
json=request_payload,
)
second_response = client.post(
"/api/v1/steward/actions/execute",
headers=auth_headers(),
json=request_payload,
)
assert first_response.status_code == 200
assert second_response.status_code == 200
first_payload = first_response.json()
second_payload = second_response.json()
assert first_payload["status"] == "succeeded"
assert second_payload["status"] == "succeeded"
assert (
first_payload["result_payload"]["draft_payload"]["claim_id"]
== second_payload["result_payload"]["draft_payload"]["claim_id"]
)
assert second_payload["result_payload"]["idempotent_replay"] is True
with session_factory() as db:
assert claim_count(db) == 1
def test_steward_action_executor_requires_confirmation_before_submit_side_effect() -> None:
client, session_factory = build_client()
with session_factory() as db:
seed_employee(db)
response = client.post(
"/api/v1/steward/actions/execute",
headers=auth_headers(),
json={
"action_type": "submit_application",
"message": "2026-02-20 至 2026-02-23去上海出差辅助国网仿生产服务器部署交通火车直接提交",
"task": base_application_task("submit"),
"confirmed": False,
"context_json": {
"precheck_result": {
"status": "ok",
"blocking": False,
}
},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["status"] == "needs_confirmation"
assert payload["requires_confirmation"] is True
with session_factory() as db:
assert claim_count(db) == 0
def test_steward_action_executor_saves_application_draft_from_action_step() -> None:
client, session_factory = build_client()
with session_factory() as db:
seed_employee(db)
response = client.post(
"/api/v1/steward/actions/execute",
headers=auth_headers(),
json={
"action_type": "save_application_draft",
"message": "2026-02-20 至 2026-02-23去上海出差辅助国网仿生产服务器部署交通火车保存草稿",
"task": base_application_task("save_draft"),
},
)
assert response.status_code == 200
payload = response.json()
assert payload["status"] == "succeeded"
draft_payload = payload["result_payload"]["draft_payload"]
assert draft_payload["draft_type"] == "expense_application"
assert draft_payload["status"] == "draft"
assert draft_payload["claim_no"].startswith("A")
with session_factory() as db:
claim = db.scalars(select(ExpenseClaim)).one()
assert claim.status == "draft"
assert claim.reason == "辅助国网仿生产服务器部署"
def test_steward_action_executor_creates_reimbursement_draft_from_action_step() -> None:
client, session_factory = build_client()
with session_factory() as db:
seed_employee(db)
response = client.post(
"/api/v1/steward/actions/execute",
headers=auth_headers(),
json={
"action_type": "create_reimbursement_draft",
"message": "2026-03-04打车去客户现场交通费32元保存草稿",
"task": base_reimbursement_task(),
},
)
assert response.status_code == 200
payload = response.json()
assert payload["status"] == "succeeded"
assert payload["result_payload"]["status"] == "draft"
assert payload["result_payload"]["claim_id"]
with session_factory() as db:
claim = db.scalars(select(ExpenseClaim)).one()
assert claim.status == "draft"
assert claim.expense_type == "transport"
assert claim.reason == "客户现场沟通"
def test_steward_action_executor_links_application_when_creating_reimbursement_draft() -> None:
client, session_factory = build_client()
with session_factory() as db:
seed_employee(db)
seed_approved_application(db)
response = client.post(
"/api/v1/steward/actions/execute",
headers=auth_headers(),
json={
"action_type": "link_existing_application",
"message": "关联申请单 AAPPROVED1并保存报销草稿",
"task": base_reimbursement_task(),
"context_json": {
"application_claim_id": "application-action-approved",
"application_claim_no": "AAPPROVED1",
"application_reason": "辅助国网仿生产服务器部署",
"application_location": "上海",
"application_business_time": "2026-02-20 至 2026-02-23",
},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["status"] == "succeeded"
assert payload["result_payload"]["status"] == "draft"
with session_factory() as db:
claims = db.scalars(select(ExpenseClaim)).all()
reimbursement = next(claim for claim in claims if claim.id != "application-action-approved")
assert reimbursement.status == "draft"
link_flags = [
flag
for flag in list(reimbursement.risk_flags_json or [])
if isinstance(flag, dict) and flag.get("source") == "application_link"
]
assert link_flags
assert link_flags[0]["application_claim_no"] == "AAPPROVED1"
def test_steward_action_executor_associates_receipt_attachments(monkeypatch) -> None:
client, session_factory = build_client()
with session_factory() as db:
seed_employee(db)
calls: list[dict[str, object]] = []
def fake_run(self, *, receipt_ids, current_user):
calls.append({
"receipt_ids": list(receipt_ids),
"username": current_user.username,
})
return {
"claim_id": "claim-associated",
"claim_no": "BX-20260220-001",
"uploaded_count": 2,
"skipped_count": 0,
}
monkeypatch.setattr(attachment_jobs_module.AttachmentAssociationJobRunner, "run", fake_run)
response = client.post(
"/api/v1/steward/actions/execute",
headers=auth_headers(),
json={
"action_type": "associate_attachments",
"message": "把两张火车票关联到报销草稿",
"task": base_reimbursement_task(),
"context_json": {
"receipt_ids": ["receipt-001", "receipt-002"],
},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["status"] == "succeeded"
assert payload["result_payload"]["claim_no"] == "BX-20260220-001"
assert calls == [
{
"receipt_ids": ["receipt-001", "receipt-002"],
"username": "zhangsan@example.com",
}
]