Files
JARVIS/backend/app/agents/tools/registry.py

207 lines
5.2 KiB
Python
Raw Normal View History

"""工具注册表 - 工具系统重构 Phase 6.1"""
from collections import defaultdict
from typing import Any, Callable
from app.agents.tools.manifest import HookConfig, ToolManifest
class ToolRegistry:
"""工具注册表
统一管理所有工具的注册发现和调用
支持工具元数据权限分类Hook 拦截
"""
def __init__(self):
self._tools: dict[str, ToolManifest] = {}
self._executors: dict[str, Callable] = {}
self._hooks: dict[str, list[HookConfig]] = defaultdict(list)
def register(
self, manifest: ToolManifest, executor: Callable, hooks: list[HookConfig] | None = None
) -> None:
"""注册工具
Args:
manifest: 工具元数据
executor: 工具执行函数
hooks: 可选的 Hook 配置列表
"""
if manifest.name in self._tools:
raise ValueError(f"Tool already registered: {manifest.name}")
self._tools[manifest.name] = manifest
self._executors[manifest.name] = executor
if hooks:
for hook in hooks:
self._hooks[manifest.name].append(hook)
def unregister(self, name: str) -> bool:
"""注销工具
Args:
name: 工具名称
Returns:
是否成功注销
"""
if name not in self._tools:
return False
del self._tools[name]
del self._executors[name]
if name in self._hooks:
del self._hooks[name]
return True
def get(self, name: str) -> ToolManifest | None:
"""获取工具元数据
Args:
name: 工具名称
Returns:
工具元数据不存在返回 None
"""
return self._tools.get(name)
def get_executor(self, name: str) -> Callable | None:
"""获取工具执行器
Args:
name: 工具名称
Returns:
工具执行函数不存在返回 None
"""
return self._executors.get(name)
def get_hooks(self, name: str) -> list[HookConfig]:
"""获取工具的 Hook 配置
Args:
name: 工具名称
Returns:
Hook 配置列表
"""
return self._hooks.get(name, [])
def list_all(self) -> list[ToolManifest]:
"""列出所有已注册的工具
Returns:
工具元数据列表
"""
return list(self._tools.values())
def list_by_category(self, category: Any) -> list[ToolManifest]:
"""按类别列出工具
Args:
category: 工具类别
Returns:
该类别下的所有工具
"""
return [t for t in self._tools.values() if t.category == category]
def list_by_permission(self, permission: Any) -> list[ToolManifest]:
"""按权限级别列出工具
Args:
permission: 权限级别
Returns:
该权限级别下的所有工具
"""
return [t for t in self._tools.values() if t.permission_class == permission]
def search_by_tag(self, tag: str) -> list[ToolManifest]:
"""按标签搜索工具
Args:
tag: 标签
Returns:
包含该标签的工具
"""
return [t for t in self._tools.values() if tag in t.tags]
def search_by_name(self, keyword: str) -> list[ToolManifest]:
"""按名称关键词搜索工具
Args:
keyword: 关键词
Returns:
名称包含关键词的工具
"""
keyword = keyword.lower()
return [t for t in self._tools.values() if keyword in t.name.lower()]
def get_requires_confirmation(self, name: str) -> bool:
"""检查工具是否需要确认
Args:
name: 工具名称
Returns:
是否需要确认
"""
manifest = self._tools.get(name)
return manifest.requires_confirmation if manifest else False
def get_is_streaming(self, name: str) -> bool:
"""检查工具是否支持流式执行
Args:
name: 工具名称
Returns:
是否支持流式
"""
manifest = self._tools.get(name)
return manifest.is_streaming if manifest else False
def clear(self) -> None:
"""清空注册表"""
self._tools.clear()
self._executors.clear()
self._hooks.clear()
def __len__(self) -> int:
return len(self._tools)
def __contains__(self, name: str) -> bool:
return name in self._tools
def __iter__(self):
return iter(self._tools.values())
# 全局单例实例
_global_registry: ToolRegistry | None = None
def get_tool_registry() -> ToolRegistry:
"""获取全局工具注册表单例
Returns:
全局 ToolRegistry 实例
"""
global _global_registry
if _global_registry is None:
_global_registry = ToolRegistry()
return _global_registry
def reset_tool_registry() -> None:
"""重置全局工具注册表(用于测试)"""
global _global_registry
if _global_registry is not None:
_global_registry.clear()
_global_registry = None