"""Hook 执行器 - Phase 6.2 执行 Hook 拦截逻辑。 """ import time from typing import Any from app.agents.tools.hooks.manager import get_hook_manager from app.agents.tools.hooks.types import ( HookDefinition, HookResult, HookType, ExecutionContext, ) class HookExecutor: """Hook 执行器 负责在工具执行前后执行 Hook 逻辑。 """ def __init__(self): self._manager = get_hook_manager() async def execute_pre_hooks( self, context: ExecutionContext ) -> tuple[bool, dict[str, Any] | None]: """执行 pre-tool Hook Args: context: 执行上下文 Returns: (是否继续执行, 修改后的输入) """ hooks = self._manager.get_hooks(HookType.PRE_TOOL_USE, context.tool_name) modified_input = context.tool_input for hook in hooks: try: # 调用 hook handler handler = hook.handler if callable(handler): result = await self._call_hook(handler, context) if result and not result.continue_execution: # Hook 决定中断执行 return False, modified_input if result.modified_input is not None: modified_input = result.modified_input except Exception as e: # Hook 出错,默认继续执行 pass return True, modified_input async def execute_post_hooks(self, context: ExecutionContext, result: Any) -> Any: """执行 post-tool Hook Args: context: 执行上下文 result: 工具执行结果 Returns: 修改后的结果 """ hooks = self._manager.get_hooks(HookType.POST_TOOL_USE, context.tool_name) modified_result = result for hook in hooks: try: handler = hook.handler if callable(handler): hook_result = await self._call_hook(handler, context, modified_result) if hook_result and hook_result.modified_output is not None: modified_result = hook_result.modified_output except Exception: # Hook 出错,默认保留原结果 pass return modified_result async def execute_error_hooks( self, context: ExecutionContext, error: Exception ) -> HookResult | None: """执行 error Hook Args: context: 执行上下文 error: 异常 Returns: Hook 结果,如果返回 None 则继续传播错误 """ hooks = self._manager.get_hooks(HookType.TOOL_ERROR, context.tool_name) for hook in hooks: try: handler = hook.handler if callable(handler): result = await self._call_hook(handler, context, error) if result is not None and result.continue_execution: return result except Exception: # Hook 出错,继续执行其他 error hooks pass return None async def execute_skip_check(self, context: ExecutionContext) -> bool: """检查是否应跳过工具执行 Args: context: 执行上下文 Returns: True 表示跳过,False 表示执行 """ hooks = self._manager.get_hooks(HookType.TOOL_SKIP, context.tool_name) for hook in hooks: try: handler = hook.handler if callable(handler): result = await self._call_hook(handler, context) if result is not None and isinstance(result, bool): return result except Exception: # Hook 出错,默认不跳过 pass return False async def _call_hook( self, handler: Any, context: ExecutionContext, *args: Any ) -> HookResult | None: """调用 Hook 处理函数 Args: handler: Hook 处理函数 context: 执行上下文 *args: 额外参数 Returns: Hook 结果 """ import asyncio # 如果是普通函数,直接调用 if asyncio.iscoroutinefunction(handler): return await handler(context, *args) else: return handler(context, *args) # 全局单例 _executor: HookExecutor | None = None def get_hook_executor() -> HookExecutor: """获取全局 Hook 执行器 Returns: 全局 HookExecutor 实例 """ global _executor if _executor is None: _executor = HookExecutor() return _executor