Phase 6.1: ToolRegistry infrastructure - Add ToolManifest with ToolCategory, PermissionClass, SideEffectScope - Add ToolRegistry singleton with register/get/unregister/list/search - Add BaseTool abstract class with ReadTool/WriteTool/DBWriteTool/ExternalTool/NetworkTool subclasses - Add migration layer for backward compatibility Phase 6.2: Hook interception system - Add HookType (PRE_TOOL_USE, POST_TOOL_USE, TOOL_ERROR, TOOL_SKIP) - Add HookManager with singleton for hook registration - Add HookExecutor for pre/post/error hook execution Phase 6.3: Streaming execution - Add StreamingToolExecutor with batch execution support Phase 6.4: New builtin tools - Add file_tools: GlobTool, GrepTool, ReadFileTool, WriteFileTool - Add system_tools: BashTool, PowerShellTool - Add dev_tools: LSPTools, GitTool - Add collaboration_tools: TeamAgentTool, TaskBroadcastTool Tests: 29 passed
207 lines
5.2 KiB
Python
207 lines
5.2 KiB
Python
"""工具注册表 - 工具系统重构 Phase 6.1"""
|
|
|
|
from collections import defaultdict
|
|
from typing import Any, Callable
|
|
|
|
from app.agents.tools.manifest import HookConfig, ToolManifest
|
|
|
|
|
|
class ToolRegistry:
|
|
"""工具注册表
|
|
|
|
统一管理所有工具的注册、发现和调用。
|
|
支持工具元数据、权限分类、Hook 拦截。
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._tools: dict[str, ToolManifest] = {}
|
|
self._executors: dict[str, Callable] = {}
|
|
self._hooks: dict[str, list[HookConfig]] = defaultdict(list)
|
|
|
|
def register(
|
|
self, manifest: ToolManifest, executor: Callable, hooks: list[HookConfig] | None = None
|
|
) -> None:
|
|
"""注册工具
|
|
|
|
Args:
|
|
manifest: 工具元数据
|
|
executor: 工具执行函数
|
|
hooks: 可选的 Hook 配置列表
|
|
"""
|
|
if manifest.name in self._tools:
|
|
raise ValueError(f"Tool already registered: {manifest.name}")
|
|
|
|
self._tools[manifest.name] = manifest
|
|
self._executors[manifest.name] = executor
|
|
|
|
if hooks:
|
|
for hook in hooks:
|
|
self._hooks[manifest.name].append(hook)
|
|
|
|
def unregister(self, name: str) -> bool:
|
|
"""注销工具
|
|
|
|
Args:
|
|
name: 工具名称
|
|
|
|
Returns:
|
|
是否成功注销
|
|
"""
|
|
if name not in self._tools:
|
|
return False
|
|
|
|
del self._tools[name]
|
|
del self._executors[name]
|
|
if name in self._hooks:
|
|
del self._hooks[name]
|
|
return True
|
|
|
|
def get(self, name: str) -> ToolManifest | None:
|
|
"""获取工具元数据
|
|
|
|
Args:
|
|
name: 工具名称
|
|
|
|
Returns:
|
|
工具元数据,不存在返回 None
|
|
"""
|
|
return self._tools.get(name)
|
|
|
|
def get_executor(self, name: str) -> Callable | None:
|
|
"""获取工具执行器
|
|
|
|
Args:
|
|
name: 工具名称
|
|
|
|
Returns:
|
|
工具执行函数,不存在返回 None
|
|
"""
|
|
return self._executors.get(name)
|
|
|
|
def get_hooks(self, name: str) -> list[HookConfig]:
|
|
"""获取工具的 Hook 配置
|
|
|
|
Args:
|
|
name: 工具名称
|
|
|
|
Returns:
|
|
Hook 配置列表
|
|
"""
|
|
return self._hooks.get(name, [])
|
|
|
|
def list_all(self) -> list[ToolManifest]:
|
|
"""列出所有已注册的工具
|
|
|
|
Returns:
|
|
工具元数据列表
|
|
"""
|
|
return list(self._tools.values())
|
|
|
|
def list_by_category(self, category: Any) -> list[ToolManifest]:
|
|
"""按类别列出工具
|
|
|
|
Args:
|
|
category: 工具类别
|
|
|
|
Returns:
|
|
该类别下的所有工具
|
|
"""
|
|
return [t for t in self._tools.values() if t.category == category]
|
|
|
|
def list_by_permission(self, permission: Any) -> list[ToolManifest]:
|
|
"""按权限级别列出工具
|
|
|
|
Args:
|
|
permission: 权限级别
|
|
|
|
Returns:
|
|
该权限级别下的所有工具
|
|
"""
|
|
return [t for t in self._tools.values() if t.permission_class == permission]
|
|
|
|
def search_by_tag(self, tag: str) -> list[ToolManifest]:
|
|
"""按标签搜索工具
|
|
|
|
Args:
|
|
tag: 标签
|
|
|
|
Returns:
|
|
包含该标签的工具
|
|
"""
|
|
return [t for t in self._tools.values() if tag in t.tags]
|
|
|
|
def search_by_name(self, keyword: str) -> list[ToolManifest]:
|
|
"""按名称关键词搜索工具
|
|
|
|
Args:
|
|
keyword: 关键词
|
|
|
|
Returns:
|
|
名称包含关键词的工具
|
|
"""
|
|
keyword = keyword.lower()
|
|
return [t for t in self._tools.values() if keyword in t.name.lower()]
|
|
|
|
def get_requires_confirmation(self, name: str) -> bool:
|
|
"""检查工具是否需要确认
|
|
|
|
Args:
|
|
name: 工具名称
|
|
|
|
Returns:
|
|
是否需要确认
|
|
"""
|
|
manifest = self._tools.get(name)
|
|
return manifest.requires_confirmation if manifest else False
|
|
|
|
def get_is_streaming(self, name: str) -> bool:
|
|
"""检查工具是否支持流式执行
|
|
|
|
Args:
|
|
name: 工具名称
|
|
|
|
Returns:
|
|
是否支持流式
|
|
"""
|
|
manifest = self._tools.get(name)
|
|
return manifest.is_streaming if manifest else False
|
|
|
|
def clear(self) -> None:
|
|
"""清空注册表"""
|
|
self._tools.clear()
|
|
self._executors.clear()
|
|
self._hooks.clear()
|
|
|
|
def __len__(self) -> int:
|
|
return len(self._tools)
|
|
|
|
def __contains__(self, name: str) -> bool:
|
|
return name in self._tools
|
|
|
|
def __iter__(self):
|
|
return iter(self._tools.values())
|
|
|
|
|
|
# 全局单例实例
|
|
_global_registry: ToolRegistry | None = None
|
|
|
|
|
|
def get_tool_registry() -> ToolRegistry:
|
|
"""获取全局工具注册表单例
|
|
|
|
Returns:
|
|
全局 ToolRegistry 实例
|
|
"""
|
|
global _global_registry
|
|
if _global_registry is None:
|
|
_global_registry = ToolRegistry()
|
|
return _global_registry
|
|
|
|
|
|
def reset_tool_registry() -> None:
|
|
"""重置全局工具注册表(用于测试)"""
|
|
global _global_registry
|
|
if _global_registry is not None:
|
|
_global_registry.clear()
|
|
_global_registry = None
|