Files
X-Financial/server/tests/test_orchestrator_service.py

694 lines
24 KiB
Python
Raw Normal View History

from __future__ import annotations
from collections.abc import Generator
from datetime import UTC, datetime, timedelta
from fastapi.testclient import TestClient
from sqlalchemy import create_engine, func, select
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.api.deps import get_db
from app.db.base import Base
from app.main import create_app
from app.models.agent_conversation import AgentConversation, AgentConversationMessage
from app.models.employee import Employee
from app.models.financial_record import ExpenseClaim
from app.schemas.settings import SettingsWrite
from app.services.agent_assets import AgentAssetService
from app.services.settings import SettingsService
def build_client() -> tuple[TestClient, sessionmaker[Session]]:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
app = create_app()
def override_db() -> Generator[Session, None, None]:
db = session_factory()
try:
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_db
return TestClient(app), session_factory
def test_orchestrator_routes_user_query_to_user_agent() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "客户A这个月还有多少应收",
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] == "user_agent"
assert payload["permission_level"] == "read"
assert payload["status"] == "succeeded"
assert payload["conversation_id"]
assert payload["result"]["answer"]
assert payload["result"]["suggested_actions"]
assert payload["trace_summary"]["tool_count"] >= 1
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
assert run_detail["agent"] == "user_agent"
assert run_detail["route_json"]["selected_agent"] == "user_agent"
assert run_detail["semantic_parse"]["scenario"] == "accounts_receivable"
assert run_detail["tool_calls"][0]["tool_type"] == "database"
def test_orchestrator_routes_schedule_to_hermes() -> None:
client, session_factory = build_client()
with session_factory() as db:
task = next(
item
for item in AgentAssetService(db).list_assets(asset_type="task", status="active")
if item.code == "task.hermes.daily_risk_scan"
)
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "schedule",
"task_id": task.id,
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] == "hermes"
assert payload["status"] == "succeeded"
assert payload["trace_summary"]["tool_count"] == 2
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
assert run_detail["agent"] == "hermes"
assert run_detail["route_json"]["selected_agent"] == "hermes"
assert len(run_detail["tool_calls"]) == 2
def test_orchestrator_forbidden_request_does_not_call_downstream_agent() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "帮我直接付款给供应商B",
"context_json": {"role_codes": ["user"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] is None
assert payload["permission_level"] == "forbidden"
assert payload["status"] == "blocked"
assert payload["trace_summary"]["tool_count"] == 0
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
assert run_detail["agent"] == "orchestrator"
assert run_detail["tool_calls"] == []
def test_orchestrator_approval_required_returns_confirmation_result() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "帮我安排付款给供应商B",
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] == "user_agent"
assert payload["permission_level"] == "approval_required"
assert payload["requires_confirmation"] is True
assert payload["status"] == "blocked"
assert "确认" in payload["result"]["message"]
def test_orchestrator_user_agent_draft_returns_structured_payload() -> None:
client, session_factory = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "帮我生成张三4月差旅报销草稿",
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] == "user_agent"
assert payload["status"] == "succeeded"
assert payload["result"]["draft_payload"]["confirmation_required"] is True
assert payload["result"]["review_payload"]["slot_cards"]
assert payload["result"]["draft_payload"]["claim_id"]
assert payload["result"]["draft_payload"]["claim_no"].startswith("EXP-")
assert payload["result"]["draft_payload"]["status"] == "draft"
assert payload["result"]["suggested_actions"]
with session_factory() as db:
claim = db.scalar(
select(ExpenseClaim).where(
ExpenseClaim.id == payload["result"]["draft_payload"]["claim_id"]
)
)
assert claim is not None
assert claim.claim_no == payload["result"]["draft_payload"]["claim_no"]
assert claim.status == "draft"
assert claim.items
def test_orchestrator_blocks_fourth_expense_draft_for_same_user() -> None:
client, session_factory = build_client()
user_id = "zhangsan@example.com"
with session_factory() as db:
db.add(
Employee(
employee_no="E1001",
name="张三",
email=user_id,
)
)
db.commit()
for amount, city in ((120, "上海"), (240, "北京"), (360, "深圳")):
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": user_id,
"message": f"帮我生成报销草稿,我昨天去{city}出差,交通费{amount}",
"context_json": {
"role_codes": ["finance"],
"name": "张三",
},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["result"]["draft_payload"]["claim_no"].startswith("EXP-")
blocked_response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": user_id,
"message": "帮我生成报销草稿我昨天去杭州出差交通费480元",
"context_json": {
"role_codes": ["finance"],
"name": "张三",
"review_action": "save_draft",
},
},
)
assert blocked_response.status_code == 200
blocked_payload = blocked_response.json()
assert blocked_payload["status"] == "succeeded"
assert "你当前已保存 3 个草稿" in blocked_payload["result"]["answer"]
assert blocked_payload["result"]["draft_payload"]["claim_id"] is None
assert blocked_payload["result"]["draft_payload"]["claim_no"] is None
assert blocked_payload["result"]["draft_payload"]["status"] == "blocked"
with session_factory() as db:
draft_count = db.scalar(
select(func.count())
.select_from(ExpenseClaim)
.where(ExpenseClaim.status == "draft")
)
assert draft_count == 3
def test_orchestrator_allows_existing_draft_update_when_user_already_has_three_drafts() -> None:
client, session_factory = build_client()
user_id = "lisi@example.com"
with session_factory() as db:
db.add(
Employee(
employee_no="E1002",
name="李四",
email=user_id,
)
)
db.commit()
first_response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": user_id,
"message": "帮我生成报销草稿我昨天去上海出差交通费120元",
"context_json": {
"role_codes": ["finance"],
"name": "李四",
},
},
)
assert first_response.status_code == 200
first_payload = first_response.json()
claim_id = first_payload["result"]["draft_payload"]["claim_id"]
conversation_id = first_payload["conversation_id"]
assert claim_id
assert conversation_id
for amount, city in ((240, "北京"), (360, "深圳")):
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": user_id,
"message": f"帮我生成报销草稿,我昨天去{city}出差,交通费{amount}",
"context_json": {
"role_codes": ["finance"],
"name": "李四",
},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["result"]["draft_payload"]["claim_no"].startswith("EXP-")
update_response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": user_id,
"conversation_id": conversation_id,
"message": "金额改成888元",
"context_json": {
"role_codes": ["finance"],
"name": "李四",
},
},
)
assert update_response.status_code == 200
update_payload = update_response.json()
assert update_payload["result"]["draft_payload"]["claim_id"] == claim_id
with session_factory() as db:
claim = db.scalar(select(ExpenseClaim).where(ExpenseClaim.id == claim_id))
assert claim is not None
assert float(claim.amount) == 888.0
draft_count = db.scalar(
select(func.count())
.select_from(ExpenseClaim)
.where(ExpenseClaim.employee_id == claim.employee_id)
.where(ExpenseClaim.status == "draft")
)
assert draft_count == 3
def test_orchestrator_persists_conversation_and_reuses_expense_draft_context() -> None:
client, session_factory = build_client()
first_response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "帮我生成一份差旅报销草稿我昨天去上海出差交通费680元",
"context_json": {
"role_codes": ["finance"],
"attachment_names": ["行程单.png"],
"attachment_count": 1,
"ocr_summary": "行程单金额680元",
},
},
)
assert first_response.status_code == 200
first_payload = first_response.json()
conversation_id = first_payload["conversation_id"]
claim_id = first_payload["result"]["draft_payload"]["claim_id"]
assert conversation_id
assert claim_id
second_response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"conversation_id": conversation_id,
"message": "金额改成800元",
"context_json": {
"role_codes": ["finance"],
},
},
)
assert second_response.status_code == 200
second_payload = second_response.json()
assert second_payload["conversation_id"] == conversation_id
assert second_payload["trace_summary"]["scenario"] == "expense"
assert second_payload["trace_summary"]["intent"] == "draft"
assert second_payload["result"]["draft_payload"]["claim_id"] == claim_id
with session_factory() as db:
claim = db.scalar(select(ExpenseClaim).where(ExpenseClaim.id == claim_id))
assert claim is not None
assert float(claim.amount) == 800.0
conversation = db.scalar(
select(AgentConversation).where(AgentConversation.conversation_id == conversation_id)
)
assert conversation is not None
assert conversation.draft_claim_id == claim_id
assert conversation.last_scenario == "expense"
assert conversation.last_intent == "draft"
message_count = db.scalar(
select(func.count())
.select_from(AgentConversationMessage)
.where(AgentConversationMessage.conversation_id == conversation_id)
)
assert message_count == 4
def test_orchestrator_does_not_reuse_conversation_when_user_changes() -> None:
client, session_factory = build_client()
first_response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "user_a",
"message": "帮我生成一份差旅报销草稿我昨天去上海出差交通费680元",
"context_json": {"role_codes": ["finance"]},
},
)
assert first_response.status_code == 200
first_payload = first_response.json()
first_conversation_id = first_payload["conversation_id"]
assert first_conversation_id
second_response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "user_b",
"conversation_id": first_conversation_id,
"message": "查一下本周报销金额",
"context_json": {"role_codes": ["finance"]},
},
)
assert second_response.status_code == 200
second_payload = second_response.json()
assert second_payload["conversation_id"]
assert second_payload["conversation_id"] != first_conversation_id
with session_factory() as db:
first_conversation = db.scalar(
select(AgentConversation).where(
AgentConversation.conversation_id == first_conversation_id
)
)
second_conversation = db.scalar(
select(AgentConversation).where(
AgentConversation.conversation_id == second_payload["conversation_id"]
)
)
assert first_conversation is not None
assert second_conversation is not None
assert first_conversation.user_id == "user_a"
assert second_conversation.user_id == "user_b"
def test_orchestrator_prunes_conversations_older_than_configured_retention_days() -> None:
client, session_factory = build_client()
expired_conversation_id = "conv_expired"
expired_at = datetime.now(UTC) - timedelta(days=2)
with session_factory() as db:
settings_service = SettingsService(db)
settings_payload = settings_service.get_settings_snapshot().model_dump()
settings_payload["sessionForm"]["conversationRetentionDays"] = 1
settings_service.save_settings_snapshot(SettingsWrite(**settings_payload))
conversation = AgentConversation(
conversation_id=expired_conversation_id,
user_id="expired_user",
source="user_message",
state_json={},
message_count=1,
created_at=expired_at,
updated_at=expired_at,
)
db.add(conversation)
db.flush()
db.add(
AgentConversationMessage(
conversation_id=expired_conversation_id,
role="user",
content="旧会话消息",
created_at=expired_at,
)
)
db.commit()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "fresh_user",
"message": "查一下本周报销金额",
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
with session_factory() as db:
conversation = db.scalar(
select(AgentConversation).where(
AgentConversation.conversation_id == expired_conversation_id
)
)
message_count = db.scalar(
select(func.count())
.select_from(AgentConversationMessage)
.where(AgentConversationMessage.conversation_id == expired_conversation_id)
)
assert conversation is None
assert message_count == 0
def test_orchestrator_treats_expense_narrative_as_draft_instead_of_ar_query() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "我今天去客户现场招待了客户花销了1000元",
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] == "user_agent"
assert payload["permission_level"] == "draft_write"
assert payload["status"] == "blocked"
assert payload["route_reason"] == "clarification_required"
assert payload["trace_summary"]["scenario"] == "expense"
assert payload["trace_summary"]["intent"] == "draft"
assert payload["trace_summary"]["tool_count"] == 0
assert "应收场景数据" not in payload["result"]["message"]
assert payload["result"]["message"].startswith("我先根据你当前提供的信息完成了初步识别")
review_payload = payload["result"]["review_payload"]
assert review_payload["intent_summary"].startswith("我理解你这次想报销业务招待费。")
assert review_payload["missing_slots"] == ["客户名称", "参与人员", "票据附件"]
slot_map = {item["key"]: item for item in review_payload["slot_cards"]}
assert slot_map["time_range"]["raw_value"] == "今天"
assert slot_map["location"]["value"] == "客户现场"
assert slot_map["amount"]["value"] == "1000.00元"
def test_orchestrator_can_restore_latest_user_conversation() -> None:
client, _ = build_client()
first_response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "restore_user",
"message": "帮我生成一份差旅报销草稿我昨天去上海出差交通费680元",
"context_json": {
"role_codes": ["finance"],
"attachment_names": ["行程单.png"],
"attachment_count": 1,
"ocr_summary": "行程单金额680元",
},
},
)
assert first_response.status_code == 200
first_payload = first_response.json()
second_response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "restore_user",
"conversation_id": first_payload["conversation_id"],
"message": "金额改成800元",
"context_json": {"role_codes": ["finance"]},
},
)
assert second_response.status_code == 200
restore_response = client.get(
"/api/v1/orchestrator/conversations/latest",
params={"user_id": "restore_user"},
)
assert restore_response.status_code == 200
restore_payload = restore_response.json()
assert restore_payload["found"] is True
assert restore_payload["conversation"]["conversation_id"] == first_payload["conversation_id"]
assert restore_payload["conversation"]["draft_claim_id"] == first_payload["result"]["draft_payload"]["claim_id"]
assert len(restore_payload["conversation"]["messages"]) == 4
assert restore_payload["conversation"]["messages"][0]["role"] == "user"
assert restore_payload["conversation"]["messages"][0]["message_json"]["attachment_names"] == ["行程单.png"]
assert restore_payload["conversation"]["messages"][1]["message_json"]["orchestrator_payload"]["run_id"]
def test_orchestrator_can_delete_all_user_conversations() -> None:
client, session_factory = build_client()
for message in ("查一下本周报销金额", "帮我生成差旅报销草稿"):
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "delete_user",
"message": message,
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
other_response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "other_user",
"message": "查一下供应商待付款",
"context_json": {"role_codes": ["finance"]},
},
)
assert other_response.status_code == 200
delete_response = client.delete(
"/api/v1/orchestrator/conversations",
params={"user_id": "delete_user"},
)
assert delete_response.status_code == 200
delete_payload = delete_response.json()
assert delete_payload["deleted_count"] == 2
with session_factory() as db:
deleted_count = db.scalar(
select(func.count())
.select_from(AgentConversation)
.where(AgentConversation.user_id == "delete_user")
)
other_count = db.scalar(
select(func.count())
.select_from(AgentConversation)
.where(AgentConversation.user_id == "other_user")
)
assert deleted_count == 0
assert other_count == 1
def test_orchestrator_tool_failure_is_logged_and_degraded() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "查一下本周报销金额",
"context_json": {
"role_codes": ["finance"],
"simulate_tool_failure": "database",
},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] == "user_agent"
assert payload["status"] == "succeeded"
assert payload["trace_summary"]["failed_tool_count"] == 1
assert payload["trace_summary"]["degraded"] is True
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
assert run_detail["tool_calls"][0]["status"] == "failed"
assert "simulated database failure" in run_detail["tool_calls"][0]["error_message"]
def test_orchestrator_exception_is_written_to_agent_run() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "查一下本周报销金额",
"context_json": {
"role_codes": ["finance"],
"simulate_orchestrator_exception": True,
},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["status"] == "failed"
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
assert run_detail["status"] == "failed"
assert "simulated orchestrator exception" in run_detail["error_message"]