from __future__ import annotations from datetime import UTC, date, datetime from decimal import Decimal from typing import Any from sqlalchemy import func, select from sqlalchemy.orm import Session from app.core.logging import get_logger from app.db.base import Base from app.models.agent_conversation import AgentConversationMessage from app.models.agent_run import AgentRun, AgentToolCall, AgentTraceEvent, SemanticParseLog from app.schemas.agent_run import AgentRunRead, AgentToolCallRead, SemanticParseRead from app.schemas.agent_trace import ( AgentConversationTraceRead, AgentTraceDetailRead, AgentTraceEventRead, AgentTraceListItem, ) from app.schemas.orchestrator import ConversationMessageRead logger = get_logger("app.services.agent_traces") class AgentTraceService: def __init__(self, db: Session) -> None: self.db = db def ensure_storage_ready(self) -> None: Base.metadata.create_all(bind=self.db.get_bind(), tables=[AgentTraceEvent.__table__]) def record_event( self, *, run_id: str, stage: str, event_name: str, title: str, status: str = "succeeded", conversation_id: str | None = None, summary: str | None = None, input_json: dict[str, Any] | None = None, output_json: dict[str, Any] | None = None, error_message: str | None = None, started_at: datetime | None = None, finished_at: datetime | None = None, duration_ms: int | None = None, ) -> AgentTraceEventRead: self.ensure_storage_ready() started = _normalize_datetime(started_at) or datetime.now(UTC) finished = _normalize_datetime(finished_at) if finished is None and status != "running": finished = started event = AgentTraceEvent( run_id=str(run_id or "").strip(), conversation_id=_optional_text(conversation_id), sequence=self._next_sequence(run_id), stage=str(stage or "orchestrator").strip() or "orchestrator", event_name=str(event_name or "").strip() or "event", title=str(title or event_name or "Trace event").strip(), summary=_optional_text(summary), status=str(status or "succeeded").strip() or "succeeded", input_json=_json_safe(input_json or {}), output_json=_json_safe(output_json or {}), error_message=_optional_text(error_message), started_at=started, finished_at=finished, duration_ms=_resolve_duration_ms(started, finished, duration_ms), ) self.db.add(event) self.db.commit() self.db.refresh(event) return AgentTraceEventRead.model_validate(event) def record_event_safe(self, **kwargs: Any) -> AgentTraceEventRead | None: try: return self.record_event(**kwargs) except Exception: self.db.rollback() logger.exception("Failed to record agent trace event run_id=%s", kwargs.get("run_id")) return None def record_tool_event_safe( self, run_id: str, tool_type: str, tool_name: str, request_json: dict[str, Any], response_json: dict[str, Any], status: str, duration_ms: int, context_json: dict[str, Any], error_message: str | None = None, ) -> AgentTraceEventRead | None: return self.record_event_safe( run_id=run_id, conversation_id=str(context_json.get("conversation_id") or "").strip() or None, stage="tool", event_name="tool_invoked", title=tool_name, status=status, summary=f"{tool_type} / {status}", input_json=request_json, output_json=response_json, error_message=error_message, duration_ms=duration_ms, ) def list_traces( self, *, agent: str | None = None, status: str | None = None, source: str | None = None, conversation_id: str | None = None, keyword: str | None = None, limit: int = 30, ) -> list[AgentTraceListItem]: self.ensure_storage_ready() normalized_limit = max(1, min(int(limit or 30), 100)) fetch_limit = normalized_limit * 4 if keyword else normalized_limit stmt = select(AgentRun) if agent: stmt = stmt.where(AgentRun.agent == agent) if status: stmt = stmt.where(AgentRun.status == status) if source: stmt = stmt.where(AgentRun.source == source) if conversation_id: run_ids = self._conversation_run_ids(conversation_id) if not run_ids: return [] stmt = stmt.where(AgentRun.run_id.in_(run_ids)) stmt = stmt.order_by(AgentRun.started_at.desc()).limit(fetch_limit) runs = list(self.db.scalars(stmt).all()) event_counts = self._event_counts([run.run_id for run in runs]) keyword_text = str(keyword or "").strip().lower() items = [self._build_trace_list_item(run, event_counts.get(run.run_id, 0)) for run in runs] if keyword_text: items = [item for item in items if self._matches_keyword(item, keyword_text)] return items[:normalized_limit] def get_trace(self, run_id: str) -> AgentTraceDetailRead | None: self.ensure_storage_ready() normalized_run_id = str(run_id or "").strip() if not normalized_run_id: return None run = self.db.scalar(select(AgentRun).where(AgentRun.run_id == normalized_run_id)) if run is None: return None db_events = list( self.db.scalars( select(AgentTraceEvent) .where(AgentTraceEvent.run_id == normalized_run_id) .order_by(AgentTraceEvent.sequence.asc(), AgentTraceEvent.started_at.asc()) ).all() ) conversation_id = self._resolve_conversation_id(run, db_events) events = [AgentTraceEventRead.model_validate(event) for event in db_events] fallback_generated = False if not events: events = self._build_fallback_events(run, conversation_id) fallback_generated = True return AgentTraceDetailRead( run=self._serialize_run(run), conversation_id=conversation_id, events=events, semantic_parse=self._serialize_semantic_parse(self._first_semantic_parse(run)), tool_calls=[AgentToolCallRead.model_validate(item) for item in run.tool_calls], conversation_messages=self._conversation_messages(conversation_id), fallback_generated=fallback_generated, ) def get_conversation_trace(self, conversation_id: str) -> AgentConversationTraceRead: normalized_conversation_id = str(conversation_id or "").strip() run_ids = self._conversation_run_ids(normalized_conversation_id) details = [] for run_id in run_ids: detail = self.get_trace(run_id) if detail is not None: details.append(detail) return AgentConversationTraceRead( conversation_id=normalized_conversation_id, runs=details, ) def _next_sequence(self, run_id: str) -> int: current = self.db.scalar( select(func.max(AgentTraceEvent.sequence)).where(AgentTraceEvent.run_id == run_id) ) return int(current or 0) + 1 def _event_counts(self, run_ids: list[str]) -> dict[str, int]: if not run_ids: return {} rows = self.db.execute( select(AgentTraceEvent.run_id, func.count(AgentTraceEvent.id)) .where(AgentTraceEvent.run_id.in_(run_ids)) .group_by(AgentTraceEvent.run_id) ).all() return {str(run_id): int(count or 0) for run_id, count in rows} def _build_trace_list_item(self, run: AgentRun, event_count: int) -> AgentTraceListItem: semantic_parse = self._first_semantic_parse(run) failed_tools = sum(1 for item in run.tool_calls if item.status == "failed") title = self._resolve_run_title(run, semantic_parse) finished_at = _normalize_datetime(run.finished_at) started_at = _normalize_datetime(run.started_at) or datetime.now(UTC) return AgentTraceListItem( run_id=run.run_id, conversation_id=self._resolve_conversation_id(run, []), agent=run.agent, source=run.source, status=run.status, scenario=semantic_parse.scenario if semantic_parse is not None else None, intent=semantic_parse.intent if semantic_parse is not None else None, title=title, summary=run.result_summary, event_count=event_count, tool_call_count=len(run.tool_calls), failed_tool_call_count=failed_tools, started_at=started_at, finished_at=finished_at, duration_ms=_resolve_duration_ms(started_at, finished_at, None), ) @staticmethod def _matches_keyword(item: AgentTraceListItem, keyword: str) -> bool: corpus = " ".join( str(value or "") for value in ( item.run_id, item.conversation_id, item.agent, item.source, item.status, item.scenario, item.intent, item.title, item.summary, ) ).lower() return keyword in corpus def _resolve_conversation_id( self, run: AgentRun, events: list[AgentTraceEvent], ) -> str | None: route_value = (run.route_json or {}).get("conversation_id") if route_value: return str(route_value).strip() or None for event in events: if event.conversation_id: return str(event.conversation_id).strip() or None message = self.db.scalar( select(AgentConversationMessage) .where(AgentConversationMessage.run_id == run.run_id) .order_by(AgentConversationMessage.created_at.asc()) ) return str(message.conversation_id).strip() if message is not None else None def _conversation_run_ids(self, conversation_id: str) -> list[str]: normalized = str(conversation_id or "").strip() if not normalized: return [] self.ensure_storage_ready() run_ids: list[str] = [] seen: set[str] = set() def append_run_id(value: str | None) -> None: run_id = str(value or "").strip() if run_id and run_id not in seen: seen.add(run_id) run_ids.append(run_id) messages = list( self.db.scalars( select(AgentConversationMessage) .where(AgentConversationMessage.conversation_id == normalized) .order_by(AgentConversationMessage.created_at.asc()) ).all() ) for message in messages: append_run_id(message.run_id) trace_event_run_ids = list( self.db.scalars( select(AgentTraceEvent.run_id) .where(AgentTraceEvent.conversation_id == normalized) .order_by(AgentTraceEvent.created_at.asc(), AgentTraceEvent.sequence.asc()) ).all() ) for run_id in trace_event_run_ids: append_run_id(run_id) recent_runs = list( self.db.scalars( select(AgentRun).order_by(AgentRun.started_at.desc()).limit(500) ).all() ) for run in reversed(recent_runs): if str((run.route_json or {}).get("conversation_id") or "").strip() == normalized: append_run_id(run.run_id) return run_ids def _conversation_messages(self, conversation_id: str | None) -> list[ConversationMessageRead]: if not conversation_id: return [] messages = list( self.db.scalars( select(AgentConversationMessage) .where(AgentConversationMessage.conversation_id == conversation_id) .order_by(AgentConversationMessage.created_at.asc()) .limit(100) ).all() ) return [ ConversationMessageRead( id=item.id, role=item.role, content=item.content, run_id=item.run_id, message_json=item.message_json or {}, created_at=item.created_at, ) for item in messages ] def _build_fallback_events( self, run: AgentRun, conversation_id: str | None, ) -> list[AgentTraceEventRead]: events: list[AgentTraceEventRead] = [] started_at = _normalize_datetime(run.started_at) or datetime.now(UTC) semantic_parse = self._first_semantic_parse(run) def append_event( *, stage: str, event_name: str, title: str, status: str, summary: str | None, started: datetime, finished: datetime | None = None, input_json: dict[str, Any] | None = None, output_json: dict[str, Any] | None = None, error_message: str | None = None, ) -> None: sequence = len(events) + 1 resolved_finished = finished or started events.append( AgentTraceEventRead( id=f"fallback-{run.run_id}-{sequence}", run_id=run.run_id, conversation_id=conversation_id, sequence=sequence, stage=stage, event_name=event_name, title=title, summary=summary, status=status, input_json=_json_safe(input_json or {}), output_json=_json_safe(output_json or {}), error_message=error_message, started_at=started, finished_at=resolved_finished, duration_ms=_resolve_duration_ms(started, resolved_finished, None), created_at=started, ) ) append_event( stage="orchestrator", event_name="run_created", title="运行记录", status="succeeded", summary="由历史 AgentRun 合成的 trace 起点。", started=started_at, output_json={"agent": run.agent, "source": run.source, "status": run.status}, ) if semantic_parse is not None: append_event( stage="semantic", event_name="semantic_parsed", title="语义解析", status="succeeded", summary=f"{semantic_parse.scenario} / {semantic_parse.intent}", started=_normalize_datetime(semantic_parse.created_at) or started_at, input_json={"raw_query": semantic_parse.raw_query}, output_json=self._semantic_parse_payload(semantic_parse), ) if run.route_json: append_event( stage="route", event_name="route_resolved", title="路由上下文", status="succeeded", summary=str(run.route_json.get("route_reason") or run.route_json.get("stage") or "已记录路由信息"), started=started_at, output_json=run.route_json, ) for tool_call in run.tool_calls: append_event( stage="tool", event_name="tool_invoked", title=tool_call.tool_name, status=tool_call.status, summary=f"{tool_call.tool_type} / {tool_call.status}", started=_normalize_datetime(tool_call.created_at) or started_at, finished=_normalize_datetime(tool_call.created_at) or started_at, input_json=tool_call.request_json, output_json=tool_call.response_json, error_message=tool_call.error_message, ) append_event( stage="response", event_name="response_built" if run.status != "failed" else "failed", title="最终结果", status=run.status, summary=run.result_summary or run.error_message, started=_normalize_datetime(run.finished_at) or started_at, output_json={"result_summary": run.result_summary}, error_message=run.error_message, ) return events @staticmethod def _resolve_run_title(run: AgentRun, semantic_parse: SemanticParseLog | None) -> str: if semantic_parse is not None: return f"{semantic_parse.scenario} / {semantic_parse.intent}" route_json = run.route_json or {} return str(route_json.get("task_name") or route_json.get("selected_agent") or run.agent) @staticmethod def _first_semantic_parse(run: AgentRun) -> SemanticParseLog | None: return run.semantic_parse_logs[0] if run.semantic_parse_logs else None @staticmethod def _serialize_semantic_parse(item: SemanticParseLog | None) -> SemanticParseRead | None: return SemanticParseRead.model_validate(item) if item is not None else None @staticmethod def _serialize_run(run: AgentRun) -> AgentRunRead: semantic_parse = AgentTraceService._first_semantic_parse(run) return AgentRunRead( id=run.id, run_id=run.run_id, agent=run.agent, source=run.source, user_id=run.user_id, task_id=run.task_id, ontology_json=run.ontology_json, route_json=run.route_json, permission_level=run.permission_level, status=run.status, result_summary=run.result_summary, error_message=run.error_message, started_at=run.started_at, finished_at=run.finished_at, tool_calls=[AgentToolCallRead.model_validate(item) for item in run.tool_calls], semantic_parse=SemanticParseRead.model_validate(semantic_parse) if semantic_parse is not None else None, ) @staticmethod def _semantic_parse_payload(item: SemanticParseLog) -> dict[str, Any]: return { "scenario": item.scenario, "intent": item.intent, "confidence": item.confidence, "entities": item.entities_json, "time_range": item.time_range_json, "metrics": item.metrics_json, "constraints": item.constraints_json, "risk_flags": item.risk_flags_json, "permission": item.permission_json, } def _resolve_duration_ms( started_at: datetime | None, finished_at: datetime | None, duration_ms: int | None, ) -> int: if duration_ms is not None: return max(0, int(duration_ms or 0)) if started_at is None or finished_at is None: return 0 try: return max(0, int((finished_at - started_at).total_seconds() * 1000)) except TypeError: return 0 def _normalize_datetime(value: datetime | None) -> datetime | None: if value is None: return None if value.tzinfo is None: return value.replace(tzinfo=UTC) return value def _optional_text(value: Any) -> str | None: text = str(value or "").strip() return text or None def _json_safe(value: Any) -> Any: if value is None or isinstance(value, (str, int, float, bool)): return value if isinstance(value, Decimal): return str(value) if isinstance(value, (datetime, date)): return value.isoformat() if isinstance(value, dict): return {str(key): _json_safe(item) for key, item in value.items()} if isinstance(value, (list, tuple, set)): return [_json_safe(item) for item in value] if hasattr(value, "model_dump"): return _json_safe(value.model_dump()) return str(value)