from __future__ import annotations from datetime import UTC, datetime, timedelta from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.pool import StaticPool from app.core.agent_enums import AgentName, AgentRunSource, AgentRunStatus, AgentToolType from app.db.base import Base from app.services.agent_runs import AgentRunService def build_session() -> Session: engine = create_engine( "sqlite+pysqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) Base.metadata.create_all(bind=engine) session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False) return session_factory() def test_agent_run_service_marks_stale_knowledge_sync_run_failed_on_read() -> None: with build_session() as db: service = AgentRunService(db) created = service.create_run( agent=AgentName.HERMES.value, source=AgentRunSource.USER_MESSAGE.value, status=AgentRunStatus.RUNNING.value, route_json={ "job_type": "knowledge_index_sync", "heartbeat_at": (datetime.now(UTC) - timedelta(minutes=31)).isoformat(), "requested_document_ids": [], }, ) fetched = service.get_run(created.run_id) running_runs = service.list_runs( agent=AgentName.HERMES.value, status=AgentRunStatus.RUNNING.value, limit=100, ) assert fetched is not None assert fetched.status == AgentRunStatus.FAILED.value assert fetched.error_message == "Knowledge index heartbeat timed out." assert all(item.run_id != created.run_id for item in running_runs) def test_agent_run_service_marks_stale_llm_wiki_run_failed_on_read() -> None: with build_session() as db: service = AgentRunService(db) created = service.create_run( agent=AgentName.HERMES.value, source=AgentRunSource.SCHEDULE.value, status=AgentRunStatus.RUNNING.value, route_json={ "job_type": "llm_wiki_sync", "heartbeat_at": (datetime.now(UTC) - timedelta(minutes=31)).isoformat(), "requested_document_ids": [], }, ) fetched = service.get_run(created.run_id) assert fetched is not None assert fetched.status == AgentRunStatus.FAILED.value assert fetched.error_message == "Knowledge index heartbeat timed out." def test_agent_run_service_updates_existing_tool_call() -> None: with build_session() as db: service = AgentRunService(db) run = service.create_run( agent=AgentName.HERMES.value, source=AgentRunSource.USER_MESSAGE.value, status=AgentRunStatus.RUNNING.value, route_json={"job_type": "knowledge_index_sync"}, ) tool_call = service.record_tool_call( run_id=run.run_id, tool_type=AgentToolType.LLM.value, tool_name="lightrag.index_documents", request_json={"document_ids": ["doc-1"]}, response_json={"phase": "indexing"}, status="running", duration_ms=0, ) updated = service.update_tool_call( tool_call.id, response_json={"track_id": "insert_123"}, status="succeeded", duration_ms=1250, error_message=None, ) fetched = service.get_run(run.run_id) assert updated.status == "succeeded" assert updated.duration_ms == 1250 assert fetched is not None assert len(fetched.tool_calls) == 1 assert fetched.tool_calls[0].status == "succeeded" assert fetched.tool_calls[0].response_json == {"track_id": "insert_123"} def test_agent_run_service_summarizes_model_and_tool_failures() -> None: with build_session() as db: service = AgentRunService(db) run = service.create_run( agent=AgentName.ORCHESTRATOR.value, source=AgentRunSource.USER_MESSAGE.value, status=AgentRunStatus.SUCCEEDED.value, ontology_json={ "parse_strategy": "rule_fallback", "model_invocation_summary": { "model_guardrail_reason": "model_conflicts_with_application_stage_signal" }, }, ) service.record_tool_call( run_id=run.run_id, tool_type=AgentToolType.LLM.value, tool_name="semantic_ontology.main", request_json={"stage": "semantic_parse"}, response_json={"model_guardrail_reason": "model_conflicts_with_application_stage_signal"}, status="failed", duration_ms=18, error_message="model_conflicts_with_application_stage_signal", ) stats = service.summarize_runs(agent=AgentName.ORCHESTRATOR.value, limit=20) assert stats.total_runs >= 1 assert stats.tool_call_count >= 1 assert stats.failed_tool_call_count >= 1 assert stats.llm_call_count >= 1 assert stats.failed_llm_call_count >= 1 assert stats.model_fallback_count >= 1 assert stats.model_guardrail_count >= 1 assert any( item.get("tool_name") == "semantic_ontology.main" for item in stats.recent_errors )