feat(agents): Phase 6 tool system refactoring
Phase 6.1: ToolRegistry infrastructure - Add ToolManifest with ToolCategory, PermissionClass, SideEffectScope - Add ToolRegistry singleton with register/get/unregister/list/search - Add BaseTool abstract class with ReadTool/WriteTool/DBWriteTool/ExternalTool/NetworkTool subclasses - Add migration layer for backward compatibility Phase 6.2: Hook interception system - Add HookType (PRE_TOOL_USE, POST_TOOL_USE, TOOL_ERROR, TOOL_SKIP) - Add HookManager with singleton for hook registration - Add HookExecutor for pre/post/error hook execution Phase 6.3: Streaming execution - Add StreamingToolExecutor with batch execution support Phase 6.4: New builtin tools - Add file_tools: GlobTool, GrepTool, ReadFileTool, WriteFileTool - Add system_tools: BashTool, PowerShellTool - Add dev_tools: LSPTools, GitTool - Add collaboration_tools: TeamAgentTool, TaskBroadcastTool Tests: 29 passed
This commit is contained in:
210
backend/app/agents/tools/streaming.py
Normal file
210
backend/app/agents/tools/streaming.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""流式工具执行器 - Phase 6.3
|
||||
|
||||
支持流式输出的工具执行器。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from app.agents.tools.hooks.executor import get_hook_executor
|
||||
from app.agents.tools.hooks.types import ExecutionContext
|
||||
from app.agents.tools.registry import get_tool_registry
|
||||
|
||||
|
||||
class StreamingToolExecutor:
|
||||
"""流式工具执行器
|
||||
|
||||
支持:
|
||||
- 普通工具的同步/异步执行
|
||||
- 流式工具的流式输出
|
||||
- Hook 拦截(pre/post/error)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._registry = get_tool_registry()
|
||||
self._hook_executor = get_hook_executor()
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any],
|
||||
user_id: str | None = None,
|
||||
) -> Any:
|
||||
"""执行工具(非流式)
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
tool_input: 工具输入参数
|
||||
user_id: 用户 ID(可选)
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
"""
|
||||
# 创建执行上下文
|
||||
context = ExecutionContext(
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# 获取工具和执行器
|
||||
manifest = self._registry.get(tool_name)
|
||||
if manifest is None:
|
||||
raise ValueError(f"Tool not found: {tool_name}")
|
||||
|
||||
executor = self._registry.get_executor(tool_name)
|
||||
if executor is None:
|
||||
raise ValueError(f"Executor not found for tool: {tool_name}")
|
||||
|
||||
# 检查是否跳过
|
||||
if await self._hook_executor.execute_skip_check(context):
|
||||
return {"skipped": True, "tool": tool_name}
|
||||
|
||||
# 执行 pre-hooks
|
||||
continue_execution, modified_input = await self._hook_executor.execute_pre_hooks(context)
|
||||
if not continue_execution:
|
||||
return {"pre_hook_aborted": True, "tool": tool_name}
|
||||
|
||||
# 执行工具
|
||||
try:
|
||||
context.start_time = asyncio.get_event_loop().time()
|
||||
|
||||
# 判断是同步还是异步执行
|
||||
if asyncio.iscoroutinefunction(executor):
|
||||
result = await executor(**modified_input)
|
||||
else:
|
||||
# 同步函数在线程池中执行
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(None, lambda: executor(**modified_input))
|
||||
|
||||
context.result = result
|
||||
context.end_time = asyncio.get_event_loop().time()
|
||||
|
||||
# 执行 post-hooks
|
||||
result = await self._hook_executor.execute_post_hooks(context, result)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
context.error = e
|
||||
context.end_time = asyncio.get_event_loop().time()
|
||||
|
||||
# 执行 error-hooks
|
||||
error_result = await self._hook_executor.execute_error_hooks(context, e)
|
||||
if error_result is not None:
|
||||
return error_result
|
||||
|
||||
raise
|
||||
|
||||
async def execute_streaming(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any],
|
||||
user_id: str | None = None,
|
||||
) -> AsyncGenerator[dict[str, Any], None]:
|
||||
"""执行流式工具
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
tool_input: 工具输入参数
|
||||
user_id: 用户 ID(可选)
|
||||
|
||||
Yields:
|
||||
流式输出片段
|
||||
"""
|
||||
# 创建执行上下文
|
||||
context = ExecutionContext(
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# 获取工具和执行器
|
||||
manifest = self._registry.get(tool_name)
|
||||
if manifest is None:
|
||||
raise ValueError(f"Tool not found: {tool_name}")
|
||||
|
||||
if not manifest.is_streaming:
|
||||
raise ValueError(f"Tool is not streaming: {tool_name}")
|
||||
|
||||
executor = self._registry.get_executor(tool_name)
|
||||
if executor is None:
|
||||
raise ValueError(f"Executor not found for tool: {tool_name}")
|
||||
|
||||
# 检查是否跳过
|
||||
if await self._hook_executor.execute_skip_check(context):
|
||||
yield {"type": "skipped", "tool": tool_name}
|
||||
return
|
||||
|
||||
# 执行 pre-hooks
|
||||
continue_execution, modified_input = await self._hook_executor.execute_pre_hooks(context)
|
||||
if not continue_execution:
|
||||
yield {"type": "pre_hook_aborted", "tool": tool_name}
|
||||
return
|
||||
|
||||
# 执行流式工具
|
||||
try:
|
||||
context.start_time = asyncio.get_event_loop().time()
|
||||
|
||||
# 调用执行器(应该返回 AsyncGenerator)
|
||||
result = executor(**modified_input)
|
||||
|
||||
# 如果是 async generator
|
||||
if asyncio.isasyncgen(result):
|
||||
async for chunk in result:
|
||||
yield {"type": "chunk", "data": chunk}
|
||||
else:
|
||||
# 普通协程
|
||||
data = await result
|
||||
yield {"type": "chunk", "data": data}
|
||||
|
||||
context.end_time = asyncio.get_event_loop().time()
|
||||
yield {"type": "done", "tool": tool_name}
|
||||
|
||||
except Exception as e:
|
||||
context.error = e
|
||||
context.end_time = asyncio.get_event_loop().time()
|
||||
|
||||
yield {"type": "error", "error": str(e), "tool": tool_name}
|
||||
|
||||
# 执行 error-hooks
|
||||
await self._hook_executor.execute_error_hooks(context, e)
|
||||
|
||||
async def execute_batch(
|
||||
self,
|
||||
tool_calls: list[dict[str, Any]],
|
||||
user_id: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""批量执行工具
|
||||
|
||||
Args:
|
||||
tool_calls: 工具调用列表,每个元素包含 tool_name 和 tool_input
|
||||
user_id: 用户 ID(可选)
|
||||
|
||||
Returns:
|
||||
执行结果列表
|
||||
"""
|
||||
tasks = []
|
||||
for call in tool_calls:
|
||||
tool_name = call.get("tool_name") or call.get("name")
|
||||
tool_input = call.get("tool_input") or call.get("input") or {}
|
||||
tasks.append(self.execute(tool_name, tool_input, user_id))
|
||||
|
||||
return await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# 全局单例
|
||||
_executor: StreamingToolExecutor | None = None
|
||||
|
||||
|
||||
def get_streaming_executor() -> StreamingToolExecutor:
|
||||
"""获取全局流式执行器
|
||||
|
||||
Returns:
|
||||
全局 StreamingToolExecutor 实例
|
||||
"""
|
||||
global _executor
|
||||
if _executor is None:
|
||||
_executor = StreamingToolExecutor()
|
||||
return _executor
|
||||
Reference in New Issue
Block a user