Files
X-Financial/server/tests/test_agent_runs_service.py

147 lines
5.3 KiB
Python
Raw Normal View History

from __future__ import annotations
from datetime import UTC, datetime, timedelta
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.core.agent_enums import AgentName, AgentRunSource, AgentRunStatus, AgentToolType
from app.db.base import Base
from app.services.agent_runs import AgentRunService
def build_session() -> Session:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
return session_factory()
def test_agent_run_service_marks_stale_knowledge_sync_run_failed_on_read() -> None:
with build_session() as db:
service = AgentRunService(db)
created = service.create_run(
agent=AgentName.HERMES.value,
source=AgentRunSource.USER_MESSAGE.value,
status=AgentRunStatus.RUNNING.value,
route_json={
"job_type": "knowledge_index_sync",
"heartbeat_at": (datetime.now(UTC) - timedelta(minutes=31)).isoformat(),
"requested_document_ids": [],
},
)
fetched = service.get_run(created.run_id)
running_runs = service.list_runs(
agent=AgentName.HERMES.value,
status=AgentRunStatus.RUNNING.value,
limit=100,
)
assert fetched is not None
assert fetched.status == AgentRunStatus.FAILED.value
assert fetched.error_message == "Knowledge index heartbeat timed out."
assert all(item.run_id != created.run_id for item in running_runs)
def test_agent_run_service_marks_stale_llm_wiki_run_failed_on_read() -> None:
with build_session() as db:
service = AgentRunService(db)
created = service.create_run(
agent=AgentName.HERMES.value,
source=AgentRunSource.SCHEDULE.value,
status=AgentRunStatus.RUNNING.value,
route_json={
"job_type": "llm_wiki_sync",
"heartbeat_at": (datetime.now(UTC) - timedelta(minutes=31)).isoformat(),
"requested_document_ids": [],
},
)
fetched = service.get_run(created.run_id)
assert fetched is not None
assert fetched.status == AgentRunStatus.FAILED.value
assert fetched.error_message == "Knowledge index heartbeat timed out."
def test_agent_run_service_updates_existing_tool_call() -> None:
with build_session() as db:
service = AgentRunService(db)
run = service.create_run(
agent=AgentName.HERMES.value,
source=AgentRunSource.USER_MESSAGE.value,
status=AgentRunStatus.RUNNING.value,
route_json={"job_type": "knowledge_index_sync"},
)
tool_call = service.record_tool_call(
run_id=run.run_id,
tool_type=AgentToolType.LLM.value,
tool_name="lightrag.index_documents",
request_json={"document_ids": ["doc-1"]},
response_json={"phase": "indexing"},
status="running",
duration_ms=0,
)
updated = service.update_tool_call(
tool_call.id,
response_json={"track_id": "insert_123"},
status="succeeded",
duration_ms=1250,
error_message=None,
)
fetched = service.get_run(run.run_id)
assert updated.status == "succeeded"
assert updated.duration_ms == 1250
assert fetched is not None
assert len(fetched.tool_calls) == 1
assert fetched.tool_calls[0].status == "succeeded"
assert fetched.tool_calls[0].response_json == {"track_id": "insert_123"}
def test_agent_run_service_summarizes_model_and_tool_failures() -> None:
with build_session() as db:
service = AgentRunService(db)
run = service.create_run(
agent=AgentName.ORCHESTRATOR.value,
source=AgentRunSource.USER_MESSAGE.value,
status=AgentRunStatus.SUCCEEDED.value,
ontology_json={
"parse_strategy": "rule_fallback",
"model_invocation_summary": {
"model_guardrail_reason": "model_conflicts_with_application_stage_signal"
},
},
)
service.record_tool_call(
run_id=run.run_id,
tool_type=AgentToolType.LLM.value,
tool_name="semantic_ontology.main",
request_json={"stage": "semantic_parse"},
response_json={"model_guardrail_reason": "model_conflicts_with_application_stage_signal"},
status="failed",
duration_ms=18,
error_message="model_conflicts_with_application_stage_signal",
)
stats = service.summarize_runs(agent=AgentName.ORCHESTRATOR.value, limit=20)
assert stats.total_runs >= 1
assert stats.tool_call_count >= 1
assert stats.failed_tool_call_count >= 1
assert stats.llm_call_count >= 1
assert stats.failed_llm_call_count >= 1
assert stats.model_fallback_count >= 1
assert stats.model_guardrail_count >= 1
assert any(
item.get("tool_name") == "semantic_ontology.main"
for item in stats.recent_errors
)