Phase 6.1: ToolRegistry infrastructure - Add ToolManifest with ToolCategory, PermissionClass, SideEffectScope - Add ToolRegistry singleton with register/get/unregister/list/search - Add BaseTool abstract class with ReadTool/WriteTool/DBWriteTool/ExternalTool/NetworkTool subclasses - Add migration layer for backward compatibility Phase 6.2: Hook interception system - Add HookType (PRE_TOOL_USE, POST_TOOL_USE, TOOL_ERROR, TOOL_SKIP) - Add HookManager with singleton for hook registration - Add HookExecutor for pre/post/error hook execution Phase 6.3: Streaming execution - Add StreamingToolExecutor with batch execution support Phase 6.4: New builtin tools - Add file_tools: GlobTool, GrepTool, ReadFileTool, WriteFileTool - Add system_tools: BashTool, PowerShellTool - Add dev_tools: LSPTools, GitTool - Add collaboration_tools: TeamAgentTool, TaskBroadcastTool Tests: 29 passed
211 lines
6.3 KiB
Python
211 lines
6.3 KiB
Python
"""流式工具执行器 - Phase 6.3
|
||
|
||
支持流式输出的工具执行器。
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
from typing import Any, AsyncGenerator
|
||
|
||
from app.agents.tools.hooks.executor import get_hook_executor
|
||
from app.agents.tools.hooks.types import ExecutionContext
|
||
from app.agents.tools.registry import get_tool_registry
|
||
|
||
|
||
class StreamingToolExecutor:
|
||
"""流式工具执行器
|
||
|
||
支持:
|
||
- 普通工具的同步/异步执行
|
||
- 流式工具的流式输出
|
||
- Hook 拦截(pre/post/error)
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._registry = get_tool_registry()
|
||
self._hook_executor = get_hook_executor()
|
||
|
||
async def execute(
|
||
self,
|
||
tool_name: str,
|
||
tool_input: dict[str, Any],
|
||
user_id: str | None = None,
|
||
) -> Any:
|
||
"""执行工具(非流式)
|
||
|
||
Args:
|
||
tool_name: 工具名称
|
||
tool_input: 工具输入参数
|
||
user_id: 用户 ID(可选)
|
||
|
||
Returns:
|
||
工具执行结果
|
||
"""
|
||
# 创建执行上下文
|
||
context = ExecutionContext(
|
||
tool_name=tool_name,
|
||
tool_input=tool_input,
|
||
user_id=user_id,
|
||
)
|
||
|
||
# 获取工具和执行器
|
||
manifest = self._registry.get(tool_name)
|
||
if manifest is None:
|
||
raise ValueError(f"Tool not found: {tool_name}")
|
||
|
||
executor = self._registry.get_executor(tool_name)
|
||
if executor is None:
|
||
raise ValueError(f"Executor not found for tool: {tool_name}")
|
||
|
||
# 检查是否跳过
|
||
if await self._hook_executor.execute_skip_check(context):
|
||
return {"skipped": True, "tool": tool_name}
|
||
|
||
# 执行 pre-hooks
|
||
continue_execution, modified_input = await self._hook_executor.execute_pre_hooks(context)
|
||
if not continue_execution:
|
||
return {"pre_hook_aborted": True, "tool": tool_name}
|
||
|
||
# 执行工具
|
||
try:
|
||
context.start_time = asyncio.get_event_loop().time()
|
||
|
||
# 判断是同步还是异步执行
|
||
if asyncio.iscoroutinefunction(executor):
|
||
result = await executor(**modified_input)
|
||
else:
|
||
# 同步函数在线程池中执行
|
||
loop = asyncio.get_event_loop()
|
||
result = await loop.run_in_executor(None, lambda: executor(**modified_input))
|
||
|
||
context.result = result
|
||
context.end_time = asyncio.get_event_loop().time()
|
||
|
||
# 执行 post-hooks
|
||
result = await self._hook_executor.execute_post_hooks(context, result)
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
context.error = e
|
||
context.end_time = asyncio.get_event_loop().time()
|
||
|
||
# 执行 error-hooks
|
||
error_result = await self._hook_executor.execute_error_hooks(context, e)
|
||
if error_result is not None:
|
||
return error_result
|
||
|
||
raise
|
||
|
||
async def execute_streaming(
|
||
self,
|
||
tool_name: str,
|
||
tool_input: dict[str, Any],
|
||
user_id: str | None = None,
|
||
) -> AsyncGenerator[dict[str, Any], None]:
|
||
"""执行流式工具
|
||
|
||
Args:
|
||
tool_name: 工具名称
|
||
tool_input: 工具输入参数
|
||
user_id: 用户 ID(可选)
|
||
|
||
Yields:
|
||
流式输出片段
|
||
"""
|
||
# 创建执行上下文
|
||
context = ExecutionContext(
|
||
tool_name=tool_name,
|
||
tool_input=tool_input,
|
||
user_id=user_id,
|
||
)
|
||
|
||
# 获取工具和执行器
|
||
manifest = self._registry.get(tool_name)
|
||
if manifest is None:
|
||
raise ValueError(f"Tool not found: {tool_name}")
|
||
|
||
if not manifest.is_streaming:
|
||
raise ValueError(f"Tool is not streaming: {tool_name}")
|
||
|
||
executor = self._registry.get_executor(tool_name)
|
||
if executor is None:
|
||
raise ValueError(f"Executor not found for tool: {tool_name}")
|
||
|
||
# 检查是否跳过
|
||
if await self._hook_executor.execute_skip_check(context):
|
||
yield {"type": "skipped", "tool": tool_name}
|
||
return
|
||
|
||
# 执行 pre-hooks
|
||
continue_execution, modified_input = await self._hook_executor.execute_pre_hooks(context)
|
||
if not continue_execution:
|
||
yield {"type": "pre_hook_aborted", "tool": tool_name}
|
||
return
|
||
|
||
# 执行流式工具
|
||
try:
|
||
context.start_time = asyncio.get_event_loop().time()
|
||
|
||
# 调用执行器(应该返回 AsyncGenerator)
|
||
result = executor(**modified_input)
|
||
|
||
# 如果是 async generator
|
||
if asyncio.isasyncgen(result):
|
||
async for chunk in result:
|
||
yield {"type": "chunk", "data": chunk}
|
||
else:
|
||
# 普通协程
|
||
data = await result
|
||
yield {"type": "chunk", "data": data}
|
||
|
||
context.end_time = asyncio.get_event_loop().time()
|
||
yield {"type": "done", "tool": tool_name}
|
||
|
||
except Exception as e:
|
||
context.error = e
|
||
context.end_time = asyncio.get_event_loop().time()
|
||
|
||
yield {"type": "error", "error": str(e), "tool": tool_name}
|
||
|
||
# 执行 error-hooks
|
||
await self._hook_executor.execute_error_hooks(context, e)
|
||
|
||
async def execute_batch(
|
||
self,
|
||
tool_calls: list[dict[str, Any]],
|
||
user_id: str | None = None,
|
||
) -> list[Any]:
|
||
"""批量执行工具
|
||
|
||
Args:
|
||
tool_calls: 工具调用列表,每个元素包含 tool_name 和 tool_input
|
||
user_id: 用户 ID(可选)
|
||
|
||
Returns:
|
||
执行结果列表
|
||
"""
|
||
tasks = []
|
||
for call in tool_calls:
|
||
tool_name = call.get("tool_name") or call.get("name")
|
||
tool_input = call.get("tool_input") or call.get("input") or {}
|
||
tasks.append(self.execute(tool_name, tool_input, user_id))
|
||
|
||
return await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
|
||
# 全局单例
|
||
_executor: StreamingToolExecutor | None = None
|
||
|
||
|
||
def get_streaming_executor() -> StreamingToolExecutor:
|
||
"""获取全局流式执行器
|
||
|
||
Returns:
|
||
全局 StreamingToolExecutor 实例
|
||
"""
|
||
global _executor
|
||
if _executor is None:
|
||
_executor = StreamingToolExecutor()
|
||
return _executor
|