feat: 扩展风险规则体系、审批动态路由与预算中心列表化改造
- 新增 25+ 条风险规则(预算/报销/申请/通用类),完善风险规则模拟与反馈发布机制 - 引入费用审批动态路由、平台风险分级、预审与风险阶段管理 - 预算中心列表化改造,优化票据夹仪表盘与数字员工工作看板 - 新增 Hermes 风险线索收集器、Agent 链路追踪中心 - 扩展数字员工能力库(18 个领域 Skill)与交通费用自动预估 - 完善报销申请快速预览、权限控制与前端测试覆盖
This commit is contained in:
530
server/src/app/services/agent_traces.py
Normal file
530
server/src/app/services/agent_traces.py
Normal file
@@ -0,0 +1,530 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, date, datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.db.base import Base
|
||||
from app.models.agent_conversation import AgentConversationMessage
|
||||
from app.models.agent_run import AgentRun, AgentToolCall, AgentTraceEvent, SemanticParseLog
|
||||
from app.schemas.agent_run import AgentRunRead, AgentToolCallRead, SemanticParseRead
|
||||
from app.schemas.agent_trace import (
|
||||
AgentConversationTraceRead,
|
||||
AgentTraceDetailRead,
|
||||
AgentTraceEventRead,
|
||||
AgentTraceListItem,
|
||||
)
|
||||
from app.schemas.orchestrator import ConversationMessageRead
|
||||
|
||||
logger = get_logger("app.services.agent_traces")
|
||||
|
||||
|
||||
class AgentTraceService:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def ensure_storage_ready(self) -> None:
|
||||
Base.metadata.create_all(bind=self.db.get_bind(), tables=[AgentTraceEvent.__table__])
|
||||
|
||||
def record_event(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
stage: str,
|
||||
event_name: str,
|
||||
title: str,
|
||||
status: str = "succeeded",
|
||||
conversation_id: str | None = None,
|
||||
summary: str | None = None,
|
||||
input_json: dict[str, Any] | None = None,
|
||||
output_json: dict[str, Any] | None = None,
|
||||
error_message: str | None = None,
|
||||
started_at: datetime | None = None,
|
||||
finished_at: datetime | None = None,
|
||||
duration_ms: int | None = None,
|
||||
) -> AgentTraceEventRead:
|
||||
self.ensure_storage_ready()
|
||||
started = _normalize_datetime(started_at) or datetime.now(UTC)
|
||||
finished = _normalize_datetime(finished_at)
|
||||
if finished is None and status != "running":
|
||||
finished = started
|
||||
|
||||
event = AgentTraceEvent(
|
||||
run_id=str(run_id or "").strip(),
|
||||
conversation_id=_optional_text(conversation_id),
|
||||
sequence=self._next_sequence(run_id),
|
||||
stage=str(stage or "orchestrator").strip() or "orchestrator",
|
||||
event_name=str(event_name or "").strip() or "event",
|
||||
title=str(title or event_name or "Trace event").strip(),
|
||||
summary=_optional_text(summary),
|
||||
status=str(status or "succeeded").strip() or "succeeded",
|
||||
input_json=_json_safe(input_json or {}),
|
||||
output_json=_json_safe(output_json or {}),
|
||||
error_message=_optional_text(error_message),
|
||||
started_at=started,
|
||||
finished_at=finished,
|
||||
duration_ms=_resolve_duration_ms(started, finished, duration_ms),
|
||||
)
|
||||
self.db.add(event)
|
||||
self.db.commit()
|
||||
self.db.refresh(event)
|
||||
return AgentTraceEventRead.model_validate(event)
|
||||
|
||||
def record_event_safe(self, **kwargs: Any) -> AgentTraceEventRead | None:
|
||||
try:
|
||||
return self.record_event(**kwargs)
|
||||
except Exception:
|
||||
self.db.rollback()
|
||||
logger.exception("Failed to record agent trace event run_id=%s", kwargs.get("run_id"))
|
||||
return None
|
||||
|
||||
def record_tool_event_safe(
|
||||
self,
|
||||
run_id: str,
|
||||
tool_type: str,
|
||||
tool_name: str,
|
||||
request_json: dict[str, Any],
|
||||
response_json: dict[str, Any],
|
||||
status: str,
|
||||
duration_ms: int,
|
||||
context_json: dict[str, Any],
|
||||
error_message: str | None = None,
|
||||
) -> AgentTraceEventRead | None:
|
||||
return self.record_event_safe(
|
||||
run_id=run_id,
|
||||
conversation_id=str(context_json.get("conversation_id") or "").strip() or None,
|
||||
stage="tool",
|
||||
event_name="tool_invoked",
|
||||
title=tool_name,
|
||||
status=status,
|
||||
summary=f"{tool_type} / {status}",
|
||||
input_json=request_json,
|
||||
output_json=response_json,
|
||||
error_message=error_message,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
def list_traces(
|
||||
self,
|
||||
*,
|
||||
agent: str | None = None,
|
||||
status: str | None = None,
|
||||
source: str | None = None,
|
||||
conversation_id: str | None = None,
|
||||
keyword: str | None = None,
|
||||
limit: int = 30,
|
||||
) -> list[AgentTraceListItem]:
|
||||
self.ensure_storage_ready()
|
||||
normalized_limit = max(1, min(int(limit or 30), 100))
|
||||
fetch_limit = normalized_limit * 4 if keyword else normalized_limit
|
||||
stmt = select(AgentRun)
|
||||
if agent:
|
||||
stmt = stmt.where(AgentRun.agent == agent)
|
||||
if status:
|
||||
stmt = stmt.where(AgentRun.status == status)
|
||||
if source:
|
||||
stmt = stmt.where(AgentRun.source == source)
|
||||
if conversation_id:
|
||||
run_ids = self._conversation_run_ids(conversation_id)
|
||||
if not run_ids:
|
||||
return []
|
||||
stmt = stmt.where(AgentRun.run_id.in_(run_ids))
|
||||
stmt = stmt.order_by(AgentRun.started_at.desc()).limit(fetch_limit)
|
||||
runs = list(self.db.scalars(stmt).all())
|
||||
event_counts = self._event_counts([run.run_id for run in runs])
|
||||
keyword_text = str(keyword or "").strip().lower()
|
||||
items = [self._build_trace_list_item(run, event_counts.get(run.run_id, 0)) for run in runs]
|
||||
if keyword_text:
|
||||
items = [item for item in items if self._matches_keyword(item, keyword_text)]
|
||||
return items[:normalized_limit]
|
||||
|
||||
def get_trace(self, run_id: str) -> AgentTraceDetailRead | None:
|
||||
self.ensure_storage_ready()
|
||||
normalized_run_id = str(run_id or "").strip()
|
||||
if not normalized_run_id:
|
||||
return None
|
||||
|
||||
run = self.db.scalar(select(AgentRun).where(AgentRun.run_id == normalized_run_id))
|
||||
if run is None:
|
||||
return None
|
||||
|
||||
db_events = list(
|
||||
self.db.scalars(
|
||||
select(AgentTraceEvent)
|
||||
.where(AgentTraceEvent.run_id == normalized_run_id)
|
||||
.order_by(AgentTraceEvent.sequence.asc(), AgentTraceEvent.started_at.asc())
|
||||
).all()
|
||||
)
|
||||
conversation_id = self._resolve_conversation_id(run, db_events)
|
||||
events = [AgentTraceEventRead.model_validate(event) for event in db_events]
|
||||
fallback_generated = False
|
||||
if not events:
|
||||
events = self._build_fallback_events(run, conversation_id)
|
||||
fallback_generated = True
|
||||
|
||||
return AgentTraceDetailRead(
|
||||
run=self._serialize_run(run),
|
||||
conversation_id=conversation_id,
|
||||
events=events,
|
||||
semantic_parse=self._serialize_semantic_parse(self._first_semantic_parse(run)),
|
||||
tool_calls=[AgentToolCallRead.model_validate(item) for item in run.tool_calls],
|
||||
conversation_messages=self._conversation_messages(conversation_id),
|
||||
fallback_generated=fallback_generated,
|
||||
)
|
||||
|
||||
def get_conversation_trace(self, conversation_id: str) -> AgentConversationTraceRead:
|
||||
normalized_conversation_id = str(conversation_id or "").strip()
|
||||
run_ids = self._conversation_run_ids(normalized_conversation_id)
|
||||
details = []
|
||||
for run_id in run_ids:
|
||||
detail = self.get_trace(run_id)
|
||||
if detail is not None:
|
||||
details.append(detail)
|
||||
return AgentConversationTraceRead(
|
||||
conversation_id=normalized_conversation_id,
|
||||
runs=details,
|
||||
)
|
||||
|
||||
def _next_sequence(self, run_id: str) -> int:
|
||||
current = self.db.scalar(
|
||||
select(func.max(AgentTraceEvent.sequence)).where(AgentTraceEvent.run_id == run_id)
|
||||
)
|
||||
return int(current or 0) + 1
|
||||
|
||||
def _event_counts(self, run_ids: list[str]) -> dict[str, int]:
|
||||
if not run_ids:
|
||||
return {}
|
||||
rows = self.db.execute(
|
||||
select(AgentTraceEvent.run_id, func.count(AgentTraceEvent.id))
|
||||
.where(AgentTraceEvent.run_id.in_(run_ids))
|
||||
.group_by(AgentTraceEvent.run_id)
|
||||
).all()
|
||||
return {str(run_id): int(count or 0) for run_id, count in rows}
|
||||
|
||||
def _build_trace_list_item(self, run: AgentRun, event_count: int) -> AgentTraceListItem:
|
||||
semantic_parse = self._first_semantic_parse(run)
|
||||
failed_tools = sum(1 for item in run.tool_calls if item.status == "failed")
|
||||
title = self._resolve_run_title(run, semantic_parse)
|
||||
finished_at = _normalize_datetime(run.finished_at)
|
||||
started_at = _normalize_datetime(run.started_at) or datetime.now(UTC)
|
||||
return AgentTraceListItem(
|
||||
run_id=run.run_id,
|
||||
conversation_id=self._resolve_conversation_id(run, []),
|
||||
agent=run.agent,
|
||||
source=run.source,
|
||||
status=run.status,
|
||||
scenario=semantic_parse.scenario if semantic_parse is not None else None,
|
||||
intent=semantic_parse.intent if semantic_parse is not None else None,
|
||||
title=title,
|
||||
summary=run.result_summary,
|
||||
event_count=event_count,
|
||||
tool_call_count=len(run.tool_calls),
|
||||
failed_tool_call_count=failed_tools,
|
||||
started_at=started_at,
|
||||
finished_at=finished_at,
|
||||
duration_ms=_resolve_duration_ms(started_at, finished_at, None),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _matches_keyword(item: AgentTraceListItem, keyword: str) -> bool:
|
||||
corpus = " ".join(
|
||||
str(value or "")
|
||||
for value in (
|
||||
item.run_id,
|
||||
item.conversation_id,
|
||||
item.agent,
|
||||
item.source,
|
||||
item.status,
|
||||
item.scenario,
|
||||
item.intent,
|
||||
item.title,
|
||||
item.summary,
|
||||
)
|
||||
).lower()
|
||||
return keyword in corpus
|
||||
|
||||
def _resolve_conversation_id(
|
||||
self,
|
||||
run: AgentRun,
|
||||
events: list[AgentTraceEvent],
|
||||
) -> str | None:
|
||||
route_value = (run.route_json or {}).get("conversation_id")
|
||||
if route_value:
|
||||
return str(route_value).strip() or None
|
||||
for event in events:
|
||||
if event.conversation_id:
|
||||
return str(event.conversation_id).strip() or None
|
||||
message = self.db.scalar(
|
||||
select(AgentConversationMessage)
|
||||
.where(AgentConversationMessage.run_id == run.run_id)
|
||||
.order_by(AgentConversationMessage.created_at.asc())
|
||||
)
|
||||
return str(message.conversation_id).strip() if message is not None else None
|
||||
|
||||
def _conversation_run_ids(self, conversation_id: str) -> list[str]:
|
||||
normalized = str(conversation_id or "").strip()
|
||||
if not normalized:
|
||||
return []
|
||||
self.ensure_storage_ready()
|
||||
run_ids: list[str] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
def append_run_id(value: str | None) -> None:
|
||||
run_id = str(value or "").strip()
|
||||
if run_id and run_id not in seen:
|
||||
seen.add(run_id)
|
||||
run_ids.append(run_id)
|
||||
|
||||
messages = list(
|
||||
self.db.scalars(
|
||||
select(AgentConversationMessage)
|
||||
.where(AgentConversationMessage.conversation_id == normalized)
|
||||
.order_by(AgentConversationMessage.created_at.asc())
|
||||
).all()
|
||||
)
|
||||
for message in messages:
|
||||
append_run_id(message.run_id)
|
||||
|
||||
trace_event_run_ids = list(
|
||||
self.db.scalars(
|
||||
select(AgentTraceEvent.run_id)
|
||||
.where(AgentTraceEvent.conversation_id == normalized)
|
||||
.order_by(AgentTraceEvent.created_at.asc(), AgentTraceEvent.sequence.asc())
|
||||
).all()
|
||||
)
|
||||
for run_id in trace_event_run_ids:
|
||||
append_run_id(run_id)
|
||||
|
||||
recent_runs = list(
|
||||
self.db.scalars(
|
||||
select(AgentRun).order_by(AgentRun.started_at.desc()).limit(500)
|
||||
).all()
|
||||
)
|
||||
for run in reversed(recent_runs):
|
||||
if str((run.route_json or {}).get("conversation_id") or "").strip() == normalized:
|
||||
append_run_id(run.run_id)
|
||||
return run_ids
|
||||
|
||||
def _conversation_messages(self, conversation_id: str | None) -> list[ConversationMessageRead]:
|
||||
if not conversation_id:
|
||||
return []
|
||||
messages = list(
|
||||
self.db.scalars(
|
||||
select(AgentConversationMessage)
|
||||
.where(AgentConversationMessage.conversation_id == conversation_id)
|
||||
.order_by(AgentConversationMessage.created_at.asc())
|
||||
.limit(100)
|
||||
).all()
|
||||
)
|
||||
return [
|
||||
ConversationMessageRead(
|
||||
id=item.id,
|
||||
role=item.role,
|
||||
content=item.content,
|
||||
run_id=item.run_id,
|
||||
message_json=item.message_json or {},
|
||||
created_at=item.created_at,
|
||||
)
|
||||
for item in messages
|
||||
]
|
||||
|
||||
def _build_fallback_events(
|
||||
self,
|
||||
run: AgentRun,
|
||||
conversation_id: str | None,
|
||||
) -> list[AgentTraceEventRead]:
|
||||
events: list[AgentTraceEventRead] = []
|
||||
started_at = _normalize_datetime(run.started_at) or datetime.now(UTC)
|
||||
semantic_parse = self._first_semantic_parse(run)
|
||||
|
||||
def append_event(
|
||||
*,
|
||||
stage: str,
|
||||
event_name: str,
|
||||
title: str,
|
||||
status: str,
|
||||
summary: str | None,
|
||||
started: datetime,
|
||||
finished: datetime | None = None,
|
||||
input_json: dict[str, Any] | None = None,
|
||||
output_json: dict[str, Any] | None = None,
|
||||
error_message: str | None = None,
|
||||
) -> None:
|
||||
sequence = len(events) + 1
|
||||
resolved_finished = finished or started
|
||||
events.append(
|
||||
AgentTraceEventRead(
|
||||
id=f"fallback-{run.run_id}-{sequence}",
|
||||
run_id=run.run_id,
|
||||
conversation_id=conversation_id,
|
||||
sequence=sequence,
|
||||
stage=stage,
|
||||
event_name=event_name,
|
||||
title=title,
|
||||
summary=summary,
|
||||
status=status,
|
||||
input_json=_json_safe(input_json or {}),
|
||||
output_json=_json_safe(output_json or {}),
|
||||
error_message=error_message,
|
||||
started_at=started,
|
||||
finished_at=resolved_finished,
|
||||
duration_ms=_resolve_duration_ms(started, resolved_finished, None),
|
||||
created_at=started,
|
||||
)
|
||||
)
|
||||
|
||||
append_event(
|
||||
stage="orchestrator",
|
||||
event_name="run_created",
|
||||
title="运行记录",
|
||||
status="succeeded",
|
||||
summary="由历史 AgentRun 合成的 trace 起点。",
|
||||
started=started_at,
|
||||
output_json={"agent": run.agent, "source": run.source, "status": run.status},
|
||||
)
|
||||
if semantic_parse is not None:
|
||||
append_event(
|
||||
stage="semantic",
|
||||
event_name="semantic_parsed",
|
||||
title="语义解析",
|
||||
status="succeeded",
|
||||
summary=f"{semantic_parse.scenario} / {semantic_parse.intent}",
|
||||
started=_normalize_datetime(semantic_parse.created_at) or started_at,
|
||||
input_json={"raw_query": semantic_parse.raw_query},
|
||||
output_json=self._semantic_parse_payload(semantic_parse),
|
||||
)
|
||||
if run.route_json:
|
||||
append_event(
|
||||
stage="route",
|
||||
event_name="route_resolved",
|
||||
title="路由上下文",
|
||||
status="succeeded",
|
||||
summary=str(run.route_json.get("route_reason") or run.route_json.get("stage") or "已记录路由信息"),
|
||||
started=started_at,
|
||||
output_json=run.route_json,
|
||||
)
|
||||
for tool_call in run.tool_calls:
|
||||
append_event(
|
||||
stage="tool",
|
||||
event_name="tool_invoked",
|
||||
title=tool_call.tool_name,
|
||||
status=tool_call.status,
|
||||
summary=f"{tool_call.tool_type} / {tool_call.status}",
|
||||
started=_normalize_datetime(tool_call.created_at) or started_at,
|
||||
finished=_normalize_datetime(tool_call.created_at) or started_at,
|
||||
input_json=tool_call.request_json,
|
||||
output_json=tool_call.response_json,
|
||||
error_message=tool_call.error_message,
|
||||
)
|
||||
append_event(
|
||||
stage="response",
|
||||
event_name="response_built" if run.status != "failed" else "failed",
|
||||
title="最终结果",
|
||||
status=run.status,
|
||||
summary=run.result_summary or run.error_message,
|
||||
started=_normalize_datetime(run.finished_at) or started_at,
|
||||
output_json={"result_summary": run.result_summary},
|
||||
error_message=run.error_message,
|
||||
)
|
||||
return events
|
||||
|
||||
@staticmethod
|
||||
def _resolve_run_title(run: AgentRun, semantic_parse: SemanticParseLog | None) -> str:
|
||||
if semantic_parse is not None:
|
||||
return f"{semantic_parse.scenario} / {semantic_parse.intent}"
|
||||
route_json = run.route_json or {}
|
||||
return str(route_json.get("task_name") or route_json.get("selected_agent") or run.agent)
|
||||
|
||||
@staticmethod
|
||||
def _first_semantic_parse(run: AgentRun) -> SemanticParseLog | None:
|
||||
return run.semantic_parse_logs[0] if run.semantic_parse_logs else None
|
||||
|
||||
@staticmethod
|
||||
def _serialize_semantic_parse(item: SemanticParseLog | None) -> SemanticParseRead | None:
|
||||
return SemanticParseRead.model_validate(item) if item is not None else None
|
||||
|
||||
@staticmethod
|
||||
def _serialize_run(run: AgentRun) -> AgentRunRead:
|
||||
semantic_parse = AgentTraceService._first_semantic_parse(run)
|
||||
return AgentRunRead(
|
||||
id=run.id,
|
||||
run_id=run.run_id,
|
||||
agent=run.agent,
|
||||
source=run.source,
|
||||
user_id=run.user_id,
|
||||
task_id=run.task_id,
|
||||
ontology_json=run.ontology_json,
|
||||
route_json=run.route_json,
|
||||
permission_level=run.permission_level,
|
||||
status=run.status,
|
||||
result_summary=run.result_summary,
|
||||
error_message=run.error_message,
|
||||
started_at=run.started_at,
|
||||
finished_at=run.finished_at,
|
||||
tool_calls=[AgentToolCallRead.model_validate(item) for item in run.tool_calls],
|
||||
semantic_parse=SemanticParseRead.model_validate(semantic_parse)
|
||||
if semantic_parse is not None
|
||||
else None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _semantic_parse_payload(item: SemanticParseLog) -> dict[str, Any]:
|
||||
return {
|
||||
"scenario": item.scenario,
|
||||
"intent": item.intent,
|
||||
"confidence": item.confidence,
|
||||
"entities": item.entities_json,
|
||||
"time_range": item.time_range_json,
|
||||
"metrics": item.metrics_json,
|
||||
"constraints": item.constraints_json,
|
||||
"risk_flags": item.risk_flags_json,
|
||||
"permission": item.permission_json,
|
||||
}
|
||||
|
||||
|
||||
def _resolve_duration_ms(
|
||||
started_at: datetime | None,
|
||||
finished_at: datetime | None,
|
||||
duration_ms: int | None,
|
||||
) -> int:
|
||||
if duration_ms is not None:
|
||||
return max(0, int(duration_ms or 0))
|
||||
if started_at is None or finished_at is None:
|
||||
return 0
|
||||
try:
|
||||
return max(0, int((finished_at - started_at).total_seconds() * 1000))
|
||||
except TypeError:
|
||||
return 0
|
||||
|
||||
|
||||
def _normalize_datetime(value: datetime | None) -> datetime | None:
|
||||
if value is None:
|
||||
return None
|
||||
if value.tzinfo is None:
|
||||
return value.replace(tzinfo=UTC)
|
||||
return value
|
||||
|
||||
|
||||
def _optional_text(value: Any) -> str | None:
|
||||
text = str(value or "").strip()
|
||||
return text or None
|
||||
|
||||
|
||||
def _json_safe(value: Any) -> Any:
|
||||
if value is None or isinstance(value, (str, int, float, bool)):
|
||||
return value
|
||||
if isinstance(value, Decimal):
|
||||
return str(value)
|
||||
if isinstance(value, (datetime, date)):
|
||||
return value.isoformat()
|
||||
if isinstance(value, dict):
|
||||
return {str(key): _json_safe(item) for key, item in value.items()}
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
return [_json_safe(item) for item in value]
|
||||
if hasattr(value, "model_dump"):
|
||||
return _json_safe(value.model_dump())
|
||||
return str(value)
|
||||
Reference in New Issue
Block a user