154 lines
4.3 KiB
Python
154 lines
4.3 KiB
Python
"""自定义 Hook 加载器 - Phase 7.4
|
||
|
||
支持动态加载用户自定义的 Hook。
|
||
"""
|
||
|
||
import importlib.util
|
||
import os
|
||
from typing import Any
|
||
|
||
from app.agents.tools.hooks.types import HookDefinition, HookType, HookTrigger, HookResult
|
||
|
||
|
||
class CustomHookLoader:
|
||
"""自定义 Hook 加载器
|
||
|
||
从指定目录动态加载自定义 Hook 模块。
|
||
"""
|
||
|
||
def __init__(self, hooks_dir: str | None = None):
|
||
"""
|
||
Args:
|
||
hooks_dir: Hook 目录,None 则使用默认目录
|
||
"""
|
||
if hooks_dir is None:
|
||
hooks_dir = os.path.join(
|
||
os.path.dirname(__file__), "..", "..", "..", "data", "custom_hooks"
|
||
)
|
||
self.hooks_dir = hooks_dir
|
||
self._loaded_hooks: dict[str, HookDefinition] = {}
|
||
|
||
def load_all(self) -> list[HookDefinition]:
|
||
"""加载所有自定义 Hook
|
||
|
||
Returns:
|
||
Hook 定义列表
|
||
"""
|
||
hooks = []
|
||
|
||
if not os.path.exists(self.hooks_dir):
|
||
return hooks
|
||
|
||
for filename in os.listdir(self.hooks_dir):
|
||
if filename.endswith(".py") and not filename.startswith("_"):
|
||
hook_path = os.path.join(self.hooks_dir, filename)
|
||
hook_def = self._load_hook_from_file(hook_path, filename[:-3])
|
||
if hook_def:
|
||
hooks.append(hook_def)
|
||
self._loaded_hooks[hook_def.name] = hook_def
|
||
|
||
return hooks
|
||
|
||
def _load_hook_from_file(self, hook_path: str, module_name: str) -> HookDefinition | None:
|
||
"""从文件加载 Hook
|
||
|
||
Args:
|
||
hook_path: Hook 文件路径
|
||
module_name: 模块名
|
||
|
||
Returns:
|
||
Hook 定义或 None
|
||
"""
|
||
try:
|
||
spec = importlib.util.spec_from_file_location(module_name, hook_path)
|
||
if not spec or not spec.loader:
|
||
return None
|
||
|
||
module = importlib.util.module_from_spec(spec)
|
||
spec.loader.exec_module(module)
|
||
|
||
# 查找 HOOK_DEFINITION 或 hook_definition
|
||
hook_def = getattr(module, "HOOK_DEFINITION", None) or getattr(
|
||
module, "hook_definition", None
|
||
)
|
||
|
||
if hook_def and isinstance(hook_def, HookDefinition):
|
||
return hook_def
|
||
|
||
# 如果没有定义,尝试从函数自动推断
|
||
if hasattr(module, "pre_tool_hook") or hasattr(module, "post_tool_hook"):
|
||
return self._infer_hook_definition(module, module_name)
|
||
|
||
except Exception:
|
||
pass
|
||
|
||
return None
|
||
|
||
def _infer_hook_definition(self, module: Any, module_name: str) -> HookDefinition | None:
|
||
"""从模块函数推断 Hook 定义
|
||
|
||
Args:
|
||
module: 模块对象
|
||
module_name: 模块名
|
||
|
||
Returns:
|
||
Hook 定义或 None
|
||
"""
|
||
hook_type = None
|
||
handler = None
|
||
|
||
if hasattr(module, "pre_tool_hook"):
|
||
handler = module.pre_tool_hook
|
||
hook_type = HookType.PRE_TOOL_USE
|
||
elif hasattr(module, "post_tool_hook"):
|
||
handler = module.post_tool_hook
|
||
hook_type = HookType.POST_TOOL_USE
|
||
elif hasattr(module, "error_tool_hook"):
|
||
handler = module.error_tool_hook
|
||
hook_type = HookType.TOOL_ERROR
|
||
|
||
if not handler or not hook_type:
|
||
return None
|
||
|
||
return HookDefinition(
|
||
name=module_name,
|
||
hook_type=hook_type,
|
||
trigger=HookTrigger(),
|
||
handler=handler,
|
||
priority=0,
|
||
enabled=True,
|
||
description=f"Auto-loaded hook from {module_name}",
|
||
)
|
||
|
||
def get_hook(self, name: str) -> HookDefinition | None:
|
||
"""获取已加载的 Hook
|
||
|
||
Args:
|
||
name: Hook 名称
|
||
|
||
Returns:
|
||
Hook 定义或 None
|
||
"""
|
||
return self._loaded_hooks.get(name)
|
||
|
||
def reload(self) -> list[HookDefinition]:
|
||
"""重新加载所有 Hook
|
||
|
||
Returns:
|
||
重新加载的 Hook 列表
|
||
"""
|
||
self._loaded_hooks.clear()
|
||
return self.load_all()
|
||
|
||
|
||
# 全局加载器
|
||
_loader: CustomHookLoader | None = None
|
||
|
||
|
||
def get_custom_hook_loader() -> CustomHookLoader:
|
||
"""获取全局自定义 Hook 加载器"""
|
||
global _loader
|
||
if _loader is None:
|
||
_loader = CustomHookLoader()
|
||
return _loader
|