Phase 6.1: ToolRegistry infrastructure - Add ToolManifest with ToolCategory, PermissionClass, SideEffectScope - Add ToolRegistry singleton with register/get/unregister/list/search - Add BaseTool abstract class with ReadTool/WriteTool/DBWriteTool/ExternalTool/NetworkTool subclasses - Add migration layer for backward compatibility Phase 6.2: Hook interception system - Add HookType (PRE_TOOL_USE, POST_TOOL_USE, TOOL_ERROR, TOOL_SKIP) - Add HookManager with singleton for hook registration - Add HookExecutor for pre/post/error hook execution Phase 6.3: Streaming execution - Add StreamingToolExecutor with batch execution support Phase 6.4: New builtin tools - Add file_tools: GlobTool, GrepTool, ReadFileTool, WriteFileTool - Add system_tools: BashTool, PowerShellTool - Add dev_tools: LSPTools, GitTool - Add collaboration_tools: TeamAgentTool, TaskBroadcastTool Tests: 29 passed
171 lines
4.8 KiB
Python
171 lines
4.8 KiB
Python
"""Hook 执行器 - Phase 6.2
|
||
|
||
执行 Hook 拦截逻辑。
|
||
"""
|
||
|
||
import time
|
||
from typing import Any
|
||
|
||
from app.agents.tools.hooks.manager import get_hook_manager
|
||
from app.agents.tools.hooks.types import (
|
||
HookDefinition,
|
||
HookResult,
|
||
HookType,
|
||
ExecutionContext,
|
||
)
|
||
|
||
|
||
class HookExecutor:
|
||
"""Hook 执行器
|
||
|
||
负责在工具执行前后执行 Hook 逻辑。
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._manager = get_hook_manager()
|
||
|
||
async def execute_pre_hooks(
|
||
self, context: ExecutionContext
|
||
) -> tuple[bool, dict[str, Any] | None]:
|
||
"""执行 pre-tool Hook
|
||
|
||
Args:
|
||
context: 执行上下文
|
||
|
||
Returns:
|
||
(是否继续执行, 修改后的输入)
|
||
"""
|
||
hooks = self._manager.get_hooks(HookType.PRE_TOOL_USE, context.tool_name)
|
||
modified_input = context.tool_input
|
||
|
||
for hook in hooks:
|
||
try:
|
||
# 调用 hook handler
|
||
handler = hook.handler
|
||
if callable(handler):
|
||
result = await self._call_hook(handler, context)
|
||
if result and not result.continue_execution:
|
||
# Hook 决定中断执行
|
||
return False, modified_input
|
||
if result.modified_input is not None:
|
||
modified_input = result.modified_input
|
||
except Exception as e:
|
||
# Hook 出错,默认继续执行
|
||
pass
|
||
|
||
return True, modified_input
|
||
|
||
async def execute_post_hooks(self, context: ExecutionContext, result: Any) -> Any:
|
||
"""执行 post-tool Hook
|
||
|
||
Args:
|
||
context: 执行上下文
|
||
result: 工具执行结果
|
||
|
||
Returns:
|
||
修改后的结果
|
||
"""
|
||
hooks = self._manager.get_hooks(HookType.POST_TOOL_USE, context.tool_name)
|
||
modified_result = result
|
||
|
||
for hook in hooks:
|
||
try:
|
||
handler = hook.handler
|
||
if callable(handler):
|
||
hook_result = await self._call_hook(handler, context, modified_result)
|
||
if hook_result and hook_result.modified_output is not None:
|
||
modified_result = hook_result.modified_output
|
||
except Exception:
|
||
# Hook 出错,默认保留原结果
|
||
pass
|
||
|
||
return modified_result
|
||
|
||
async def execute_error_hooks(
|
||
self, context: ExecutionContext, error: Exception
|
||
) -> HookResult | None:
|
||
"""执行 error Hook
|
||
|
||
Args:
|
||
context: 执行上下文
|
||
error: 异常
|
||
|
||
Returns:
|
||
Hook 结果,如果返回 None 则继续传播错误
|
||
"""
|
||
hooks = self._manager.get_hooks(HookType.TOOL_ERROR, context.tool_name)
|
||
|
||
for hook in hooks:
|
||
try:
|
||
handler = hook.handler
|
||
if callable(handler):
|
||
result = await self._call_hook(handler, context, error)
|
||
if result is not None and result.continue_execution:
|
||
return result
|
||
except Exception:
|
||
# Hook 出错,继续执行其他 error hooks
|
||
pass
|
||
|
||
return None
|
||
|
||
async def execute_skip_check(self, context: ExecutionContext) -> bool:
|
||
"""检查是否应跳过工具执行
|
||
|
||
Args:
|
||
context: 执行上下文
|
||
|
||
Returns:
|
||
True 表示跳过,False 表示执行
|
||
"""
|
||
hooks = self._manager.get_hooks(HookType.TOOL_SKIP, context.tool_name)
|
||
|
||
for hook in hooks:
|
||
try:
|
||
handler = hook.handler
|
||
if callable(handler):
|
||
result = await self._call_hook(handler, context)
|
||
if result is not None and isinstance(result, bool):
|
||
return result
|
||
except Exception:
|
||
# Hook 出错,默认不跳过
|
||
pass
|
||
|
||
return False
|
||
|
||
async def _call_hook(
|
||
self, handler: Any, context: ExecutionContext, *args: Any
|
||
) -> HookResult | None:
|
||
"""调用 Hook 处理函数
|
||
|
||
Args:
|
||
handler: Hook 处理函数
|
||
context: 执行上下文
|
||
*args: 额外参数
|
||
|
||
Returns:
|
||
Hook 结果
|
||
"""
|
||
import asyncio
|
||
|
||
# 如果是普通函数,直接调用
|
||
if asyncio.iscoroutinefunction(handler):
|
||
return await handler(context, *args)
|
||
else:
|
||
return handler(context, *args)
|
||
|
||
|
||
# 全局单例
|
||
_executor: HookExecutor | None = None
|
||
|
||
|
||
def get_hook_executor() -> HookExecutor:
|
||
"""获取全局 Hook 执行器
|
||
|
||
Returns:
|
||
全局 HookExecutor 实例
|
||
"""
|
||
global _executor
|
||
if _executor is None:
|
||
_executor = HookExecutor()
|
||
return _executor
|