2026-05-11 03:51:24 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
2026-06-06 17:19:07 +08:00
|
|
|
from typing import Any
|
|
|
|
|
|
2026-05-11 03:51:24 +00:00
|
|
|
from sqlalchemy import select
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
|
|
|
|
from app.models.agent_run import AgentRun, AgentToolCall, SemanticParseLog
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AgentRunRepository:
|
|
|
|
|
def __init__(self, db: Session) -> None:
|
|
|
|
|
self.db = db
|
|
|
|
|
|
|
|
|
|
def list(
|
|
|
|
|
self,
|
|
|
|
|
*,
|
|
|
|
|
agent: str | None = None,
|
|
|
|
|
status: str | None = None,
|
|
|
|
|
source: str | None = None,
|
|
|
|
|
limit: int = 20,
|
|
|
|
|
) -> list[AgentRun]:
|
|
|
|
|
stmt = select(AgentRun)
|
|
|
|
|
if agent:
|
|
|
|
|
stmt = stmt.where(AgentRun.agent == agent)
|
|
|
|
|
if status:
|
|
|
|
|
stmt = stmt.where(AgentRun.status == status)
|
|
|
|
|
if source:
|
|
|
|
|
stmt = stmt.where(AgentRun.source == source)
|
|
|
|
|
stmt = stmt.order_by(AgentRun.started_at.desc()).limit(limit)
|
|
|
|
|
return list(self.db.scalars(stmt).all())
|
|
|
|
|
|
2026-06-06 17:19:07 +08:00
|
|
|
def list_light(
|
|
|
|
|
self,
|
|
|
|
|
*,
|
|
|
|
|
agent: str | None = None,
|
|
|
|
|
status: str | None = None,
|
|
|
|
|
source: str | None = None,
|
|
|
|
|
limit: int = 20,
|
|
|
|
|
) -> list[dict[str, Any]]:
|
|
|
|
|
stmt = select(
|
|
|
|
|
AgentRun.id.label("id"),
|
|
|
|
|
AgentRun.run_id.label("run_id"),
|
|
|
|
|
AgentRun.agent.label("agent"),
|
|
|
|
|
AgentRun.source.label("source"),
|
|
|
|
|
AgentRun.user_id.label("user_id"),
|
|
|
|
|
AgentRun.task_id.label("task_id"),
|
|
|
|
|
AgentRun.permission_level.label("permission_level"),
|
|
|
|
|
AgentRun.status.label("status"),
|
|
|
|
|
AgentRun.result_summary.label("result_summary"),
|
|
|
|
|
AgentRun.error_message.label("error_message"),
|
|
|
|
|
AgentRun.started_at.label("started_at"),
|
|
|
|
|
AgentRun.finished_at.label("finished_at"),
|
|
|
|
|
AgentRun.route_json["job_type"].as_string().label("route_job_type"),
|
|
|
|
|
AgentRun.route_json["task_type"].as_string().label("route_task_type"),
|
|
|
|
|
AgentRun.route_json["task_code"].as_string().label("route_task_code"),
|
|
|
|
|
AgentRun.route_json["task_name"].as_string().label("route_task_name"),
|
|
|
|
|
AgentRun.route_json["task_title"].as_string().label("route_task_title"),
|
|
|
|
|
AgentRun.route_json["asset_name"].as_string().label("route_asset_name"),
|
|
|
|
|
AgentRun.route_json["selected_agent"].as_string().label("route_selected_agent"),
|
|
|
|
|
AgentRun.route_json["phase"].as_string().label("route_phase"),
|
|
|
|
|
AgentRun.route_json["stage"].as_string().label("route_stage"),
|
|
|
|
|
AgentRun.route_json["report_type"].as_string().label("route_report_type"),
|
|
|
|
|
AgentRun.route_json["snapshot_key"].as_string().label("route_snapshot_key"),
|
|
|
|
|
AgentRun.route_json["folder"].as_string().label("route_folder"),
|
|
|
|
|
AgentRun.route_json["heartbeat_at"].as_string().label("route_heartbeat_at"),
|
|
|
|
|
AgentRun.route_json["progress"].label("route_progress"),
|
|
|
|
|
AgentRun.ontology_json["scenario"].as_string().label("ontology_scenario"),
|
|
|
|
|
AgentRun.ontology_json["intent"].as_string().label("ontology_intent"),
|
|
|
|
|
AgentRun.ontology_json["parse_strategy"].as_string().label("ontology_parse_strategy"),
|
|
|
|
|
)
|
|
|
|
|
if agent:
|
|
|
|
|
stmt = stmt.where(AgentRun.agent == agent)
|
|
|
|
|
if status:
|
|
|
|
|
stmt = stmt.where(AgentRun.status == status)
|
|
|
|
|
if source:
|
|
|
|
|
stmt = stmt.where(AgentRun.source == source)
|
|
|
|
|
stmt = stmt.order_by(AgentRun.started_at.desc()).limit(limit)
|
|
|
|
|
return [dict(item) for item in self.db.execute(stmt).mappings().all()]
|
|
|
|
|
|
|
|
|
|
def list_light_tool_calls(self, run_ids: list[str]) -> list[dict[str, Any]]:
|
|
|
|
|
if not run_ids:
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
stmt = (
|
|
|
|
|
select(
|
|
|
|
|
AgentToolCall.id.label("id"),
|
|
|
|
|
AgentToolCall.run_id.label("run_id"),
|
|
|
|
|
AgentToolCall.tool_type.label("tool_type"),
|
|
|
|
|
AgentToolCall.tool_name.label("tool_name"),
|
|
|
|
|
AgentToolCall.status.label("status"),
|
|
|
|
|
AgentToolCall.duration_ms.label("duration_ms"),
|
|
|
|
|
AgentToolCall.error_message.label("error_message"),
|
|
|
|
|
AgentToolCall.created_at.label("created_at"),
|
|
|
|
|
)
|
|
|
|
|
.where(AgentToolCall.run_id.in_(run_ids))
|
|
|
|
|
.order_by(AgentToolCall.created_at.asc())
|
|
|
|
|
)
|
|
|
|
|
return [dict(item) for item in self.db.execute(stmt).mappings().all()]
|
|
|
|
|
|
2026-05-11 03:51:24 +00:00
|
|
|
def get_by_run_id(self, run_id: str) -> AgentRun | None:
|
|
|
|
|
stmt = select(AgentRun).where(AgentRun.run_id == run_id)
|
|
|
|
|
return self.db.scalar(stmt)
|
|
|
|
|
|
|
|
|
|
def create_run(self, run: AgentRun) -> AgentRun:
|
|
|
|
|
self.db.add(run)
|
|
|
|
|
self.db.commit()
|
|
|
|
|
self.db.refresh(run)
|
|
|
|
|
return run
|
|
|
|
|
|
|
|
|
|
def save_run(self, run: AgentRun) -> AgentRun:
|
|
|
|
|
self.db.add(run)
|
|
|
|
|
self.db.commit()
|
|
|
|
|
self.db.refresh(run)
|
|
|
|
|
return run
|
|
|
|
|
|
|
|
|
|
def create_tool_call(self, tool_call: AgentToolCall) -> AgentToolCall:
|
|
|
|
|
self.db.add(tool_call)
|
|
|
|
|
self.db.commit()
|
|
|
|
|
self.db.refresh(tool_call)
|
|
|
|
|
return tool_call
|
|
|
|
|
|
2026-05-18 02:53:06 +00:00
|
|
|
def get_tool_call(self, tool_call_id: str) -> AgentToolCall | None:
|
|
|
|
|
stmt = select(AgentToolCall).where(AgentToolCall.id == tool_call_id)
|
|
|
|
|
return self.db.scalar(stmt)
|
|
|
|
|
|
|
|
|
|
def save_tool_call(self, tool_call: AgentToolCall) -> AgentToolCall:
|
|
|
|
|
self.db.add(tool_call)
|
|
|
|
|
self.db.commit()
|
|
|
|
|
self.db.refresh(tool_call)
|
|
|
|
|
return tool_call
|
|
|
|
|
|
2026-05-11 03:51:24 +00:00
|
|
|
def create_semantic_parse(self, semantic_parse: SemanticParseLog) -> SemanticParseLog:
|
|
|
|
|
self.db.add(semantic_parse)
|
|
|
|
|
self.db.commit()
|
|
|
|
|
self.db.refresh(semantic_parse)
|
|
|
|
|
return semantic_parse
|