242 lines
7.9 KiB
Python
242 lines
7.9 KiB
Python
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
from collections.abc import Generator
|
|||
|
|
|
|||
|
|
from fastapi.testclient import TestClient
|
|||
|
|
from sqlalchemy import create_engine
|
|||
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|||
|
|
from sqlalchemy.pool import StaticPool
|
|||
|
|
|
|||
|
|
from app.api.deps import get_db
|
|||
|
|
from app.db.base import Base
|
|||
|
|
from app.main import create_app
|
|||
|
|
from app.services.agent_assets import AgentAssetService
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_client() -> tuple[TestClient, sessionmaker[Session]]:
|
|||
|
|
engine = create_engine(
|
|||
|
|
"sqlite+pysqlite:///:memory:",
|
|||
|
|
connect_args={"check_same_thread": False},
|
|||
|
|
poolclass=StaticPool,
|
|||
|
|
)
|
|||
|
|
Base.metadata.create_all(bind=engine)
|
|||
|
|
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
|||
|
|
app = create_app()
|
|||
|
|
|
|||
|
|
def override_db() -> Generator[Session, None, None]:
|
|||
|
|
db = session_factory()
|
|||
|
|
try:
|
|||
|
|
yield db
|
|||
|
|
finally:
|
|||
|
|
db.close()
|
|||
|
|
|
|||
|
|
app.dependency_overrides[get_db] = override_db
|
|||
|
|
return TestClient(app), session_factory
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_orchestrator_routes_user_query_to_user_agent() -> None:
|
|||
|
|
client, _ = build_client()
|
|||
|
|
|
|||
|
|
response = client.post(
|
|||
|
|
"/api/v1/orchestrator/run",
|
|||
|
|
json={
|
|||
|
|
"source": "user_message",
|
|||
|
|
"user_id": "pytest",
|
|||
|
|
"message": "客户A这个月还有多少应收",
|
|||
|
|
"context_json": {"role_codes": ["finance"]},
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
payload = response.json()
|
|||
|
|
assert payload["selected_agent"] == "user_agent"
|
|||
|
|
assert payload["permission_level"] == "read"
|
|||
|
|
assert payload["status"] == "succeeded"
|
|||
|
|
assert payload["result"]["answer"]
|
|||
|
|
assert payload["result"]["suggested_actions"]
|
|||
|
|
assert payload["trace_summary"]["tool_count"] >= 1
|
|||
|
|
|
|||
|
|
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
|
|||
|
|
assert run_detail["agent"] == "user_agent"
|
|||
|
|
assert run_detail["route_json"]["selected_agent"] == "user_agent"
|
|||
|
|
assert run_detail["semantic_parse"]["scenario"] == "accounts_receivable"
|
|||
|
|
assert run_detail["tool_calls"][0]["tool_type"] == "database"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_orchestrator_routes_schedule_to_hermes() -> None:
|
|||
|
|
client, session_factory = build_client()
|
|||
|
|
|
|||
|
|
with session_factory() as db:
|
|||
|
|
task = next(
|
|||
|
|
item
|
|||
|
|
for item in AgentAssetService(db).list_assets(asset_type="task", status="active")
|
|||
|
|
if item.code == "task.hermes.daily_risk_scan"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
response = client.post(
|
|||
|
|
"/api/v1/orchestrator/run",
|
|||
|
|
json={
|
|||
|
|
"source": "schedule",
|
|||
|
|
"task_id": task.id,
|
|||
|
|
"context_json": {"role_codes": ["finance"]},
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
payload = response.json()
|
|||
|
|
assert payload["selected_agent"] == "hermes"
|
|||
|
|
assert payload["status"] == "succeeded"
|
|||
|
|
assert payload["trace_summary"]["tool_count"] == 2
|
|||
|
|
|
|||
|
|
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
|
|||
|
|
assert run_detail["agent"] == "hermes"
|
|||
|
|
assert run_detail["route_json"]["selected_agent"] == "hermes"
|
|||
|
|
assert len(run_detail["tool_calls"]) == 2
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_orchestrator_forbidden_request_does_not_call_downstream_agent() -> None:
|
|||
|
|
client, _ = build_client()
|
|||
|
|
|
|||
|
|
response = client.post(
|
|||
|
|
"/api/v1/orchestrator/run",
|
|||
|
|
json={
|
|||
|
|
"source": "user_message",
|
|||
|
|
"user_id": "pytest",
|
|||
|
|
"message": "帮我直接付款给供应商B",
|
|||
|
|
"context_json": {"role_codes": ["user"]},
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
payload = response.json()
|
|||
|
|
assert payload["selected_agent"] is None
|
|||
|
|
assert payload["permission_level"] == "forbidden"
|
|||
|
|
assert payload["status"] == "blocked"
|
|||
|
|
assert payload["trace_summary"]["tool_count"] == 0
|
|||
|
|
|
|||
|
|
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
|
|||
|
|
assert run_detail["agent"] == "orchestrator"
|
|||
|
|
assert run_detail["tool_calls"] == []
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_orchestrator_approval_required_returns_confirmation_result() -> None:
|
|||
|
|
client, _ = build_client()
|
|||
|
|
|
|||
|
|
response = client.post(
|
|||
|
|
"/api/v1/orchestrator/run",
|
|||
|
|
json={
|
|||
|
|
"source": "user_message",
|
|||
|
|
"user_id": "pytest",
|
|||
|
|
"message": "帮我安排付款给供应商B",
|
|||
|
|
"context_json": {"role_codes": ["finance"]},
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
payload = response.json()
|
|||
|
|
assert payload["selected_agent"] == "user_agent"
|
|||
|
|
assert payload["permission_level"] == "approval_required"
|
|||
|
|
assert payload["requires_confirmation"] is True
|
|||
|
|
assert payload["status"] == "blocked"
|
|||
|
|
assert "确认" in payload["result"]["message"]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_orchestrator_user_agent_draft_returns_structured_payload() -> None:
|
|||
|
|
client, _ = build_client()
|
|||
|
|
|
|||
|
|
response = client.post(
|
|||
|
|
"/api/v1/orchestrator/run",
|
|||
|
|
json={
|
|||
|
|
"source": "user_message",
|
|||
|
|
"user_id": "pytest",
|
|||
|
|
"message": "帮我生成张三4月差旅报销草稿",
|
|||
|
|
"context_json": {"role_codes": ["finance"]},
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
payload = response.json()
|
|||
|
|
assert payload["selected_agent"] == "user_agent"
|
|||
|
|
assert payload["status"] == "succeeded"
|
|||
|
|
assert payload["result"]["draft_payload"]["confirmation_required"] is True
|
|||
|
|
assert payload["result"]["suggested_actions"]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_orchestrator_treats_expense_narrative_as_draft_instead_of_ar_query() -> None:
|
|||
|
|
client, _ = build_client()
|
|||
|
|
|
|||
|
|
response = client.post(
|
|||
|
|
"/api/v1/orchestrator/run",
|
|||
|
|
json={
|
|||
|
|
"source": "user_message",
|
|||
|
|
"user_id": "pytest",
|
|||
|
|
"message": "我今天去客户现场,招待了客户,花销了1000元",
|
|||
|
|
"context_json": {"role_codes": ["finance"]},
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
payload = response.json()
|
|||
|
|
assert payload["selected_agent"] == "user_agent"
|
|||
|
|
assert payload["permission_level"] == "draft_write"
|
|||
|
|
assert payload["status"] == "blocked"
|
|||
|
|
assert payload["route_reason"] == "clarification_required"
|
|||
|
|
assert payload["trace_summary"]["scenario"] == "expense"
|
|||
|
|
assert payload["trace_summary"]["intent"] == "draft"
|
|||
|
|
assert payload["trace_summary"]["tool_count"] == 0
|
|||
|
|
assert "应收场景数据" not in payload["result"]["message"]
|
|||
|
|
assert "请补充" in payload["result"]["message"]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_orchestrator_tool_failure_is_logged_and_degraded() -> None:
|
|||
|
|
client, _ = build_client()
|
|||
|
|
|
|||
|
|
response = client.post(
|
|||
|
|
"/api/v1/orchestrator/run",
|
|||
|
|
json={
|
|||
|
|
"source": "user_message",
|
|||
|
|
"user_id": "pytest",
|
|||
|
|
"message": "查一下本周报销金额",
|
|||
|
|
"context_json": {
|
|||
|
|
"role_codes": ["finance"],
|
|||
|
|
"simulate_tool_failure": "database",
|
|||
|
|
},
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
payload = response.json()
|
|||
|
|
assert payload["selected_agent"] == "user_agent"
|
|||
|
|
assert payload["status"] == "succeeded"
|
|||
|
|
assert payload["trace_summary"]["failed_tool_count"] == 1
|
|||
|
|
assert payload["trace_summary"]["degraded"] is True
|
|||
|
|
|
|||
|
|
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
|
|||
|
|
assert run_detail["tool_calls"][0]["status"] == "failed"
|
|||
|
|
assert "simulated database failure" in run_detail["tool_calls"][0]["error_message"]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_orchestrator_exception_is_written_to_agent_run() -> None:
|
|||
|
|
client, _ = build_client()
|
|||
|
|
|
|||
|
|
response = client.post(
|
|||
|
|
"/api/v1/orchestrator/run",
|
|||
|
|
json={
|
|||
|
|
"source": "user_message",
|
|||
|
|
"user_id": "pytest",
|
|||
|
|
"message": "查一下本周报销金额",
|
|||
|
|
"context_json": {
|
|||
|
|
"role_codes": ["finance"],
|
|||
|
|
"simulate_orchestrator_exception": True,
|
|||
|
|
},
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert response.status_code == 200
|
|||
|
|
payload = response.json()
|
|||
|
|
assert payload["status"] == "failed"
|
|||
|
|
|
|||
|
|
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
|
|||
|
|
assert run_detail["status"] == "failed"
|
|||
|
|
assert "simulated orchestrator exception" in run_detail["error_message"]
|