Files
X-Financial/server/tests/test_orchestrator_service.py
caoxiaozhu e53c0aa5d1 test(backend): update service tests
- test_orchestrator_service.py: update orchestrator service tests
- test_settings_persistence.py: update settings persistence tests
- test_user_agent_service.py: update user agent service tests
2026-05-12 06:37:59 +00:00

539 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
from collections.abc import Generator
from datetime import UTC, datetime, timedelta
from fastapi.testclient import TestClient
from sqlalchemy import create_engine, func, select
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.api.deps import get_db
from app.db.base import Base
from app.main import create_app
from app.models.agent_conversation import AgentConversation, AgentConversationMessage
from app.models.financial_record import ExpenseClaim
from app.schemas.settings import SettingsWrite
from app.services.agent_assets import AgentAssetService
from app.services.settings import SettingsService
def build_client() -> tuple[TestClient, sessionmaker[Session]]:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
app = create_app()
def override_db() -> Generator[Session, None, None]:
db = session_factory()
try:
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_db
return TestClient(app), session_factory
def test_orchestrator_routes_user_query_to_user_agent() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "客户A这个月还有多少应收",
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] == "user_agent"
assert payload["permission_level"] == "read"
assert payload["status"] == "succeeded"
assert payload["conversation_id"]
assert payload["result"]["answer"]
assert payload["result"]["suggested_actions"]
assert payload["trace_summary"]["tool_count"] >= 1
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
assert run_detail["agent"] == "user_agent"
assert run_detail["route_json"]["selected_agent"] == "user_agent"
assert run_detail["semantic_parse"]["scenario"] == "accounts_receivable"
assert run_detail["tool_calls"][0]["tool_type"] == "database"
def test_orchestrator_routes_schedule_to_hermes() -> None:
client, session_factory = build_client()
with session_factory() as db:
task = next(
item
for item in AgentAssetService(db).list_assets(asset_type="task", status="active")
if item.code == "task.hermes.daily_risk_scan"
)
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "schedule",
"task_id": task.id,
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] == "hermes"
assert payload["status"] == "succeeded"
assert payload["trace_summary"]["tool_count"] == 2
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
assert run_detail["agent"] == "hermes"
assert run_detail["route_json"]["selected_agent"] == "hermes"
assert len(run_detail["tool_calls"]) == 2
def test_orchestrator_forbidden_request_does_not_call_downstream_agent() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "帮我直接付款给供应商B",
"context_json": {"role_codes": ["user"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] is None
assert payload["permission_level"] == "forbidden"
assert payload["status"] == "blocked"
assert payload["trace_summary"]["tool_count"] == 0
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
assert run_detail["agent"] == "orchestrator"
assert run_detail["tool_calls"] == []
def test_orchestrator_approval_required_returns_confirmation_result() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "帮我安排付款给供应商B",
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] == "user_agent"
assert payload["permission_level"] == "approval_required"
assert payload["requires_confirmation"] is True
assert payload["status"] == "blocked"
assert "确认" in payload["result"]["message"]
def test_orchestrator_user_agent_draft_returns_structured_payload() -> None:
client, session_factory = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "帮我生成张三4月差旅报销草稿",
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] == "user_agent"
assert payload["status"] == "succeeded"
assert payload["result"]["draft_payload"]["confirmation_required"] is True
assert payload["result"]["review_payload"]["slot_cards"]
assert payload["result"]["draft_payload"]["claim_id"]
assert payload["result"]["draft_payload"]["claim_no"].startswith("EXP-")
assert payload["result"]["draft_payload"]["status"] == "draft"
assert payload["result"]["suggested_actions"]
with session_factory() as db:
claim = db.scalar(
select(ExpenseClaim).where(
ExpenseClaim.id == payload["result"]["draft_payload"]["claim_id"]
)
)
assert claim is not None
assert claim.claim_no == payload["result"]["draft_payload"]["claim_no"]
assert claim.status == "draft"
assert claim.items
def test_orchestrator_persists_conversation_and_reuses_expense_draft_context() -> None:
client, session_factory = build_client()
first_response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "帮我生成一份差旅报销草稿我昨天去上海出差交通费680元",
"context_json": {
"role_codes": ["finance"],
"attachment_names": ["行程单.png"],
"attachment_count": 1,
"ocr_summary": "行程单金额680元",
},
},
)
assert first_response.status_code == 200
first_payload = first_response.json()
conversation_id = first_payload["conversation_id"]
claim_id = first_payload["result"]["draft_payload"]["claim_id"]
assert conversation_id
assert claim_id
second_response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"conversation_id": conversation_id,
"message": "金额改成800元",
"context_json": {
"role_codes": ["finance"],
},
},
)
assert second_response.status_code == 200
second_payload = second_response.json()
assert second_payload["conversation_id"] == conversation_id
assert second_payload["trace_summary"]["scenario"] == "expense"
assert second_payload["trace_summary"]["intent"] == "draft"
assert second_payload["result"]["draft_payload"]["claim_id"] == claim_id
with session_factory() as db:
claim = db.scalar(select(ExpenseClaim).where(ExpenseClaim.id == claim_id))
assert claim is not None
assert float(claim.amount) == 800.0
conversation = db.scalar(
select(AgentConversation).where(AgentConversation.conversation_id == conversation_id)
)
assert conversation is not None
assert conversation.draft_claim_id == claim_id
assert conversation.last_scenario == "expense"
assert conversation.last_intent == "draft"
message_count = db.scalar(
select(func.count())
.select_from(AgentConversationMessage)
.where(AgentConversationMessage.conversation_id == conversation_id)
)
assert message_count == 4
def test_orchestrator_does_not_reuse_conversation_when_user_changes() -> None:
client, session_factory = build_client()
first_response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "user_a",
"message": "帮我生成一份差旅报销草稿我昨天去上海出差交通费680元",
"context_json": {"role_codes": ["finance"]},
},
)
assert first_response.status_code == 200
first_payload = first_response.json()
first_conversation_id = first_payload["conversation_id"]
assert first_conversation_id
second_response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "user_b",
"conversation_id": first_conversation_id,
"message": "查一下本周报销金额",
"context_json": {"role_codes": ["finance"]},
},
)
assert second_response.status_code == 200
second_payload = second_response.json()
assert second_payload["conversation_id"]
assert second_payload["conversation_id"] != first_conversation_id
with session_factory() as db:
first_conversation = db.scalar(
select(AgentConversation).where(
AgentConversation.conversation_id == first_conversation_id
)
)
second_conversation = db.scalar(
select(AgentConversation).where(
AgentConversation.conversation_id == second_payload["conversation_id"]
)
)
assert first_conversation is not None
assert second_conversation is not None
assert first_conversation.user_id == "user_a"
assert second_conversation.user_id == "user_b"
def test_orchestrator_prunes_conversations_older_than_configured_retention_days() -> None:
client, session_factory = build_client()
expired_conversation_id = "conv_expired"
expired_at = datetime.now(UTC) - timedelta(days=2)
with session_factory() as db:
settings_service = SettingsService(db)
settings_payload = settings_service.get_settings_snapshot().model_dump()
settings_payload["sessionForm"]["conversationRetentionDays"] = 1
settings_service.save_settings_snapshot(SettingsWrite(**settings_payload))
conversation = AgentConversation(
conversation_id=expired_conversation_id,
user_id="expired_user",
source="user_message",
state_json={},
message_count=1,
created_at=expired_at,
updated_at=expired_at,
)
db.add(conversation)
db.flush()
db.add(
AgentConversationMessage(
conversation_id=expired_conversation_id,
role="user",
content="旧会话消息",
created_at=expired_at,
)
)
db.commit()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "fresh_user",
"message": "查一下本周报销金额",
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
with session_factory() as db:
conversation = db.scalar(
select(AgentConversation).where(
AgentConversation.conversation_id == expired_conversation_id
)
)
message_count = db.scalar(
select(func.count())
.select_from(AgentConversationMessage)
.where(AgentConversationMessage.conversation_id == expired_conversation_id)
)
assert conversation is None
assert message_count == 0
def test_orchestrator_treats_expense_narrative_as_draft_instead_of_ar_query() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "我今天去客户现场招待了客户花销了1000元",
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] == "user_agent"
assert payload["permission_level"] == "draft_write"
assert payload["status"] == "blocked"
assert payload["route_reason"] == "clarification_required"
assert payload["trace_summary"]["scenario"] == "expense"
assert payload["trace_summary"]["intent"] == "draft"
assert payload["trace_summary"]["tool_count"] == 0
assert "应收场景数据" not in payload["result"]["message"]
assert "请补充" in payload["result"]["message"]
def test_orchestrator_can_restore_latest_user_conversation() -> None:
client, _ = build_client()
first_response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "restore_user",
"message": "帮我生成一份差旅报销草稿我昨天去上海出差交通费680元",
"context_json": {
"role_codes": ["finance"],
"attachment_names": ["行程单.png"],
"attachment_count": 1,
"ocr_summary": "行程单金额680元",
},
},
)
assert first_response.status_code == 200
first_payload = first_response.json()
second_response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "restore_user",
"conversation_id": first_payload["conversation_id"],
"message": "金额改成800元",
"context_json": {"role_codes": ["finance"]},
},
)
assert second_response.status_code == 200
restore_response = client.get(
"/api/v1/orchestrator/conversations/latest",
params={"user_id": "restore_user"},
)
assert restore_response.status_code == 200
restore_payload = restore_response.json()
assert restore_payload["found"] is True
assert restore_payload["conversation"]["conversation_id"] == first_payload["conversation_id"]
assert restore_payload["conversation"]["draft_claim_id"] == first_payload["result"]["draft_payload"]["claim_id"]
assert len(restore_payload["conversation"]["messages"]) == 4
assert restore_payload["conversation"]["messages"][0]["role"] == "user"
assert restore_payload["conversation"]["messages"][0]["message_json"]["attachment_names"] == ["行程单.png"]
assert restore_payload["conversation"]["messages"][1]["message_json"]["orchestrator_payload"]["run_id"]
def test_orchestrator_can_delete_all_user_conversations() -> None:
client, session_factory = build_client()
for message in ("查一下本周报销金额", "帮我生成差旅报销草稿"):
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "delete_user",
"message": message,
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
other_response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "other_user",
"message": "查一下供应商待付款",
"context_json": {"role_codes": ["finance"]},
},
)
assert other_response.status_code == 200
delete_response = client.delete(
"/api/v1/orchestrator/conversations",
params={"user_id": "delete_user"},
)
assert delete_response.status_code == 200
delete_payload = delete_response.json()
assert delete_payload["deleted_count"] == 2
with session_factory() as db:
deleted_count = db.scalar(
select(func.count())
.select_from(AgentConversation)
.where(AgentConversation.user_id == "delete_user")
)
other_count = db.scalar(
select(func.count())
.select_from(AgentConversation)
.where(AgentConversation.user_id == "other_user")
)
assert deleted_count == 0
assert other_count == 1
def test_orchestrator_tool_failure_is_logged_and_degraded() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "查一下本周报销金额",
"context_json": {
"role_codes": ["finance"],
"simulate_tool_failure": "database",
},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] == "user_agent"
assert payload["status"] == "succeeded"
assert payload["trace_summary"]["failed_tool_count"] == 1
assert payload["trace_summary"]["degraded"] is True
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
assert run_detail["tool_calls"][0]["status"] == "failed"
assert "simulated database failure" in run_detail["tool_calls"][0]["error_message"]
def test_orchestrator_exception_is_written_to_agent_run() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "查一下本周报销金额",
"context_json": {
"role_codes": ["finance"],
"simulate_orchestrator_exception": True,
},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["status"] == "failed"
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
assert run_detail["status"] == "failed"
assert "simulated orchestrator exception" in run_detail["error_message"]