"""流式工具执行器 - Phase 6.3 支持流式输出的工具执行器。 """ import asyncio import json from typing import Any, AsyncGenerator from app.agents.tools.hooks.executor import get_hook_executor from app.agents.tools.hooks.types import ExecutionContext from app.agents.tools.registry import get_tool_registry class StreamingToolExecutor: """流式工具执行器 支持: - 普通工具的同步/异步执行 - 流式工具的流式输出 - Hook 拦截(pre/post/error) """ def __init__(self): self._registry = get_tool_registry() self._hook_executor = get_hook_executor() async def execute( self, tool_name: str, tool_input: dict[str, Any], user_id: str | None = None, ) -> Any: """执行工具(非流式) Args: tool_name: 工具名称 tool_input: 工具输入参数 user_id: 用户 ID(可选) Returns: 工具执行结果 """ # 创建执行上下文 context = ExecutionContext( tool_name=tool_name, tool_input=tool_input, user_id=user_id, ) # 获取工具和执行器 manifest = self._registry.get(tool_name) if manifest is None: raise ValueError(f"Tool not found: {tool_name}") executor = self._registry.get_executor(tool_name) if executor is None: raise ValueError(f"Executor not found for tool: {tool_name}") # 检查是否跳过 if await self._hook_executor.execute_skip_check(context): return {"skipped": True, "tool": tool_name} # 执行 pre-hooks continue_execution, modified_input = await self._hook_executor.execute_pre_hooks(context) if not continue_execution: return {"pre_hook_aborted": True, "tool": tool_name} # 执行工具 try: context.start_time = asyncio.get_event_loop().time() # 判断是同步还是异步执行 if asyncio.iscoroutinefunction(executor): result = await executor(**modified_input) else: # 同步函数在线程池中执行 loop = asyncio.get_event_loop() result = await loop.run_in_executor(None, lambda: executor(**modified_input)) context.result = result context.end_time = asyncio.get_event_loop().time() # 执行 post-hooks result = await self._hook_executor.execute_post_hooks(context, result) return result except Exception as e: context.error = e context.end_time = asyncio.get_event_loop().time() # 执行 error-hooks error_result = await self._hook_executor.execute_error_hooks(context, e) if error_result is not None: return error_result raise async def execute_streaming( self, tool_name: str, tool_input: dict[str, Any], user_id: str | None = None, ) -> AsyncGenerator[dict[str, Any], None]: """执行流式工具 Args: tool_name: 工具名称 tool_input: 工具输入参数 user_id: 用户 ID(可选) Yields: 流式输出片段 """ # 创建执行上下文 context = ExecutionContext( tool_name=tool_name, tool_input=tool_input, user_id=user_id, ) # 获取工具和执行器 manifest = self._registry.get(tool_name) if manifest is None: raise ValueError(f"Tool not found: {tool_name}") if not manifest.is_streaming: raise ValueError(f"Tool is not streaming: {tool_name}") executor = self._registry.get_executor(tool_name) if executor is None: raise ValueError(f"Executor not found for tool: {tool_name}") # 检查是否跳过 if await self._hook_executor.execute_skip_check(context): yield {"type": "skipped", "tool": tool_name} return # 执行 pre-hooks continue_execution, modified_input = await self._hook_executor.execute_pre_hooks(context) if not continue_execution: yield {"type": "pre_hook_aborted", "tool": tool_name} return # 执行流式工具 try: context.start_time = asyncio.get_event_loop().time() # 调用执行器(应该返回 AsyncGenerator) result = executor(**modified_input) # 如果是 async generator if asyncio.isasyncgen(result): async for chunk in result: yield {"type": "chunk", "data": chunk} else: # 普通协程 data = await result yield {"type": "chunk", "data": data} context.end_time = asyncio.get_event_loop().time() yield {"type": "done", "tool": tool_name} except Exception as e: context.error = e context.end_time = asyncio.get_event_loop().time() yield {"type": "error", "error": str(e), "tool": tool_name} # 执行 error-hooks await self._hook_executor.execute_error_hooks(context, e) async def execute_batch( self, tool_calls: list[dict[str, Any]], user_id: str | None = None, ) -> list[Any]: """批量执行工具 Args: tool_calls: 工具调用列表,每个元素包含 tool_name 和 tool_input user_id: 用户 ID(可选) Returns: 执行结果列表 """ tasks = [] for call in tool_calls: tool_name = call.get("tool_name") or call.get("name") tool_input = call.get("tool_input") or call.get("input") or {} tasks.append(self.execute(tool_name, tool_input, user_id)) return await asyncio.gather(*tasks, return_exceptions=True) # 全局单例 _executor: StreamingToolExecutor | None = None def get_streaming_executor() -> StreamingToolExecutor: """获取全局流式执行器 Returns: 全局 StreamingToolExecutor 实例 """ global _executor if _executor is None: _executor = StreamingToolExecutor() return _executor