Files
X-Financial/server/src/app/services/agent_traces.py
caoxiaozhu 92444e7eae feat: 扩展风险规则体系、审批动态路由与预算中心列表化改造
- 新增 25+ 条风险规则(预算/报销/申请/通用类),完善风险规则模拟与反馈发布机制
- 引入费用审批动态路由、平台风险分级、预审与风险阶段管理
- 预算中心列表化改造,优化票据夹仪表盘与数字员工工作看板
- 新增 Hermes 风险线索收集器、Agent 链路追踪中心
- 扩展数字员工能力库(18 个领域 Skill)与交通费用自动预估
- 完善报销申请快速预览、权限控制与前端测试覆盖
2026-06-01 17:07:14 +08:00

531 lines
20 KiB
Python

from __future__ import annotations
from datetime import UTC, date, datetime
from decimal import Decimal
from typing import Any
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from app.core.logging import get_logger
from app.db.base import Base
from app.models.agent_conversation import AgentConversationMessage
from app.models.agent_run import AgentRun, AgentToolCall, AgentTraceEvent, SemanticParseLog
from app.schemas.agent_run import AgentRunRead, AgentToolCallRead, SemanticParseRead
from app.schemas.agent_trace import (
AgentConversationTraceRead,
AgentTraceDetailRead,
AgentTraceEventRead,
AgentTraceListItem,
)
from app.schemas.orchestrator import ConversationMessageRead
logger = get_logger("app.services.agent_traces")
class AgentTraceService:
def __init__(self, db: Session) -> None:
self.db = db
def ensure_storage_ready(self) -> None:
Base.metadata.create_all(bind=self.db.get_bind(), tables=[AgentTraceEvent.__table__])
def record_event(
self,
*,
run_id: str,
stage: str,
event_name: str,
title: str,
status: str = "succeeded",
conversation_id: str | None = None,
summary: str | None = None,
input_json: dict[str, Any] | None = None,
output_json: dict[str, Any] | None = None,
error_message: str | None = None,
started_at: datetime | None = None,
finished_at: datetime | None = None,
duration_ms: int | None = None,
) -> AgentTraceEventRead:
self.ensure_storage_ready()
started = _normalize_datetime(started_at) or datetime.now(UTC)
finished = _normalize_datetime(finished_at)
if finished is None and status != "running":
finished = started
event = AgentTraceEvent(
run_id=str(run_id or "").strip(),
conversation_id=_optional_text(conversation_id),
sequence=self._next_sequence(run_id),
stage=str(stage or "orchestrator").strip() or "orchestrator",
event_name=str(event_name or "").strip() or "event",
title=str(title or event_name or "Trace event").strip(),
summary=_optional_text(summary),
status=str(status or "succeeded").strip() or "succeeded",
input_json=_json_safe(input_json or {}),
output_json=_json_safe(output_json or {}),
error_message=_optional_text(error_message),
started_at=started,
finished_at=finished,
duration_ms=_resolve_duration_ms(started, finished, duration_ms),
)
self.db.add(event)
self.db.commit()
self.db.refresh(event)
return AgentTraceEventRead.model_validate(event)
def record_event_safe(self, **kwargs: Any) -> AgentTraceEventRead | None:
try:
return self.record_event(**kwargs)
except Exception:
self.db.rollback()
logger.exception("Failed to record agent trace event run_id=%s", kwargs.get("run_id"))
return None
def record_tool_event_safe(
self,
run_id: str,
tool_type: str,
tool_name: str,
request_json: dict[str, Any],
response_json: dict[str, Any],
status: str,
duration_ms: int,
context_json: dict[str, Any],
error_message: str | None = None,
) -> AgentTraceEventRead | None:
return self.record_event_safe(
run_id=run_id,
conversation_id=str(context_json.get("conversation_id") or "").strip() or None,
stage="tool",
event_name="tool_invoked",
title=tool_name,
status=status,
summary=f"{tool_type} / {status}",
input_json=request_json,
output_json=response_json,
error_message=error_message,
duration_ms=duration_ms,
)
def list_traces(
self,
*,
agent: str | None = None,
status: str | None = None,
source: str | None = None,
conversation_id: str | None = None,
keyword: str | None = None,
limit: int = 30,
) -> list[AgentTraceListItem]:
self.ensure_storage_ready()
normalized_limit = max(1, min(int(limit or 30), 100))
fetch_limit = normalized_limit * 4 if keyword else normalized_limit
stmt = select(AgentRun)
if agent:
stmt = stmt.where(AgentRun.agent == agent)
if status:
stmt = stmt.where(AgentRun.status == status)
if source:
stmt = stmt.where(AgentRun.source == source)
if conversation_id:
run_ids = self._conversation_run_ids(conversation_id)
if not run_ids:
return []
stmt = stmt.where(AgentRun.run_id.in_(run_ids))
stmt = stmt.order_by(AgentRun.started_at.desc()).limit(fetch_limit)
runs = list(self.db.scalars(stmt).all())
event_counts = self._event_counts([run.run_id for run in runs])
keyword_text = str(keyword or "").strip().lower()
items = [self._build_trace_list_item(run, event_counts.get(run.run_id, 0)) for run in runs]
if keyword_text:
items = [item for item in items if self._matches_keyword(item, keyword_text)]
return items[:normalized_limit]
def get_trace(self, run_id: str) -> AgentTraceDetailRead | None:
self.ensure_storage_ready()
normalized_run_id = str(run_id or "").strip()
if not normalized_run_id:
return None
run = self.db.scalar(select(AgentRun).where(AgentRun.run_id == normalized_run_id))
if run is None:
return None
db_events = list(
self.db.scalars(
select(AgentTraceEvent)
.where(AgentTraceEvent.run_id == normalized_run_id)
.order_by(AgentTraceEvent.sequence.asc(), AgentTraceEvent.started_at.asc())
).all()
)
conversation_id = self._resolve_conversation_id(run, db_events)
events = [AgentTraceEventRead.model_validate(event) for event in db_events]
fallback_generated = False
if not events:
events = self._build_fallback_events(run, conversation_id)
fallback_generated = True
return AgentTraceDetailRead(
run=self._serialize_run(run),
conversation_id=conversation_id,
events=events,
semantic_parse=self._serialize_semantic_parse(self._first_semantic_parse(run)),
tool_calls=[AgentToolCallRead.model_validate(item) for item in run.tool_calls],
conversation_messages=self._conversation_messages(conversation_id),
fallback_generated=fallback_generated,
)
def get_conversation_trace(self, conversation_id: str) -> AgentConversationTraceRead:
normalized_conversation_id = str(conversation_id or "").strip()
run_ids = self._conversation_run_ids(normalized_conversation_id)
details = []
for run_id in run_ids:
detail = self.get_trace(run_id)
if detail is not None:
details.append(detail)
return AgentConversationTraceRead(
conversation_id=normalized_conversation_id,
runs=details,
)
def _next_sequence(self, run_id: str) -> int:
current = self.db.scalar(
select(func.max(AgentTraceEvent.sequence)).where(AgentTraceEvent.run_id == run_id)
)
return int(current or 0) + 1
def _event_counts(self, run_ids: list[str]) -> dict[str, int]:
if not run_ids:
return {}
rows = self.db.execute(
select(AgentTraceEvent.run_id, func.count(AgentTraceEvent.id))
.where(AgentTraceEvent.run_id.in_(run_ids))
.group_by(AgentTraceEvent.run_id)
).all()
return {str(run_id): int(count or 0) for run_id, count in rows}
def _build_trace_list_item(self, run: AgentRun, event_count: int) -> AgentTraceListItem:
semantic_parse = self._first_semantic_parse(run)
failed_tools = sum(1 for item in run.tool_calls if item.status == "failed")
title = self._resolve_run_title(run, semantic_parse)
finished_at = _normalize_datetime(run.finished_at)
started_at = _normalize_datetime(run.started_at) or datetime.now(UTC)
return AgentTraceListItem(
run_id=run.run_id,
conversation_id=self._resolve_conversation_id(run, []),
agent=run.agent,
source=run.source,
status=run.status,
scenario=semantic_parse.scenario if semantic_parse is not None else None,
intent=semantic_parse.intent if semantic_parse is not None else None,
title=title,
summary=run.result_summary,
event_count=event_count,
tool_call_count=len(run.tool_calls),
failed_tool_call_count=failed_tools,
started_at=started_at,
finished_at=finished_at,
duration_ms=_resolve_duration_ms(started_at, finished_at, None),
)
@staticmethod
def _matches_keyword(item: AgentTraceListItem, keyword: str) -> bool:
corpus = " ".join(
str(value or "")
for value in (
item.run_id,
item.conversation_id,
item.agent,
item.source,
item.status,
item.scenario,
item.intent,
item.title,
item.summary,
)
).lower()
return keyword in corpus
def _resolve_conversation_id(
self,
run: AgentRun,
events: list[AgentTraceEvent],
) -> str | None:
route_value = (run.route_json or {}).get("conversation_id")
if route_value:
return str(route_value).strip() or None
for event in events:
if event.conversation_id:
return str(event.conversation_id).strip() or None
message = self.db.scalar(
select(AgentConversationMessage)
.where(AgentConversationMessage.run_id == run.run_id)
.order_by(AgentConversationMessage.created_at.asc())
)
return str(message.conversation_id).strip() if message is not None else None
def _conversation_run_ids(self, conversation_id: str) -> list[str]:
normalized = str(conversation_id or "").strip()
if not normalized:
return []
self.ensure_storage_ready()
run_ids: list[str] = []
seen: set[str] = set()
def append_run_id(value: str | None) -> None:
run_id = str(value or "").strip()
if run_id and run_id not in seen:
seen.add(run_id)
run_ids.append(run_id)
messages = list(
self.db.scalars(
select(AgentConversationMessage)
.where(AgentConversationMessage.conversation_id == normalized)
.order_by(AgentConversationMessage.created_at.asc())
).all()
)
for message in messages:
append_run_id(message.run_id)
trace_event_run_ids = list(
self.db.scalars(
select(AgentTraceEvent.run_id)
.where(AgentTraceEvent.conversation_id == normalized)
.order_by(AgentTraceEvent.created_at.asc(), AgentTraceEvent.sequence.asc())
).all()
)
for run_id in trace_event_run_ids:
append_run_id(run_id)
recent_runs = list(
self.db.scalars(
select(AgentRun).order_by(AgentRun.started_at.desc()).limit(500)
).all()
)
for run in reversed(recent_runs):
if str((run.route_json or {}).get("conversation_id") or "").strip() == normalized:
append_run_id(run.run_id)
return run_ids
def _conversation_messages(self, conversation_id: str | None) -> list[ConversationMessageRead]:
if not conversation_id:
return []
messages = list(
self.db.scalars(
select(AgentConversationMessage)
.where(AgentConversationMessage.conversation_id == conversation_id)
.order_by(AgentConversationMessage.created_at.asc())
.limit(100)
).all()
)
return [
ConversationMessageRead(
id=item.id,
role=item.role,
content=item.content,
run_id=item.run_id,
message_json=item.message_json or {},
created_at=item.created_at,
)
for item in messages
]
def _build_fallback_events(
self,
run: AgentRun,
conversation_id: str | None,
) -> list[AgentTraceEventRead]:
events: list[AgentTraceEventRead] = []
started_at = _normalize_datetime(run.started_at) or datetime.now(UTC)
semantic_parse = self._first_semantic_parse(run)
def append_event(
*,
stage: str,
event_name: str,
title: str,
status: str,
summary: str | None,
started: datetime,
finished: datetime | None = None,
input_json: dict[str, Any] | None = None,
output_json: dict[str, Any] | None = None,
error_message: str | None = None,
) -> None:
sequence = len(events) + 1
resolved_finished = finished or started
events.append(
AgentTraceEventRead(
id=f"fallback-{run.run_id}-{sequence}",
run_id=run.run_id,
conversation_id=conversation_id,
sequence=sequence,
stage=stage,
event_name=event_name,
title=title,
summary=summary,
status=status,
input_json=_json_safe(input_json or {}),
output_json=_json_safe(output_json or {}),
error_message=error_message,
started_at=started,
finished_at=resolved_finished,
duration_ms=_resolve_duration_ms(started, resolved_finished, None),
created_at=started,
)
)
append_event(
stage="orchestrator",
event_name="run_created",
title="运行记录",
status="succeeded",
summary="由历史 AgentRun 合成的 trace 起点。",
started=started_at,
output_json={"agent": run.agent, "source": run.source, "status": run.status},
)
if semantic_parse is not None:
append_event(
stage="semantic",
event_name="semantic_parsed",
title="语义解析",
status="succeeded",
summary=f"{semantic_parse.scenario} / {semantic_parse.intent}",
started=_normalize_datetime(semantic_parse.created_at) or started_at,
input_json={"raw_query": semantic_parse.raw_query},
output_json=self._semantic_parse_payload(semantic_parse),
)
if run.route_json:
append_event(
stage="route",
event_name="route_resolved",
title="路由上下文",
status="succeeded",
summary=str(run.route_json.get("route_reason") or run.route_json.get("stage") or "已记录路由信息"),
started=started_at,
output_json=run.route_json,
)
for tool_call in run.tool_calls:
append_event(
stage="tool",
event_name="tool_invoked",
title=tool_call.tool_name,
status=tool_call.status,
summary=f"{tool_call.tool_type} / {tool_call.status}",
started=_normalize_datetime(tool_call.created_at) or started_at,
finished=_normalize_datetime(tool_call.created_at) or started_at,
input_json=tool_call.request_json,
output_json=tool_call.response_json,
error_message=tool_call.error_message,
)
append_event(
stage="response",
event_name="response_built" if run.status != "failed" else "failed",
title="最终结果",
status=run.status,
summary=run.result_summary or run.error_message,
started=_normalize_datetime(run.finished_at) or started_at,
output_json={"result_summary": run.result_summary},
error_message=run.error_message,
)
return events
@staticmethod
def _resolve_run_title(run: AgentRun, semantic_parse: SemanticParseLog | None) -> str:
if semantic_parse is not None:
return f"{semantic_parse.scenario} / {semantic_parse.intent}"
route_json = run.route_json or {}
return str(route_json.get("task_name") or route_json.get("selected_agent") or run.agent)
@staticmethod
def _first_semantic_parse(run: AgentRun) -> SemanticParseLog | None:
return run.semantic_parse_logs[0] if run.semantic_parse_logs else None
@staticmethod
def _serialize_semantic_parse(item: SemanticParseLog | None) -> SemanticParseRead | None:
return SemanticParseRead.model_validate(item) if item is not None else None
@staticmethod
def _serialize_run(run: AgentRun) -> AgentRunRead:
semantic_parse = AgentTraceService._first_semantic_parse(run)
return AgentRunRead(
id=run.id,
run_id=run.run_id,
agent=run.agent,
source=run.source,
user_id=run.user_id,
task_id=run.task_id,
ontology_json=run.ontology_json,
route_json=run.route_json,
permission_level=run.permission_level,
status=run.status,
result_summary=run.result_summary,
error_message=run.error_message,
started_at=run.started_at,
finished_at=run.finished_at,
tool_calls=[AgentToolCallRead.model_validate(item) for item in run.tool_calls],
semantic_parse=SemanticParseRead.model_validate(semantic_parse)
if semantic_parse is not None
else None,
)
@staticmethod
def _semantic_parse_payload(item: SemanticParseLog) -> dict[str, Any]:
return {
"scenario": item.scenario,
"intent": item.intent,
"confidence": item.confidence,
"entities": item.entities_json,
"time_range": item.time_range_json,
"metrics": item.metrics_json,
"constraints": item.constraints_json,
"risk_flags": item.risk_flags_json,
"permission": item.permission_json,
}
def _resolve_duration_ms(
started_at: datetime | None,
finished_at: datetime | None,
duration_ms: int | None,
) -> int:
if duration_ms is not None:
return max(0, int(duration_ms or 0))
if started_at is None or finished_at is None:
return 0
try:
return max(0, int((finished_at - started_at).total_seconds() * 1000))
except TypeError:
return 0
def _normalize_datetime(value: datetime | None) -> datetime | None:
if value is None:
return None
if value.tzinfo is None:
return value.replace(tzinfo=UTC)
return value
def _optional_text(value: Any) -> str | None:
text = str(value or "").strip()
return text or None
def _json_safe(value: Any) -> Any:
if value is None or isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, Decimal):
return str(value)
if isinstance(value, (datetime, date)):
return value.isoformat()
if isinstance(value, dict):
return {str(key): _json_safe(item) for key, item in value.items()}
if isinstance(value, (list, tuple, set)):
return [_json_safe(item) for item in value]
if hasattr(value, "model_dump"):
return _json_safe(value.model_dump())
return str(value)