Files
X-Financial/server/tests/test_agent_trace_service.py

206 lines
7.5 KiB
Python
Raw Permalink Normal View History

from __future__ import annotations
from collections.abc import Generator
from datetime import UTC, datetime, timedelta
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.api.deps import get_db
from app.core.agent_enums import AgentName, AgentRunSource, AgentRunStatus, AgentToolType
from app.db.base import Base
from app.main import create_app
from app.models.agent_conversation import AgentConversation, AgentConversationMessage
from app.services.agent_runs import AgentRunService
from app.services.agent_traces import AgentTraceService
def build_session() -> Session:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
return session_factory()
def build_client() -> tuple[TestClient, sessionmaker[Session]]:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
app = create_app()
def override_db() -> Generator[Session, None, None]:
db = session_factory()
try:
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_db
return TestClient(app), session_factory
def test_agent_trace_service_records_events_and_reads_detail() -> None:
with build_session() as db:
run_service = AgentRunService(db)
trace_service = AgentTraceService(db)
started_at = datetime.now(UTC) - timedelta(seconds=2)
run = run_service.create_run(
agent=AgentName.ORCHESTRATOR.value,
source=AgentRunSource.USER_MESSAGE.value,
status=AgentRunStatus.SUCCEEDED.value,
route_json={"conversation_id": "conv-trace-1"},
result_summary="expense answer ready",
started_at=started_at,
finished_at=started_at + timedelta(seconds=1),
)
db.add(
AgentConversation(
conversation_id="conv-trace-1",
user_id="u-1",
source=AgentRunSource.USER_MESSAGE.value,
)
)
db.add(
AgentConversationMessage(
conversation_id="conv-trace-1",
run_id=run.run_id,
role="user",
content="帮我看报销风险",
message_json={"source": "test"},
)
)
db.commit()
first = trace_service.record_event(
run_id=run.run_id,
conversation_id="conv-trace-1",
stage="orchestrator",
event_name="request_received",
title="接收请求",
summary="用户消息进入编排",
input_json={"message": "帮我看报销风险"},
output_json={"run_id": run.run_id},
started_at=started_at,
finished_at=started_at + timedelta(milliseconds=20),
)
second = trace_service.record_event(
run_id=run.run_id,
conversation_id="conv-trace-1",
stage="response",
event_name="response_built",
title="生成回复",
status=AgentRunStatus.SUCCEEDED.value,
output_json={"message": "已完成"},
started_at=started_at + timedelta(milliseconds=500),
finished_at=started_at + timedelta(milliseconds=650),
)
items = trace_service.list_traces(keyword=run.run_id, limit=10)
detail = trace_service.get_trace(run.run_id)
assert first.sequence == 1
assert second.sequence == 2
assert len(items) == 1
assert items[0].event_count == 2
assert detail is not None
assert detail.fallback_generated is False
assert [event.event_name for event in detail.events] == [
"request_received",
"response_built",
]
assert detail.conversation_id == "conv-trace-1"
assert detail.conversation_messages[0].content == "帮我看报销风险"
def test_agent_trace_service_builds_fallback_timeline_for_legacy_runs() -> None:
with build_session() as db:
run_service = AgentRunService(db)
trace_service = AgentTraceService(db)
run = run_service.create_run(
agent=AgentName.HERMES.value,
source=AgentRunSource.SCHEDULE.value,
status=AgentRunStatus.FAILED.value,
route_json={"conversation_id": "conv-trace-legacy", "stage": "tooling"},
result_summary="sync failed",
error_message="boom",
)
run_service.record_semantic_parse(
run_id=run.run_id,
user_id="u-1",
raw_query="同步知识库",
scenario="knowledge",
intent="sync",
confidence=0.92,
)
run_service.record_tool_call(
run_id=run.run_id,
tool_type=AgentToolType.LLM.value,
tool_name="lightrag.index_documents",
request_json={"document_ids": ["doc-1"]},
response_json={"fallback": True},
status=AgentRunStatus.FAILED.value,
duration_ms=31,
error_message="boom",
)
detail = trace_service.get_trace(run.run_id)
conversation_detail = trace_service.get_conversation_trace("conv-trace-legacy")
assert detail is not None
assert detail.fallback_generated is True
assert detail.conversation_id == "conv-trace-legacy"
assert "semantic_parsed" in [event.event_name for event in detail.events]
assert "tool_invoked" in [event.event_name for event in detail.events]
assert detail.events[-1].event_name == "failed"
assert detail.tool_calls[0].tool_name == "lightrag.index_documents"
assert [item.run.run_id for item in conversation_detail.runs] == [run.run_id]
def test_agent_trace_endpoints_return_admin_trace_detail() -> None:
client, session_factory = build_client()
with session_factory() as db:
run_service = AgentRunService(db)
trace_service = AgentTraceService(db)
run = run_service.create_run(
agent=AgentName.ORCHESTRATOR.value,
source=AgentRunSource.USER_MESSAGE.value,
status=AgentRunStatus.SUCCEEDED.value,
route_json={"conversation_id": "conv-api-trace"},
result_summary="api trace ready",
)
trace_service.record_event(
run_id=run.run_id,
conversation_id="conv-api-trace",
stage="response",
event_name="response_built",
title="生成回复",
status=AgentRunStatus.SUCCEEDED.value,
output_json={"message": "ok"},
)
headers = {
"x-auth-username": "admin",
"x-auth-name": "admin",
"x-auth-is-admin": "true",
}
list_response = client.get("/api/v1/agent-traces", headers=headers)
detail_response = client.get(f"/api/v1/agent-traces/{run.run_id}", headers=headers)
assert list_response.status_code == 200
assert any(item["run_id"] == run.run_id for item in list_response.json())
assert detail_response.status_code == 200
payload = detail_response.json()
assert payload["run"]["run_id"] == run.run_id
assert payload["events"][0]["event_name"] == "response_built"