58 lines
1.7 KiB
Python
58 lines
1.7 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from sqlalchemy import select
|
||
|
|
from sqlalchemy.orm import Session
|
||
|
|
|
||
|
|
from app.models.agent_run import AgentRun, AgentToolCall, SemanticParseLog
|
||
|
|
|
||
|
|
|
||
|
|
class AgentRunRepository:
|
||
|
|
def __init__(self, db: Session) -> None:
|
||
|
|
self.db = db
|
||
|
|
|
||
|
|
def list(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
agent: str | None = None,
|
||
|
|
status: str | None = None,
|
||
|
|
source: str | None = None,
|
||
|
|
limit: int = 20,
|
||
|
|
) -> list[AgentRun]:
|
||
|
|
stmt = select(AgentRun)
|
||
|
|
if agent:
|
||
|
|
stmt = stmt.where(AgentRun.agent == agent)
|
||
|
|
if status:
|
||
|
|
stmt = stmt.where(AgentRun.status == status)
|
||
|
|
if source:
|
||
|
|
stmt = stmt.where(AgentRun.source == source)
|
||
|
|
stmt = stmt.order_by(AgentRun.started_at.desc()).limit(limit)
|
||
|
|
return list(self.db.scalars(stmt).all())
|
||
|
|
|
||
|
|
def get_by_run_id(self, run_id: str) -> AgentRun | None:
|
||
|
|
stmt = select(AgentRun).where(AgentRun.run_id == run_id)
|
||
|
|
return self.db.scalar(stmt)
|
||
|
|
|
||
|
|
def create_run(self, run: AgentRun) -> AgentRun:
|
||
|
|
self.db.add(run)
|
||
|
|
self.db.commit()
|
||
|
|
self.db.refresh(run)
|
||
|
|
return run
|
||
|
|
|
||
|
|
def save_run(self, run: AgentRun) -> AgentRun:
|
||
|
|
self.db.add(run)
|
||
|
|
self.db.commit()
|
||
|
|
self.db.refresh(run)
|
||
|
|
return run
|
||
|
|
|
||
|
|
def create_tool_call(self, tool_call: AgentToolCall) -> AgentToolCall:
|
||
|
|
self.db.add(tool_call)
|
||
|
|
self.db.commit()
|
||
|
|
self.db.refresh(tool_call)
|
||
|
|
return tool_call
|
||
|
|
|
||
|
|
def create_semantic_parse(self, semantic_parse: SemanticParseLog) -> SemanticParseLog:
|
||
|
|
self.db.add(semantic_parse)
|
||
|
|
self.db.commit()
|
||
|
|
self.db.refresh(semantic_parse)
|
||
|
|
return semantic_parse
|