Files
X-Financial/server/tests/test_orchestrator_service.py
caoxiaozhu 22d47cbf2b feat(backend): add ontology and orchestrator API endpoints
New endpoints:
- server/src/app/api/v1/endpoints/ontology.py: ontology API
- server/src/app/api/v1/endpoints/orchestrator.py: orchestrator API

New schemas:
- server/src/app/schemas/ontology.py: ontology data schemas
- server/src/app/schemas/orchestrator.py: orchestrator data schemas
- server/src/app/schemas/user_agent.py: user agent data schemas

New services:
- server/src/app/services/ontology.py: ontology business logic
- server/src/app/services/orchestrator.py: orchestrator business logic
- server/src/app/services/runtime_chat.py: runtime chat service
- server/src/app/services/user_agent.py: user agent service

New tests:
- server/tests/test_ontology_service.py
- server/tests/test_orchestrator_service.py
- server/tests/test_user_agent_service.py
2026-05-12 01:24:39 +00:00

242 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
from collections.abc import Generator
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.api.deps import get_db
from app.db.base import Base
from app.main import create_app
from app.services.agent_assets import AgentAssetService
def build_client() -> tuple[TestClient, sessionmaker[Session]]:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
app = create_app()
def override_db() -> Generator[Session, None, None]:
db = session_factory()
try:
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_db
return TestClient(app), session_factory
def test_orchestrator_routes_user_query_to_user_agent() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "客户A这个月还有多少应收",
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] == "user_agent"
assert payload["permission_level"] == "read"
assert payload["status"] == "succeeded"
assert payload["result"]["answer"]
assert payload["result"]["suggested_actions"]
assert payload["trace_summary"]["tool_count"] >= 1
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
assert run_detail["agent"] == "user_agent"
assert run_detail["route_json"]["selected_agent"] == "user_agent"
assert run_detail["semantic_parse"]["scenario"] == "accounts_receivable"
assert run_detail["tool_calls"][0]["tool_type"] == "database"
def test_orchestrator_routes_schedule_to_hermes() -> None:
client, session_factory = build_client()
with session_factory() as db:
task = next(
item
for item in AgentAssetService(db).list_assets(asset_type="task", status="active")
if item.code == "task.hermes.daily_risk_scan"
)
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "schedule",
"task_id": task.id,
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] == "hermes"
assert payload["status"] == "succeeded"
assert payload["trace_summary"]["tool_count"] == 2
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
assert run_detail["agent"] == "hermes"
assert run_detail["route_json"]["selected_agent"] == "hermes"
assert len(run_detail["tool_calls"]) == 2
def test_orchestrator_forbidden_request_does_not_call_downstream_agent() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "帮我直接付款给供应商B",
"context_json": {"role_codes": ["user"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] is None
assert payload["permission_level"] == "forbidden"
assert payload["status"] == "blocked"
assert payload["trace_summary"]["tool_count"] == 0
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
assert run_detail["agent"] == "orchestrator"
assert run_detail["tool_calls"] == []
def test_orchestrator_approval_required_returns_confirmation_result() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "帮我安排付款给供应商B",
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] == "user_agent"
assert payload["permission_level"] == "approval_required"
assert payload["requires_confirmation"] is True
assert payload["status"] == "blocked"
assert "确认" in payload["result"]["message"]
def test_orchestrator_user_agent_draft_returns_structured_payload() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "帮我生成张三4月差旅报销草稿",
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] == "user_agent"
assert payload["status"] == "succeeded"
assert payload["result"]["draft_payload"]["confirmation_required"] is True
assert payload["result"]["suggested_actions"]
def test_orchestrator_treats_expense_narrative_as_draft_instead_of_ar_query() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "我今天去客户现场招待了客户花销了1000元",
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] == "user_agent"
assert payload["permission_level"] == "draft_write"
assert payload["status"] == "blocked"
assert payload["route_reason"] == "clarification_required"
assert payload["trace_summary"]["scenario"] == "expense"
assert payload["trace_summary"]["intent"] == "draft"
assert payload["trace_summary"]["tool_count"] == 0
assert "应收场景数据" not in payload["result"]["message"]
assert "请补充" in payload["result"]["message"]
def test_orchestrator_tool_failure_is_logged_and_degraded() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "查一下本周报销金额",
"context_json": {
"role_codes": ["finance"],
"simulate_tool_failure": "database",
},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["selected_agent"] == "user_agent"
assert payload["status"] == "succeeded"
assert payload["trace_summary"]["failed_tool_count"] == 1
assert payload["trace_summary"]["degraded"] is True
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
assert run_detail["tool_calls"][0]["status"] == "failed"
assert "simulated database failure" in run_detail["tool_calls"][0]["error_message"]
def test_orchestrator_exception_is_written_to_agent_run() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": "pytest",
"message": "查一下本周报销金额",
"context_json": {
"role_codes": ["finance"],
"simulate_orchestrator_exception": True,
},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["status"] == "failed"
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
assert run_detail["status"] == "failed"
assert "simulated orchestrator exception" in run_detail["error_message"]