Files
X-Financial/server/src/app/repositories/agent_run.py
caoxiaozhu e124e4bbcb feat: 报销审批流重构与管家计划全链路贯通
- 重构报销状态注册表、审批流路由与平台风险标记
- 完善管家意图规划器与模型计划构建器全链路
- 新增 OCR Worker 脚本、数据库会话管理与通知状态
- 优化文档中心、日志视图、预算中心与员工管理交互
- 增强工作台摘要、图标资源与全局主题样式
- 补充审批路由、状态注册、OCR 服务与管家规划器测试覆盖
2026-06-06 17:19:07 +08:00

138 lines
5.5 KiB
Python

from __future__ import annotations
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.models.agent_run import AgentRun, AgentToolCall, SemanticParseLog
class AgentRunRepository:
def __init__(self, db: Session) -> None:
self.db = db
def list(
self,
*,
agent: str | None = None,
status: str | None = None,
source: str | None = None,
limit: int = 20,
) -> list[AgentRun]:
stmt = select(AgentRun)
if agent:
stmt = stmt.where(AgentRun.agent == agent)
if status:
stmt = stmt.where(AgentRun.status == status)
if source:
stmt = stmt.where(AgentRun.source == source)
stmt = stmt.order_by(AgentRun.started_at.desc()).limit(limit)
return list(self.db.scalars(stmt).all())
def list_light(
self,
*,
agent: str | None = None,
status: str | None = None,
source: str | None = None,
limit: int = 20,
) -> list[dict[str, Any]]:
stmt = select(
AgentRun.id.label("id"),
AgentRun.run_id.label("run_id"),
AgentRun.agent.label("agent"),
AgentRun.source.label("source"),
AgentRun.user_id.label("user_id"),
AgentRun.task_id.label("task_id"),
AgentRun.permission_level.label("permission_level"),
AgentRun.status.label("status"),
AgentRun.result_summary.label("result_summary"),
AgentRun.error_message.label("error_message"),
AgentRun.started_at.label("started_at"),
AgentRun.finished_at.label("finished_at"),
AgentRun.route_json["job_type"].as_string().label("route_job_type"),
AgentRun.route_json["task_type"].as_string().label("route_task_type"),
AgentRun.route_json["task_code"].as_string().label("route_task_code"),
AgentRun.route_json["task_name"].as_string().label("route_task_name"),
AgentRun.route_json["task_title"].as_string().label("route_task_title"),
AgentRun.route_json["asset_name"].as_string().label("route_asset_name"),
AgentRun.route_json["selected_agent"].as_string().label("route_selected_agent"),
AgentRun.route_json["phase"].as_string().label("route_phase"),
AgentRun.route_json["stage"].as_string().label("route_stage"),
AgentRun.route_json["report_type"].as_string().label("route_report_type"),
AgentRun.route_json["snapshot_key"].as_string().label("route_snapshot_key"),
AgentRun.route_json["folder"].as_string().label("route_folder"),
AgentRun.route_json["heartbeat_at"].as_string().label("route_heartbeat_at"),
AgentRun.route_json["progress"].label("route_progress"),
AgentRun.ontology_json["scenario"].as_string().label("ontology_scenario"),
AgentRun.ontology_json["intent"].as_string().label("ontology_intent"),
AgentRun.ontology_json["parse_strategy"].as_string().label("ontology_parse_strategy"),
)
if agent:
stmt = stmt.where(AgentRun.agent == agent)
if status:
stmt = stmt.where(AgentRun.status == status)
if source:
stmt = stmt.where(AgentRun.source == source)
stmt = stmt.order_by(AgentRun.started_at.desc()).limit(limit)
return [dict(item) for item in self.db.execute(stmt).mappings().all()]
def list_light_tool_calls(self, run_ids: list[str]) -> list[dict[str, Any]]:
if not run_ids:
return []
stmt = (
select(
AgentToolCall.id.label("id"),
AgentToolCall.run_id.label("run_id"),
AgentToolCall.tool_type.label("tool_type"),
AgentToolCall.tool_name.label("tool_name"),
AgentToolCall.status.label("status"),
AgentToolCall.duration_ms.label("duration_ms"),
AgentToolCall.error_message.label("error_message"),
AgentToolCall.created_at.label("created_at"),
)
.where(AgentToolCall.run_id.in_(run_ids))
.order_by(AgentToolCall.created_at.asc())
)
return [dict(item) for item in self.db.execute(stmt).mappings().all()]
def get_by_run_id(self, run_id: str) -> AgentRun | None:
stmt = select(AgentRun).where(AgentRun.run_id == run_id)
return self.db.scalar(stmt)
def create_run(self, run: AgentRun) -> AgentRun:
self.db.add(run)
self.db.commit()
self.db.refresh(run)
return run
def save_run(self, run: AgentRun) -> AgentRun:
self.db.add(run)
self.db.commit()
self.db.refresh(run)
return run
def create_tool_call(self, tool_call: AgentToolCall) -> AgentToolCall:
self.db.add(tool_call)
self.db.commit()
self.db.refresh(tool_call)
return tool_call
def get_tool_call(self, tool_call_id: str) -> AgentToolCall | None:
stmt = select(AgentToolCall).where(AgentToolCall.id == tool_call_id)
return self.db.scalar(stmt)
def save_tool_call(self, tool_call: AgentToolCall) -> AgentToolCall:
self.db.add(tool_call)
self.db.commit()
self.db.refresh(tool_call)
return tool_call
def create_semantic_parse(self, semantic_parse: SemanticParseLog) -> SemanticParseLog:
self.db.add(semantic_parse)
self.db.commit()
self.db.refresh(semantic_parse)
return semantic_parse