from __future__ import annotations from typing import Any from sqlalchemy import select from sqlalchemy.orm import Session from app.models.agent_run import AgentRun, AgentToolCall, SemanticParseLog class AgentRunRepository: def __init__(self, db: Session) -> None: self.db = db def list( self, *, agent: str | None = None, status: str | None = None, source: str | None = None, limit: int = 20, ) -> list[AgentRun]: stmt = select(AgentRun) if agent: stmt = stmt.where(AgentRun.agent == agent) if status: stmt = stmt.where(AgentRun.status == status) if source: stmt = stmt.where(AgentRun.source == source) stmt = stmt.order_by(AgentRun.started_at.desc()).limit(limit) return list(self.db.scalars(stmt).all()) def list_light( self, *, agent: str | None = None, status: str | None = None, source: str | None = None, limit: int = 20, ) -> list[dict[str, Any]]: stmt = select( AgentRun.id.label("id"), AgentRun.run_id.label("run_id"), AgentRun.agent.label("agent"), AgentRun.source.label("source"), AgentRun.user_id.label("user_id"), AgentRun.task_id.label("task_id"), AgentRun.permission_level.label("permission_level"), AgentRun.status.label("status"), AgentRun.result_summary.label("result_summary"), AgentRun.error_message.label("error_message"), AgentRun.started_at.label("started_at"), AgentRun.finished_at.label("finished_at"), AgentRun.route_json["job_type"].as_string().label("route_job_type"), AgentRun.route_json["task_type"].as_string().label("route_task_type"), AgentRun.route_json["task_code"].as_string().label("route_task_code"), AgentRun.route_json["task_name"].as_string().label("route_task_name"), AgentRun.route_json["task_title"].as_string().label("route_task_title"), AgentRun.route_json["asset_name"].as_string().label("route_asset_name"), AgentRun.route_json["selected_agent"].as_string().label("route_selected_agent"), AgentRun.route_json["phase"].as_string().label("route_phase"), AgentRun.route_json["stage"].as_string().label("route_stage"), AgentRun.route_json["report_type"].as_string().label("route_report_type"), AgentRun.route_json["snapshot_key"].as_string().label("route_snapshot_key"), AgentRun.route_json["folder"].as_string().label("route_folder"), AgentRun.route_json["heartbeat_at"].as_string().label("route_heartbeat_at"), AgentRun.route_json["progress"].label("route_progress"), AgentRun.ontology_json["scenario"].as_string().label("ontology_scenario"), AgentRun.ontology_json["intent"].as_string().label("ontology_intent"), AgentRun.ontology_json["parse_strategy"].as_string().label("ontology_parse_strategy"), ) if agent: stmt = stmt.where(AgentRun.agent == agent) if status: stmt = stmt.where(AgentRun.status == status) if source: stmt = stmt.where(AgentRun.source == source) stmt = stmt.order_by(AgentRun.started_at.desc()).limit(limit) return [dict(item) for item in self.db.execute(stmt).mappings().all()] def list_light_tool_calls(self, run_ids: list[str]) -> list[dict[str, Any]]: if not run_ids: return [] stmt = ( select( AgentToolCall.id.label("id"), AgentToolCall.run_id.label("run_id"), AgentToolCall.tool_type.label("tool_type"), AgentToolCall.tool_name.label("tool_name"), AgentToolCall.status.label("status"), AgentToolCall.duration_ms.label("duration_ms"), AgentToolCall.error_message.label("error_message"), AgentToolCall.created_at.label("created_at"), ) .where(AgentToolCall.run_id.in_(run_ids)) .order_by(AgentToolCall.created_at.asc()) ) return [dict(item) for item in self.db.execute(stmt).mappings().all()] def get_by_run_id(self, run_id: str) -> AgentRun | None: stmt = select(AgentRun).where(AgentRun.run_id == run_id) return self.db.scalar(stmt) def create_run(self, run: AgentRun) -> AgentRun: self.db.add(run) self.db.commit() self.db.refresh(run) return run def save_run(self, run: AgentRun) -> AgentRun: self.db.add(run) self.db.commit() self.db.refresh(run) return run def create_tool_call(self, tool_call: AgentToolCall) -> AgentToolCall: self.db.add(tool_call) self.db.commit() self.db.refresh(tool_call) return tool_call def get_tool_call(self, tool_call_id: str) -> AgentToolCall | None: stmt = select(AgentToolCall).where(AgentToolCall.id == tool_call_id) return self.db.scalar(stmt) def save_tool_call(self, tool_call: AgentToolCall) -> AgentToolCall: self.db.add(tool_call) self.db.commit() self.db.refresh(tool_call) return tool_call def create_semantic_parse(self, semantic_parse: SemanticParseLog) -> SemanticParseLog: self.db.add(semantic_parse) self.db.commit() self.db.refresh(semantic_parse) return semantic_parse