Files
JARVIS/backend/app/agents/tools/hooks/manager.py

175 lines
4.4 KiB
Python
Raw Normal View History

"""Hook 管理器 - Phase 6.2
管理 Hook 的注册查找和配置
"""
from typing import Any
from app.agents.tools.hooks.types import (
HookDefinition,
HookResult,
HookTrigger,
HookType,
ExecutionContext,
)
class HookManager:
"""Hook 管理器
管理全局 Hook 的注册和配置
"""
def __init__(self):
self._hooks: dict[HookType, list[HookDefinition]] = {
HookType.PRE_TOOL_USE: [],
HookType.POST_TOOL_USE: [],
HookType.TOOL_ERROR: [],
HookType.TOOL_SKIP: [],
}
self._global_hooks: list[HookDefinition] = [] # 全局 Hook对所有工具生效
def register(self, definition: HookDefinition) -> None:
"""注册 Hook
Args:
definition: Hook 定义
"""
if definition.trigger.tool_names is None and definition.trigger.categories is None:
# 全局 Hook
self._global_hooks.append(definition)
else:
# 特定工具 Hook
self._hooks[definition.hook_type].append(definition)
# 按优先级排序
self._hooks[definition.hook_type].sort(key=lambda h: h.priority, reverse=True)
self._global_hooks.sort(key=lambda h: h.priority, reverse=True)
def unregister(self, name: str) -> bool:
"""注销 Hook
Args:
name: Hook 名称
Returns:
是否成功注销
"""
# 从特定工具 Hook 中移除
for hooks in self._hooks.values():
for i, hook in enumerate(hooks):
if hook.name == name:
hooks.pop(i)
return True
# 从全局 Hook 中移除
for i, hook in enumerate(self._global_hooks):
if hook.name == name:
self._global_hooks.pop(i)
return True
return False
def get_hooks(self, hook_type: HookType, tool_name: str | None = None) -> list[HookDefinition]:
"""获取指定类型和工具的 Hook
Args:
hook_type: Hook 类型
tool_name: 工具名称可选
Returns:
匹配的 Hook 列表
"""
result: list[HookDefinition] = []
# 添加全局 Hook
for hook in self._global_hooks:
if hook.hook_type == hook_type and hook.enabled:
result.append(hook)
# 添加特定工具 Hook
for hook in self._hooks[hook_type]:
if not hook.enabled:
continue
if hook.trigger.tool_names is None and hook.trigger.categories is None:
continue
# 检查是否匹配
if hook.trigger.tool_names and tool_name not in hook.trigger.tool_names:
continue
result.append(hook)
return result
def list_all(self) -> list[HookDefinition]:
"""列出所有已注册的 Hook
Returns:
Hook 列表
"""
all_hooks = list(self._global_hooks)
for hooks in self._hooks.values():
all_hooks.extend(hooks)
return all_hooks
def enable(self, name: str) -> bool:
"""启用 Hook
Args:
name: Hook 名称
Returns:
是否成功启用
"""
for hook in self.list_all():
if hook.name == name:
hook.enabled = True
return True
return False
def disable(self, name: str) -> bool:
"""禁用 Hook
Args:
name: Hook 名称
Returns:
是否成功禁用
"""
for hook in self.list_all():
if hook.name == name:
hook.enabled = False
return True
return False
def clear(self) -> None:
"""清除所有 Hook"""
self._hooks = {ht: [] for ht in HookType}
self._global_hooks = []
# 全局单例
_global_hook_manager: HookManager | None = None
def get_hook_manager() -> HookManager:
"""获取全局 Hook 管理器
Returns:
全局 HookManager 实例
"""
global _global_hook_manager
if _global_hook_manager is None:
_global_hook_manager = HookManager()
return _global_hook_manager
def reset_hook_manager() -> None:
"""重置全局 Hook 管理器(用于测试)"""
global _global_hook_manager
if _global_hook_manager is not None:
_global_hook_manager.clear()
_global_hook_manager = None