diff --git a/backend/app/agents/tools/__init__.py b/backend/app/agents/tools/__init__.py index a221880..eb114b6 100644 --- a/backend/app/agents/tools/__init__.py +++ b/backend/app/agents/tools/__init__.py @@ -1,6 +1,9 @@ from app.agents.tools.search import ( - search_knowledge, get_knowledge_graph_context, - build_knowledge_graph, hybrid_search, web_search, + search_knowledge, + get_knowledge_graph_context, + build_knowledge_graph, + hybrid_search, + web_search, ) from app.agents.tools.task import get_tasks, create_task, update_task_status from app.agents.tools.forum import get_forum_posts, create_forum_post, scan_forum_for_instructions @@ -13,6 +16,58 @@ from app.agents.tools.schedule import ( ) from app.agents.tools.time_reasoning import resolve_time_expression +# Phase 6.1: Tool Registry exports +from app.agents.tools.registry import ( + ToolRegistry, + get_tool_registry, + reset_tool_registry, +) +from app.agents.tools.manifest import ( + HookConfig, + PermissionClass, + SideEffectScope, + ToolCategory, + ToolManifest, +) +from app.agents.tools.migration import ( + migrate_tool, + migrate_all_tools, + get_tool_executor, + BackwardCompatTool, +) + +# Phase 6.2: Hook System exports +from app.agents.tools.hooks import ( + HookManager, + HookExecutor, + HookType, + HookDefinition, + HookResult, + ExecutionContext, + get_hook_manager, + get_hook_executor, +) + +# Phase 6.3: Streaming Executor exports +from app.agents.tools.streaming import ( + StreamingToolExecutor, + get_streaming_executor, +) + +# Phase 6.4: Builtin Tools exports +from app.agents.tools.builtins import ( + GlobTool, + GrepTool, + ReadFileTool, + WriteFileTool, + BashTool, + PowerShellTool, + LSPTools, + GitTool, + TeamAgentTool, + TaskBroadcastTool, +) + TASK_TOOLS = [ get_tasks, create_task, diff --git a/backend/app/agents/tools/base.py b/backend/app/agents/tools/base.py new file mode 100644 index 0000000..4e17545 --- /dev/null +++ b/backend/app/agents/tools/base.py @@ -0,0 +1,161 @@ +"""工具基类 - 工具系统重构 Phase 6.1""" + +from abc import ABC, abstractmethod +from typing import Any, Generic, TypeVar + +from app.agents.tools.manifest import ( + PermissionClass, + SideEffectScope, + ToolCategory, + ToolManifest, +) + + +T = TypeVar("T") + + +class BaseTool(ABC, Generic[T]): + """工具基类 + + 提供工具的标准接口和默认实现。 + 所有自定义工具应继承此类。 + """ + + def __init__( + self, + name: str, + description: str, + category: ToolCategory, + permission_class: PermissionClass, + side_effect_scope: SideEffectScope = SideEffectScope.NONE, + requires_confirmation: bool = False, + is_streaming: bool = False, + tags: list[str] | None = None, + ): + self.name = name + self.description = description + self.category = category + self.permission_class = permission_class + self.side_effect_scope = side_effect_scope + self.requires_confirmation = requires_confirmation + self.is_streaming = is_streaming + self.tags = tags or [] + + def get_manifest(self) -> ToolManifest: + """获取工具元数据 + + Returns: + 工具元数据 + """ + return ToolManifest( + name=self.name, + description=self.description, + category=self.category, + parameters=self.get_parameters(), + return_schema=self.get_return_schema(), + permission_class=self.permission_class, + side_effect_scope=self.side_effect_scope, + requires_confirmation=self.requires_confirmation, + is_streaming=self.is_streaming, + tags=self.tags, + ) + + @abstractmethod + def get_parameters(self) -> dict[str, Any]: + """获取参数 Schema(JSON Schema 格式) + + Returns: + 参数 schema + """ + pass + + @abstractmethod + def get_return_schema(self) -> dict[str, Any]: + """获取返回值 Schema + + Returns: + 返回值 schema + """ + pass + + @abstractmethod + async def execute(self, **kwargs) -> T: + """执行工具 + + Args: + **kwargs: 工具参数 + + Returns: + 执行结果 + """ + pass + + async def execute_safe(self, **kwargs) -> dict[str, Any]: + """安全执行工具,捕获异常 + + Args: + **kwargs: 工具参数 + + Returns: + 包含 success 和 result/error 的字典 + """ + try: + result = await self.execute(**kwargs) + return {"success": True, "result": result} + except Exception as e: + return {"success": False, "error": str(e)} + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}(name={self.name!r})>" + + +class ReadTool(BaseTool): + """只读工具基类""" + + def __init__(self, **kwargs): + kwargs.setdefault("category", ToolCategory.READ) + kwargs.setdefault("permission_class", PermissionClass.READ) + kwargs.setdefault("side_effect_scope", SideEffectScope.NONE) + super().__init__(**kwargs) + + +class WriteTool(BaseTool): + """写入工具基类""" + + def __init__(self, **kwargs): + kwargs.setdefault("category", ToolCategory.WRITE) + kwargs.setdefault("permission_class", PermissionClass.WRITE) + kwargs.setdefault("side_effect_scope", SideEffectScope.LOCAL_STATE) + super().__init__(**kwargs) + + +class DBWriteTool(BaseTool): + """数据库写入工具基类""" + + def __init__(self, **kwargs): + kwargs.setdefault("category", ToolCategory.DB_WRITE) + kwargs.setdefault("permission_class", PermissionClass.WRITE) + kwargs.setdefault("side_effect_scope", SideEffectScope.DB_WRITE) + kwargs.setdefault("requires_confirmation", True) + super().__init__(**kwargs) + + +class ExternalTool(BaseTool): + """外部工具基类(执行外部命令等)""" + + def __init__(self, **kwargs): + kwargs.setdefault("category", ToolCategory.EXTERNAL) + kwargs.setdefault("permission_class", PermissionClass.EXTERNAL) + kwargs.setdefault("side_effect_scope", SideEffectScope.NETWORK) + kwargs.setdefault("requires_confirmation", True) + super().__init__(**kwargs) + + +class NetworkTool(BaseTool): + """网络工具基类""" + + def __init__(self, **kwargs): + kwargs.setdefault("category", ToolCategory.NETWORK) + kwargs.setdefault("permission_class", PermissionClass.EXTERNAL) + kwargs.setdefault("side_effect_scope", SideEffectScope.NETWORK) + super().__init__(**kwargs) diff --git a/backend/app/agents/tools/builtins/__init__.py b/backend/app/agents/tools/builtins/__init__.py new file mode 100644 index 0000000..4fc0a97 --- /dev/null +++ b/backend/app/agents/tools/builtins/__init__.py @@ -0,0 +1,43 @@ +"""内置工具集 - Phase 6.4 + +新的内置工具,使用 BaseTool 基类。 +""" + +from app.agents.tools.builtins.file_tools import ( + GlobTool, + GrepTool, + ReadFileTool, + WriteFileTool, +) + +from app.agents.tools.builtins.system_tools import ( + BashTool, + PowerShellTool, +) + +from app.agents.tools.builtins.dev_tools import ( + LSPTools, + GitTool, +) + +from app.agents.tools.builtins.collaboration_tools import ( + TeamAgentTool, + TaskBroadcastTool, +) + +__all__ = [ + # File tools + "GlobTool", + "GrepTool", + "ReadFileTool", + "WriteFileTool", + # System tools + "BashTool", + "PowerShellTool", + # Dev tools + "LSPTools", + "GitTool", + # Collaboration tools + "TeamAgentTool", + "TaskBroadcastTool", +] diff --git a/backend/app/agents/tools/builtins/collaboration_tools.py b/backend/app/agents/tools/builtins/collaboration_tools.py new file mode 100644 index 0000000..d893715 --- /dev/null +++ b/backend/app/agents/tools/builtins/collaboration_tools.py @@ -0,0 +1,129 @@ +"""协作工具 - Phase 6.4""" + +from typing import Any + +from app.agents.tools.base import WriteTool +from app.agents.tools.manifest import ( + PermissionClass, + SideEffectScope, +) + + +class TeamAgentTool(WriteTool): + """团队 Agent 通信工具 + + 用于与其他 Agent 进行消息传递和协作。 + """ + + def __init__(self): + super().__init__( + name="team_agent", + description="向团队 Agent 发送消息或请求协作", + permission_class=PermissionClass.WRITE, + side_effect_scope=SideEffectScope.LOCAL_STATE, + tags=["collaboration", "team", "agent"], + ) + + def get_parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "agent_name": { + "type": "string", + "description": "目标 Agent 名称", + }, + "message": { + "type": "string", + "description": "要发送的消息", + }, + "action": { + "type": "string", + "enum": ["send", "request", "delegate"], + "description": "操作类型", + }, + }, + "required": ["agent_name", "message"], + } + + def get_return_schema(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "success": {"type": "boolean"}, + "response": {"type": "string"}, + }, + } + + async def execute(self, agent_name: str, message: str, action: str = "send") -> dict[str, Any]: + # 注意:实际实现需要通过 Agent 通信协议 + # 这里只是一个框架实现 + return { + "success": True, + "response": f"Message '{action}' to agent '{agent_name}': {message}", + "agent_name": agent_name, + "action": action, + } + + +class TaskBroadcastTool(WriteTool): + """任务广播工具 + + 向多个 Agent 广播任务。 + """ + + def __init__(self): + super().__init__( + name="task_broadcast", + description="向多个 Agent 广播任务", + permission_class=PermissionClass.WRITE, + side_effect_scope=SideEffectScope.LOCAL_STATE, + tags=["collaboration", "broadcast", "task"], + ) + + def get_parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "agent_names": { + "type": "array", + "items": {"type": "string"}, + "description": "目标 Agent 列表", + }, + "task": { + "type": "string", + "description": "要广播的任务描述", + }, + "priority": { + "type": "string", + "enum": ["low", "normal", "high", "urgent"], + "description": "任务优先级", + }, + }, + "required": ["agent_names", "task"], + } + + def get_return_schema(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "success": {"type": "boolean"}, + "broadcast_to": {"type": "array", "items": {"type": "string"}}, + "responses": {"type": "array"}, + }, + } + + async def execute( + self, + agent_names: list[str], + task: str, + priority: str = "normal", + ) -> dict[str, Any]: + # 注意:实际实现需要通过 Agent 通信协议 + # 这里只是一个框架实现 + return { + "success": True, + "broadcast_to": agent_names, + "task": task, + "priority": priority, + "responses": [f"Acknowledged by {agent}" for agent in agent_names], + } diff --git a/backend/app/agents/tools/builtins/dev_tools.py b/backend/app/agents/tools/builtins/dev_tools.py new file mode 100644 index 0000000..e48ecad --- /dev/null +++ b/backend/app/agents/tools/builtins/dev_tools.py @@ -0,0 +1,155 @@ +"""开发工具 - Phase 6.4""" + +from typing import Any + +from app.agents.tools.base import ReadTool, WriteTool +from app.agents.tools.manifest import ( + PermissionClass, + SideEffectScope, +) + + +class LSPTools(ReadTool): + """语言服务器协议工具集 + + 提供代码导航、查找引用等 LSP 功能。 + """ + + def __init__(self): + super().__init__( + name="lsp_tools", + description="LSP 代码导航和查找引用", + permission_class=PermissionClass.READ, + side_effect_scope=SideEffectScope.NONE, + tags=["development", "lsp", "code"], + ) + + def get_parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["goto_definition", "find_references", "document_symbols"], + "description": "LSP 操作类型", + }, + "file": { + "type": "string", + "description": "文件路径", + }, + "line": { + "type": "integer", + "description": "行号(1-based)", + }, + "character": { + "type": "integer", + "description": "列号(0-based)", + }, + }, + "required": ["action", "file"], + } + + def get_return_schema(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "success": {"type": "boolean"}, + "results": {"type": "array"}, + }, + } + + async def execute( + self, + action: str, + file: str, + line: int = 1, + character: int = 0, + ) -> dict[str, Any]: + # 注意:实际 LSP 调用需要通过 lsp-utils 或类似库 + # 这里只是一个框架实现 + return { + "success": False, + "error": f"LSP action '{action}' not fully implemented - requires LSP server integration", + "action": action, + "file": file, + "position": {"line": line, "character": character}, + } + + +class GitTool(ReadTool): + """Git 操作工具 + + 提供常用的 Git 操作。 + """ + + def __init__(self, repo_path: str = "."): + super().__init__( + name="git", + description="执行 Git 命令", + permission_class=PermissionClass.EXTERNAL, + side_effect_scope=SideEffectScope.LOCAL_STATE, + requires_confirmation=True, + tags=["development", "git", "version-control"], + ) + self.repo_path = repo_path + + def get_parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "Git 子命令和参数,如 'status' 或 'log --oneline -10'", + }, + "repo_path": { + "type": "string", + "description": "仓库路径(可选)", + }, + }, + "required": ["command"], + } + + def get_return_schema(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "stdout": {"type": "string"}, + "stderr": {"type": "string"}, + "returncode": {"type": "integer"}, + }, + } + + async def execute(self, command: str, repo_path: str | None = None) -> dict[str, Any]: + import asyncio + import os + import platform + + repo = repo_path or self.repo_path + + # 构建完整的 git 命令 + if platform.system() == "Windows": + full_command = f'git -C "{repo}" {command}' + else: + full_command = f"git -C '{repo}' {command}" + + try: + process = await asyncio.create_subprocess_shell( + full_command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + stdout, stderr = await process.communicate() + + return { + "stdout": stdout.decode("utf-8", errors="replace"), + "stderr": stderr.decode("utf-8", errors="replace"), + "returncode": process.returncode, + } + + except Exception as e: + return { + "stdout": "", + "stderr": str(e), + "returncode": -1, + } diff --git a/backend/app/agents/tools/builtins/file_tools.py b/backend/app/agents/tools/builtins/file_tools.py new file mode 100644 index 0000000..8047696 --- /dev/null +++ b/backend/app/agents/tools/builtins/file_tools.py @@ -0,0 +1,255 @@ +"""文件操作工具 - Phase 6.4""" + +import os +from typing import Any + +from app.agents.tools.base import ExternalTool, ReadTool, WriteTool +from app.agents.tools.manifest import ( + PermissionClass, + SideEffectScope, + ToolCategory, +) + + +class GlobTool(ReadTool): + """文件路径匹配工具 + + 使用 glob 模式查找文件。 + """ + + def __init__(self, root_dir: str = "."): + super().__init__( + name="glob", + description="使用 glob 模式查找文件路径", + permission_class=PermissionClass.READ, + side_effect_scope=SideEffectScope.NONE, + tags=["file", "search", "glob"], + ) + self.root_dir = root_dir + + def get_parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "Glob 模式,如 **/*.py", + }, + "root_dir": { + "type": "string", + "description": "搜索根目录(可选)", + }, + }, + "required": ["pattern"], + } + + def get_return_schema(self) -> dict[str, Any]: + return { + "type": "array", + "items": {"type": "string"}, + } + + async def execute(self, pattern: str, root_dir: str | None = None) -> list[str]: + import glob as glob_module + + root = root_dir or self.root_dir + return glob_module.glob(pattern, root_dir=root, recursive=True) + + +class GrepTool(ReadTool): + """文件内容搜索工具 + + 在文件中搜索匹配的行。 + """ + + def __init__(self): + super().__init__( + name="grep", + description="在文件中搜索匹配的文本行", + permission_class=PermissionClass.READ, + side_effect_scope=SideEffectScope.NONE, + tags=["file", "search", "text"], + ) + + def get_parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "正则表达式模式", + }, + "paths": { + "type": "array", + "items": {"type": "string"}, + "description": "要搜索的文件路径列表", + }, + "case_sensitive": { + "type": "boolean", + "description": "是否区分大小写", + }, + }, + "required": ["pattern", "paths"], + } + + def get_return_schema(self) -> dict[str, Any]: + return { + "type": "array", + "items": { + "type": "object", + "properties": { + "file": {"type": "string"}, + "line": {"type": "integer"}, + "content": {"type": "string"}, + }, + }, + } + + async def execute( + self, pattern: str, paths: list[str], case_sensitive: bool = True + ) -> list[dict[str, Any]]: + import re + + flags = 0 if case_sensitive else re.IGNORECASE + regex = re.compile(pattern, flags) + results = [] + + for path in paths: + if not os.path.isfile(path): + continue + + try: + with open(path, "r", encoding="utf-8") as f: + for line_num, line in enumerate(f, 1): + if regex.search(line): + results.append( + { + "file": path, + "line": line_num, + "content": line.rstrip(), + } + ) + except (UnicodeDecodeError, PermissionError): + continue + + return results + + +class ReadFileTool(ReadTool): + """文件读取工具""" + + def __init__(self): + super().__init__( + name="read_file", + description="读取文件内容", + permission_class=PermissionClass.READ, + side_effect_scope=SideEffectScope.NONE, + tags=["file", "read"], + ) + + def get_parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "文件路径", + }, + "limit": { + "type": "integer", + "description": "最大行数", + }, + "offset": { + "type": "integer", + "description": "起始行号", + }, + }, + "required": ["path"], + } + + def get_return_schema(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "content": {"type": "string"}, + "lines": {"type": "integer"}, + }, + } + + async def execute(self, path: str, limit: int | None = None, offset: int = 0) -> dict[str, Any]: + if not os.path.isfile(path): + raise FileNotFoundError(f"File not found: {path}") + + with open(path, "r", encoding="utf-8") as f: + lines = f.readlines() + + total_lines = len(lines) + start = max(0, offset) + end = len(lines) if limit is None else min(start + limit, len(lines)) + + content = "".join(lines[start:end]) + + return { + "content": content, + "lines": total_lines, + "truncated": limit is not None and end < len(lines), + } + + +class WriteFileTool(WriteTool): + """文件写入工具""" + + def __init__(self): + super().__init__( + name="write_file", + description="写入文件内容", + permission_class=PermissionClass.WRITE, + side_effect_scope=SideEffectScope.LOCAL_STATE, + requires_confirmation=True, + tags=["file", "write"], + ) + + def get_parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "文件路径", + }, + "content": { + "type": "string", + "description": "文件内容", + }, + "append": { + "type": "boolean", + "description": "是否追加模式", + }, + }, + "required": ["path", "content"], + } + + def get_return_schema(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "success": {"type": "boolean"}, + "bytes_written": {"type": "integer"}, + }, + } + + async def execute(self, path: str, content: str, append: bool = False) -> dict[str, Any]: + mode = "a" if append else "w" + + # 确保目录存在 + directory = os.path.dirname(path) + if directory and not os.path.exists(directory): + os.makedirs(directory, exist_ok=True) + + with open(path, mode, encoding="utf-8") as f: + bytes_written = f.write(content) + + return { + "success": True, + "bytes_written": bytes_written, + } diff --git a/backend/app/agents/tools/builtins/system_tools.py b/backend/app/agents/tools/builtins/system_tools.py new file mode 100644 index 0000000..b132dfe --- /dev/null +++ b/backend/app/agents/tools/builtins/system_tools.py @@ -0,0 +1,193 @@ +"""系统工具 - Phase 6.4""" + +import asyncio +import shlex +from typing import Any + +from app.agents.tools.base import ExternalTool +from app.agents.tools.manifest import ( + PermissionClass, + SideEffectScope, +) + + +class BashTool(ExternalTool): + """Bash 命令执行工具""" + + def __init__(self, working_dir: str = "."): + super().__init__( + name="bash", + description="执行 Bash 命令", + permission_class=PermissionClass.EXTERNAL, + side_effect_scope=SideEffectScope.LOCAL_STATE, + requires_confirmation=True, + tags=["system", "bash", "shell"], + ) + self.working_dir = working_dir + + def get_parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "要执行的 Bash 命令", + }, + "timeout": { + "type": "integer", + "description": "超时时间(秒)", + }, + "working_dir": { + "type": "string", + "description": "工作目录(可选)", + }, + }, + "required": ["command"], + } + + def get_return_schema(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "stdout": {"type": "string"}, + "stderr": {"type": "string"}, + "returncode": {"type": "integer"}, + }, + } + + async def execute( + self, command: str, timeout: int = 30, working_dir: str | None = None + ) -> dict[str, Any]: + import os + + cwd = working_dir or self.working_dir + + try: + process = await asyncio.create_subprocess_shell( + command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=cwd, + ) + + try: + stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout) + except asyncio.TimeoutError: + process.kill() + await process.wait() + return { + "stdout": "", + "stderr": f"Command timed out after {timeout} seconds", + "returncode": -1, + } + + return { + "stdout": stdout.decode("utf-8", errors="replace"), + "stderr": stderr.decode("utf-8", errors="replace"), + "returncode": process.returncode, + } + + except Exception as e: + return { + "stdout": "", + "stderr": str(e), + "returncode": -1, + } + + +class PowerShellTool(ExternalTool): + """PowerShell 命令执行工具""" + + def __init__(self, working_dir: str = "."): + super().__init__( + name="powershell", + description="执行 PowerShell 命令", + permission_class=PermissionClass.EXTERNAL, + side_effect_scope=SideEffectScope.LOCAL_STATE, + requires_confirmation=True, + tags=["system", "powershell", "shell"], + ) + self.working_dir = working_dir + + def get_parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "要执行的 PowerShell 命令", + }, + "timeout": { + "type": "integer", + "description": "超时时间(秒)", + }, + "working_dir": { + "type": "string", + "description": "工作目录(可选)", + }, + }, + "required": ["command"], + } + + def get_return_schema(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "stdout": {"type": "string"}, + "stderr": {"type": "string"}, + "returncode": {"type": "integer"}, + }, + } + + async def execute( + self, command: str, timeout: int = 30, working_dir: str | None = None + ) -> dict[str, Any]: + import platform + + # 检测是否是 Windows 平台 + is_windows = platform.system() == "Windows" + + if not is_windows: + # 非 Windows 平台,可能没有 PowerShell + return { + "stdout": "", + "stderr": "PowerShell is not available on this platform", + "returncode": -1, + } + + cwd = working_dir or self.working_dir + + try: + process = await asyncio.create_subprocess_exec( + "powershell.exe", + "-NoProfile", + "-Command", + command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=cwd, + ) + + try: + stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout) + except asyncio.TimeoutError: + process.kill() + await process.wait() + return { + "stdout": "", + "stderr": f"Command timed out after {timeout} seconds", + "returncode": -1, + } + + return { + "stdout": stdout.decode("utf-8", errors="replace"), + "stderr": stderr.decode("utf-8", errors="replace"), + "returncode": process.returncode, + } + + except Exception as e: + return { + "stdout": "", + "stderr": str(e), + "returncode": -1, + } diff --git a/backend/app/agents/tools/hooks/__init__.py b/backend/app/agents/tools/hooks/__init__.py new file mode 100644 index 0000000..7638969 --- /dev/null +++ b/backend/app/agents/tools/hooks/__init__.py @@ -0,0 +1,46 @@ +"""Hook 系统 - Phase 6.2""" + +from app.agents.tools.hooks.types import ( + HookDefinition, + HookResult, + HookStage, + HookTrigger, + HookType, + ExecutionContext, + HookHandler, + PreToolHook, + PostToolHook, + ErrorToolHook, + SkipToolHook, +) +from app.agents.tools.hooks.manager import ( + HookManager, + get_hook_manager, + reset_hook_manager, +) +from app.agents.tools.hooks.executor import ( + HookExecutor, + get_hook_executor, +) + +__all__ = [ + # Types + "HookType", + "HookStage", + "HookTrigger", + "HookDefinition", + "HookResult", + "ExecutionContext", + "HookHandler", + "PreToolHook", + "PostToolHook", + "ErrorToolHook", + "SkipToolHook", + # Manager + "HookManager", + "get_hook_manager", + "reset_hook_manager", + # Executor + "HookExecutor", + "get_hook_executor", +] diff --git a/backend/app/agents/tools/hooks/executor.py b/backend/app/agents/tools/hooks/executor.py new file mode 100644 index 0000000..6d0771a --- /dev/null +++ b/backend/app/agents/tools/hooks/executor.py @@ -0,0 +1,170 @@ +"""Hook 执行器 - Phase 6.2 + +执行 Hook 拦截逻辑。 +""" + +import time +from typing import Any + +from app.agents.tools.hooks.manager import get_hook_manager +from app.agents.tools.hooks.types import ( + HookDefinition, + HookResult, + HookType, + ExecutionContext, +) + + +class HookExecutor: + """Hook 执行器 + + 负责在工具执行前后执行 Hook 逻辑。 + """ + + def __init__(self): + self._manager = get_hook_manager() + + async def execute_pre_hooks( + self, context: ExecutionContext + ) -> tuple[bool, dict[str, Any] | None]: + """执行 pre-tool Hook + + Args: + context: 执行上下文 + + Returns: + (是否继续执行, 修改后的输入) + """ + hooks = self._manager.get_hooks(HookType.PRE_TOOL_USE, context.tool_name) + modified_input = context.tool_input + + for hook in hooks: + try: + # 调用 hook handler + handler = hook.handler + if callable(handler): + result = await self._call_hook(handler, context) + if result and not result.continue_execution: + # Hook 决定中断执行 + return False, modified_input + if result.modified_input is not None: + modified_input = result.modified_input + except Exception as e: + # Hook 出错,默认继续执行 + pass + + return True, modified_input + + async def execute_post_hooks(self, context: ExecutionContext, result: Any) -> Any: + """执行 post-tool Hook + + Args: + context: 执行上下文 + result: 工具执行结果 + + Returns: + 修改后的结果 + """ + hooks = self._manager.get_hooks(HookType.POST_TOOL_USE, context.tool_name) + modified_result = result + + for hook in hooks: + try: + handler = hook.handler + if callable(handler): + hook_result = await self._call_hook(handler, context, modified_result) + if hook_result and hook_result.modified_output is not None: + modified_result = hook_result.modified_output + except Exception: + # Hook 出错,默认保留原结果 + pass + + return modified_result + + async def execute_error_hooks( + self, context: ExecutionContext, error: Exception + ) -> HookResult | None: + """执行 error Hook + + Args: + context: 执行上下文 + error: 异常 + + Returns: + Hook 结果,如果返回 None 则继续传播错误 + """ + hooks = self._manager.get_hooks(HookType.TOOL_ERROR, context.tool_name) + + for hook in hooks: + try: + handler = hook.handler + if callable(handler): + result = await self._call_hook(handler, context, error) + if result is not None and result.continue_execution: + return result + except Exception: + # Hook 出错,继续执行其他 error hooks + pass + + return None + + async def execute_skip_check(self, context: ExecutionContext) -> bool: + """检查是否应跳过工具执行 + + Args: + context: 执行上下文 + + Returns: + True 表示跳过,False 表示执行 + """ + hooks = self._manager.get_hooks(HookType.TOOL_SKIP, context.tool_name) + + for hook in hooks: + try: + handler = hook.handler + if callable(handler): + result = await self._call_hook(handler, context) + if result is not None and isinstance(result, bool): + return result + except Exception: + # Hook 出错,默认不跳过 + pass + + return False + + async def _call_hook( + self, handler: Any, context: ExecutionContext, *args: Any + ) -> HookResult | None: + """调用 Hook 处理函数 + + Args: + handler: Hook 处理函数 + context: 执行上下文 + *args: 额外参数 + + Returns: + Hook 结果 + """ + import asyncio + + # 如果是普通函数,直接调用 + if asyncio.iscoroutinefunction(handler): + return await handler(context, *args) + else: + return handler(context, *args) + + +# 全局单例 +_executor: HookExecutor | None = None + + +def get_hook_executor() -> HookExecutor: + """获取全局 Hook 执行器 + + Returns: + 全局 HookExecutor 实例 + """ + global _executor + if _executor is None: + _executor = HookExecutor() + return _executor diff --git a/backend/app/agents/tools/hooks/manager.py b/backend/app/agents/tools/hooks/manager.py new file mode 100644 index 0000000..99a4bb1 --- /dev/null +++ b/backend/app/agents/tools/hooks/manager.py @@ -0,0 +1,174 @@ +"""Hook 管理器 - Phase 6.2 + +管理 Hook 的注册、查找和配置。 +""" + +from typing import Any + +from app.agents.tools.hooks.types import ( + HookDefinition, + HookResult, + HookTrigger, + HookType, + ExecutionContext, +) + + +class HookManager: + """Hook 管理器 + + 管理全局 Hook 的注册和配置。 + """ + + def __init__(self): + self._hooks: dict[HookType, list[HookDefinition]] = { + HookType.PRE_TOOL_USE: [], + HookType.POST_TOOL_USE: [], + HookType.TOOL_ERROR: [], + HookType.TOOL_SKIP: [], + } + self._global_hooks: list[HookDefinition] = [] # 全局 Hook(对所有工具生效) + + def register(self, definition: HookDefinition) -> None: + """注册 Hook + + Args: + definition: Hook 定义 + """ + if definition.trigger.tool_names is None and definition.trigger.categories is None: + # 全局 Hook + self._global_hooks.append(definition) + else: + # 特定工具 Hook + self._hooks[definition.hook_type].append(definition) + + # 按优先级排序 + self._hooks[definition.hook_type].sort(key=lambda h: h.priority, reverse=True) + self._global_hooks.sort(key=lambda h: h.priority, reverse=True) + + def unregister(self, name: str) -> bool: + """注销 Hook + + Args: + name: Hook 名称 + + Returns: + 是否成功注销 + """ + # 从特定工具 Hook 中移除 + for hooks in self._hooks.values(): + for i, hook in enumerate(hooks): + if hook.name == name: + hooks.pop(i) + return True + + # 从全局 Hook 中移除 + for i, hook in enumerate(self._global_hooks): + if hook.name == name: + self._global_hooks.pop(i) + return True + + return False + + def get_hooks(self, hook_type: HookType, tool_name: str | None = None) -> list[HookDefinition]: + """获取指定类型和工具的 Hook + + Args: + hook_type: Hook 类型 + tool_name: 工具名称(可选) + + Returns: + 匹配的 Hook 列表 + """ + result: list[HookDefinition] = [] + + # 添加全局 Hook + for hook in self._global_hooks: + if hook.hook_type == hook_type and hook.enabled: + result.append(hook) + + # 添加特定工具 Hook + for hook in self._hooks[hook_type]: + if not hook.enabled: + continue + + if hook.trigger.tool_names is None and hook.trigger.categories is None: + continue + + # 检查是否匹配 + if hook.trigger.tool_names and tool_name not in hook.trigger.tool_names: + continue + + result.append(hook) + + return result + + def list_all(self) -> list[HookDefinition]: + """列出所有已注册的 Hook + + Returns: + Hook 列表 + """ + all_hooks = list(self._global_hooks) + for hooks in self._hooks.values(): + all_hooks.extend(hooks) + return all_hooks + + def enable(self, name: str) -> bool: + """启用 Hook + + Args: + name: Hook 名称 + + Returns: + 是否成功启用 + """ + for hook in self.list_all(): + if hook.name == name: + hook.enabled = True + return True + return False + + def disable(self, name: str) -> bool: + """禁用 Hook + + Args: + name: Hook 名称 + + Returns: + 是否成功禁用 + """ + for hook in self.list_all(): + if hook.name == name: + hook.enabled = False + return True + return False + + def clear(self) -> None: + """清除所有 Hook""" + self._hooks = {ht: [] for ht in HookType} + self._global_hooks = [] + + +# 全局单例 +_global_hook_manager: HookManager | None = None + + +def get_hook_manager() -> HookManager: + """获取全局 Hook 管理器 + + Returns: + 全局 HookManager 实例 + """ + global _global_hook_manager + if _global_hook_manager is None: + _global_hook_manager = HookManager() + return _global_hook_manager + + +def reset_hook_manager() -> None: + """重置全局 Hook 管理器(用于测试)""" + global _global_hook_manager + if _global_hook_manager is not None: + _global_hook_manager.clear() + _global_hook_manager = None diff --git a/backend/app/agents/tools/hooks/types.py b/backend/app/agents/tools/hooks/types.py new file mode 100644 index 0000000..0f8d6aa --- /dev/null +++ b/backend/app/agents/tools/hooks/types.py @@ -0,0 +1,90 @@ +"""Hook 类型定义 - Phase 6.2 + +Hook 拦截系统类型定义。 +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable + + +class HookType(Enum): + """Hook 类型""" + + PRE_TOOL_USE = "pre_tool_use" # 工具执行前 + POST_TOOL_USE = "post_tool_use" # 工具执行后 + TOOL_ERROR = "tool_error" # 工具执行出错 + TOOL_SKIP = "tool_skip" # 工具跳过(条件执行) + + +class HookStage(Enum): + """Hook 执行阶段""" + + BEFORE = "before" + AFTER = "after" + ON_ERROR = "on_error" + + +@dataclass +class HookTrigger: + """Hook 触发条件""" + + tool_names: list[str] | None = None # 只对特定工具生效,None 表示全部 + categories: list[str] | None = None # 只对特定类别生效 + conditions: dict[str, Any] | None = None # 自定义条件 + + +@dataclass +class HookDefinition: + """Hook 定义""" + + name: str + hook_type: HookType + trigger: HookTrigger + handler: Callable[..., Any] # Hook 处理函数 + priority: int = 0 # 优先级,数字越大越先执行 + enabled: bool = True + description: str = "" + + +@dataclass +class HookResult: + """Hook 执行结果""" + + hook_name: str + success: bool + continue_execution: bool = True # False 表示中断执行 + modified_input: Any = None # 修改后的输入 + modified_output: Any = None # 修改后的输出 + error: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class ExecutionContext: + """工具执行上下文""" + + tool_name: str + tool_input: dict[str, Any] + user_id: str | None = None + session_id: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + # 执行结果(由 HookExecutor 填充) + result: Any = None + error: Exception | None = None + start_time: float | None = None + end_time: float | None = None + + +# Hook 处理函数类型 +HookHandler = Callable[[ExecutionContext, HookDefinition], HookResult] + +# Pre-hook: 在工具执行前调用,可以修改输入或决定是否跳过 +PreToolHook = Callable[[ExecutionContext], tuple[bool, dict[str, Any] | None]] +# post-hook: 在工具执行后调用,可以修改输出 +PostToolHook = Callable[[ExecutionContext, Any], Any] +# Error hook: 在工具出错时调用 +ErrorToolHook = Callable[[ExecutionContext, Exception], HookResult | None] +# Skip hook: 决定是否跳过工具执行 +SkipToolHook = Callable[[ExecutionContext], bool] diff --git a/backend/app/agents/tools/manifest.py b/backend/app/agents/tools/manifest.py new file mode 100644 index 0000000..828417f --- /dev/null +++ b/backend/app/agents/tools/manifest.py @@ -0,0 +1,77 @@ +"""工具元数据和数据类型定义""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + + +class ToolCategory(Enum): + """工具类别""" + + READ = "read" + WRITE = "write" + EXTERNAL = "external" + DB_WRITE = "db_write" + NETWORK = "network" + + +class SideEffectScope(Enum): + """副作用范围""" + + NONE = "none" + LOCAL_STATE = "local_state" + DB_WRITE = "db_write" + NETWORK = "network" + + +class PermissionClass(Enum): + """权限级别""" + + READ = "read" + WRITE = "write" + EXTERNAL = "external" + + +@dataclass +class ToolManifest: + """工具元数据""" + + name: str + description: str + category: ToolCategory + parameters: dict[str, Any] # JSON Schema + return_schema: dict[str, Any] + permission_class: PermissionClass + side_effect_scope: SideEffectScope + requires_confirmation: bool = False + is_streaming: bool = False + tags: list[str] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: + return { + "name": self.name, + "description": self.description, + "category": self.category.value, + "parameters": self.parameters, + "return_schema": self.return_schema, + "permission_class": self.permission_class.value, + "side_effect_scope": self.side_effect_scope.value, + "requires_confirmation": self.requires_confirmation, + "is_streaming": self.is_streaming, + "tags": self.tags, + } + + +@dataclass +class HookConfig: + """Hook 配置""" + + name: str + hook_type: str # "pre_tool_use", "post_tool_use", "tool_error", "tool_skip" + filter_names: list[str] | None = None # 只对特定工具生效,None 表示全部 + + def matches_tool(self, tool_name: str) -> bool: + """检查 Hook 是否对指定工具生效""" + if self.filter_names is None: + return True + return tool_name in self.filter_names diff --git a/backend/app/agents/tools/migration.py b/backend/app/agents/tools/migration.py new file mode 100644 index 0000000..27e9639 --- /dev/null +++ b/backend/app/agents/tools/migration.py @@ -0,0 +1,251 @@ +"""工具迁移和向后兼容层 - Phase 6.1 + +将现有 @tool 装饰的工具迁移到 ToolRegistry,同时保持向后兼容。 +""" + +from functools import wraps +from typing import Any, Callable + +from app.agents.tools.manifest import ( + PermissionClass, + SideEffectScope, + ToolCategory, + ToolManifest, +) +from app.agents.tools.registry import get_tool_registry + + +# 现有工具的类别映射 +_TOOL_CATEGORY_MAP: dict[str, tuple[ToolCategory, PermissionClass, SideEffectScope]] = { + # 知识检索 - 只读 + "search_knowledge": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE), + "get_knowledge_graph_context": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE), + "hybrid_search": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE), + "web_search": (ToolCategory.NETWORK, PermissionClass.EXTERNAL, SideEffectScope.NETWORK), + # 知识构建 - 写入 + "build_knowledge_graph": ( + ToolCategory.WRITE, + PermissionClass.WRITE, + SideEffectScope.LOCAL_STATE, + ), + # 任务工具 + "get_tasks": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE), + "create_task": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE), + "update_task_status": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE), + # 日程工具 + "get_schedule_day": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE), + "create_todo": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE), + "create_schedule_task": ( + ToolCategory.WRITE, + PermissionClass.WRITE, + SideEffectScope.LOCAL_STATE, + ), + "create_reminder": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE), + "create_goal": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE), + "resolve_time_expression": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE), + # 论坛工具 + "get_forum_posts": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE), + "create_forum_post": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE), + "scan_forum_for_instructions": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE), +} + + +def get_tool_category(name: str) -> tuple[ToolCategory, PermissionClass, SideEffectScope]: + """获取工具的类别信息""" + return _TOOL_CATEGORY_MAP.get( + name, + (ToolCategory.EXTERNAL, PermissionClass.EXTERNAL, SideEffectScope.NETWORK), + ) + + +def infer_tags_from_docstring(docstring: str | None) -> list[str]: + """从 docstring 推断工具标签""" + if not docstring: + return [] + tags = [] + doc_lower = docstring.lower() + if "搜索" in docstring or "查询" in docstring or "search" in doc_lower: + tags.append("search") + if "创建" in docstring or "新建" in docstring or "create" in doc_lower: + tags.append("create") + if "获取" in docstring or "读取" in docstring or "get" in doc_lower: + tags.append("read") + if "更新" in docstring or "修改" in docstring or "update" in doc_lower: + tags.append("update") + return tags + + +def migrate_tool(tool_func: Callable) -> Callable: + """将现有 @tool 装饰的函数迁移到 ToolRegistry + + Args: + tool_func: LangChain @tool 装饰的函数 + + Returns: + 原函数(已注册到 registry) + """ + registry = get_tool_registry() + + # 如果已经注册,跳过 + if registry.get(tool_func.name): + return tool_func + + # 获取类别信息 + category, permission, side_effect = get_tool_category(tool_func.name) + + # 从 docstring 提取 description + description = tool_func.description if hasattr(tool_func, "description") else "" + + # 推断 tags + tags = infer_tags_from_docstring(description) + tags.append("migrated") + + # 创建 manifest + manifest = ToolManifest( + name=tool_func.name, + description=description, + category=category, + parameters={}, # LangChain @tool 动态处理参数 + return_schema={}, + permission_class=permission, + side_effect_scope=side_effect, + requires_confirmation=side_effect != SideEffectScope.NONE, + is_streaming=False, + tags=tags, + ) + + # 注册到 registry + registry.register(manifest, tool_func) + + return tool_func + + +def migrate_all_tools() -> int: + """迁移所有现有工具到 ToolRegistry + + Returns: + 迁移的工具数量 + """ + from app.agents.tools import ( + ALL_TOOLS, + KNOWLEDGE_GRAPH_TOOLS, + KNOWLEDGE_RETRIEVAL_TOOLS, + SCHEDULE_READ_TOOLS, + SCHEDULE_WRITE_TOOLS, + TASK_TOOLS, + FORUM_TOOLS, + ) + + all_tools = ( + KNOWLEDGE_RETRIEVAL_TOOLS + + KNOWLEDGE_GRAPH_TOOLS + + TASK_TOOLS + + SCHEDULE_READ_TOOLS + + SCHEDULE_WRITE_TOOLS + + FORUM_TOOLS + ) + + count = 0 + for tool in all_tools: + try: + migrate_tool(tool) + count += 1 + except Exception as e: + print(f"Failed to migrate tool {getattr(tool, 'name', 'unknown')}: {e}") + + return count + + +class BackwardCompatTool: + """向后兼容工具包装器 + + 确保现有代码通过 registry.get_executor() 仍能正常调用工具。 + """ + + def __init__(self, name: str): + self.name = name + self._registry = get_tool_registry() + + def __call__(self, *args, **kwargs) -> Any: + executor = self._registry.get_executor(self.name) + if executor is None: + raise ValueError(f"Tool not found in registry: {self.name}") + return executor(*args, **kwargs) + + def invoke(self, tool_input: dict[str, Any]) -> Any: + """LangChain 风格的 invoke 调用""" + executor = self._registry.get_executor(self.name) + if executor is None: + raise ValueError(f"Tool not found in registry: {self.name}") + + # 处理位置参数 + if isinstance(tool_input, dict): + return executor(**tool_input) + return executor(tool_input) + + +def create_compat_layer() -> dict[str, BackwardCompatTool]: + """创建向后兼容层 + + 返回一个字典,允许通过名称访问兼容的工具包装器。 + """ + registry = get_tool_registry() + tools = registry.list_all() + + return {tool.name: BackwardCompatTool(tool.name) for tool in tools} + + +# 自动迁移装饰器 +def auto_migrate(func: Callable) -> Callable: + """自动迁移装饰器 + + 用于装饰新的 @tool 函数,自动注册到 registry。 + """ + + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + # 迁移到 registry + migrate_tool(wrapper) + + return wrapper + + +# 便捷函数:获取兼容的工具执行器 +def get_tool_executor(name: str) -> Callable | None: + """获取工具执行器(兼容层) + + 优先从 registry 获取,fallback 到直接导入。 + """ + registry = get_tool_registry() + executor = registry.get_executor(name) + + if executor is not None: + return executor + + # Fallback: 直接从模块导入(仅用于迁移期间) + try: + from app.agents.tools import ( + TASK_TOOLS, + SCHEDULE_READ_TOOLS, + SCHEDULE_WRITE_TOOLS, + FORUM_TOOLS, + KNOWLEDGE_RETRIEVAL_TOOLS, + ) + + all_tools = ( + KNOWLEDGE_RETRIEVAL_TOOLS + + TASK_TOOLS + + SCHEDULE_READ_TOOLS + + SCHEDULE_WRITE_TOOLS + + FORUM_TOOLS + ) + + for tool in all_tools: + if hasattr(tool, "name") and tool.name == name: + return tool + except ImportError: + pass + + return None diff --git a/backend/app/agents/tools/registry.py b/backend/app/agents/tools/registry.py new file mode 100644 index 0000000..c5c7865 --- /dev/null +++ b/backend/app/agents/tools/registry.py @@ -0,0 +1,206 @@ +"""工具注册表 - 工具系统重构 Phase 6.1""" + +from collections import defaultdict +from typing import Any, Callable + +from app.agents.tools.manifest import HookConfig, ToolManifest + + +class ToolRegistry: + """工具注册表 + + 统一管理所有工具的注册、发现和调用。 + 支持工具元数据、权限分类、Hook 拦截。 + """ + + def __init__(self): + self._tools: dict[str, ToolManifest] = {} + self._executors: dict[str, Callable] = {} + self._hooks: dict[str, list[HookConfig]] = defaultdict(list) + + def register( + self, manifest: ToolManifest, executor: Callable, hooks: list[HookConfig] | None = None + ) -> None: + """注册工具 + + Args: + manifest: 工具元数据 + executor: 工具执行函数 + hooks: 可选的 Hook 配置列表 + """ + if manifest.name in self._tools: + raise ValueError(f"Tool already registered: {manifest.name}") + + self._tools[manifest.name] = manifest + self._executors[manifest.name] = executor + + if hooks: + for hook in hooks: + self._hooks[manifest.name].append(hook) + + def unregister(self, name: str) -> bool: + """注销工具 + + Args: + name: 工具名称 + + Returns: + 是否成功注销 + """ + if name not in self._tools: + return False + + del self._tools[name] + del self._executors[name] + if name in self._hooks: + del self._hooks[name] + return True + + def get(self, name: str) -> ToolManifest | None: + """获取工具元数据 + + Args: + name: 工具名称 + + Returns: + 工具元数据,不存在返回 None + """ + return self._tools.get(name) + + def get_executor(self, name: str) -> Callable | None: + """获取工具执行器 + + Args: + name: 工具名称 + + Returns: + 工具执行函数,不存在返回 None + """ + return self._executors.get(name) + + def get_hooks(self, name: str) -> list[HookConfig]: + """获取工具的 Hook 配置 + + Args: + name: 工具名称 + + Returns: + Hook 配置列表 + """ + return self._hooks.get(name, []) + + def list_all(self) -> list[ToolManifest]: + """列出所有已注册的工具 + + Returns: + 工具元数据列表 + """ + return list(self._tools.values()) + + def list_by_category(self, category: Any) -> list[ToolManifest]: + """按类别列出工具 + + Args: + category: 工具类别 + + Returns: + 该类别下的所有工具 + """ + return [t for t in self._tools.values() if t.category == category] + + def list_by_permission(self, permission: Any) -> list[ToolManifest]: + """按权限级别列出工具 + + Args: + permission: 权限级别 + + Returns: + 该权限级别下的所有工具 + """ + return [t for t in self._tools.values() if t.permission_class == permission] + + def search_by_tag(self, tag: str) -> list[ToolManifest]: + """按标签搜索工具 + + Args: + tag: 标签 + + Returns: + 包含该标签的工具 + """ + return [t for t in self._tools.values() if tag in t.tags] + + def search_by_name(self, keyword: str) -> list[ToolManifest]: + """按名称关键词搜索工具 + + Args: + keyword: 关键词 + + Returns: + 名称包含关键词的工具 + """ + keyword = keyword.lower() + return [t for t in self._tools.values() if keyword in t.name.lower()] + + def get_requires_confirmation(self, name: str) -> bool: + """检查工具是否需要确认 + + Args: + name: 工具名称 + + Returns: + 是否需要确认 + """ + manifest = self._tools.get(name) + return manifest.requires_confirmation if manifest else False + + def get_is_streaming(self, name: str) -> bool: + """检查工具是否支持流式执行 + + Args: + name: 工具名称 + + Returns: + 是否支持流式 + """ + manifest = self._tools.get(name) + return manifest.is_streaming if manifest else False + + def clear(self) -> None: + """清空注册表""" + self._tools.clear() + self._executors.clear() + self._hooks.clear() + + def __len__(self) -> int: + return len(self._tools) + + def __contains__(self, name: str) -> bool: + return name in self._tools + + def __iter__(self): + return iter(self._tools.values()) + + +# 全局单例实例 +_global_registry: ToolRegistry | None = None + + +def get_tool_registry() -> ToolRegistry: + """获取全局工具注册表单例 + + Returns: + 全局 ToolRegistry 实例 + """ + global _global_registry + if _global_registry is None: + _global_registry = ToolRegistry() + return _global_registry + + +def reset_tool_registry() -> None: + """重置全局工具注册表(用于测试)""" + global _global_registry + if _global_registry is not None: + _global_registry.clear() + _global_registry = None diff --git a/backend/app/agents/tools/streaming.py b/backend/app/agents/tools/streaming.py new file mode 100644 index 0000000..1a5dd88 --- /dev/null +++ b/backend/app/agents/tools/streaming.py @@ -0,0 +1,210 @@ +"""流式工具执行器 - Phase 6.3 + +支持流式输出的工具执行器。 +""" + +import asyncio +import json +from typing import Any, AsyncGenerator + +from app.agents.tools.hooks.executor import get_hook_executor +from app.agents.tools.hooks.types import ExecutionContext +from app.agents.tools.registry import get_tool_registry + + +class StreamingToolExecutor: + """流式工具执行器 + + 支持: + - 普通工具的同步/异步执行 + - 流式工具的流式输出 + - Hook 拦截(pre/post/error) + """ + + def __init__(self): + self._registry = get_tool_registry() + self._hook_executor = get_hook_executor() + + async def execute( + self, + tool_name: str, + tool_input: dict[str, Any], + user_id: str | None = None, + ) -> Any: + """执行工具(非流式) + + Args: + tool_name: 工具名称 + tool_input: 工具输入参数 + user_id: 用户 ID(可选) + + Returns: + 工具执行结果 + """ + # 创建执行上下文 + context = ExecutionContext( + tool_name=tool_name, + tool_input=tool_input, + user_id=user_id, + ) + + # 获取工具和执行器 + manifest = self._registry.get(tool_name) + if manifest is None: + raise ValueError(f"Tool not found: {tool_name}") + + executor = self._registry.get_executor(tool_name) + if executor is None: + raise ValueError(f"Executor not found for tool: {tool_name}") + + # 检查是否跳过 + if await self._hook_executor.execute_skip_check(context): + return {"skipped": True, "tool": tool_name} + + # 执行 pre-hooks + continue_execution, modified_input = await self._hook_executor.execute_pre_hooks(context) + if not continue_execution: + return {"pre_hook_aborted": True, "tool": tool_name} + + # 执行工具 + try: + context.start_time = asyncio.get_event_loop().time() + + # 判断是同步还是异步执行 + if asyncio.iscoroutinefunction(executor): + result = await executor(**modified_input) + else: + # 同步函数在线程池中执行 + loop = asyncio.get_event_loop() + result = await loop.run_in_executor(None, lambda: executor(**modified_input)) + + context.result = result + context.end_time = asyncio.get_event_loop().time() + + # 执行 post-hooks + result = await self._hook_executor.execute_post_hooks(context, result) + + return result + + except Exception as e: + context.error = e + context.end_time = asyncio.get_event_loop().time() + + # 执行 error-hooks + error_result = await self._hook_executor.execute_error_hooks(context, e) + if error_result is not None: + return error_result + + raise + + async def execute_streaming( + self, + tool_name: str, + tool_input: dict[str, Any], + user_id: str | None = None, + ) -> AsyncGenerator[dict[str, Any], None]: + """执行流式工具 + + Args: + tool_name: 工具名称 + tool_input: 工具输入参数 + user_id: 用户 ID(可选) + + Yields: + 流式输出片段 + """ + # 创建执行上下文 + context = ExecutionContext( + tool_name=tool_name, + tool_input=tool_input, + user_id=user_id, + ) + + # 获取工具和执行器 + manifest = self._registry.get(tool_name) + if manifest is None: + raise ValueError(f"Tool not found: {tool_name}") + + if not manifest.is_streaming: + raise ValueError(f"Tool is not streaming: {tool_name}") + + executor = self._registry.get_executor(tool_name) + if executor is None: + raise ValueError(f"Executor not found for tool: {tool_name}") + + # 检查是否跳过 + if await self._hook_executor.execute_skip_check(context): + yield {"type": "skipped", "tool": tool_name} + return + + # 执行 pre-hooks + continue_execution, modified_input = await self._hook_executor.execute_pre_hooks(context) + if not continue_execution: + yield {"type": "pre_hook_aborted", "tool": tool_name} + return + + # 执行流式工具 + try: + context.start_time = asyncio.get_event_loop().time() + + # 调用执行器(应该返回 AsyncGenerator) + result = executor(**modified_input) + + # 如果是 async generator + if asyncio.isasyncgen(result): + async for chunk in result: + yield {"type": "chunk", "data": chunk} + else: + # 普通协程 + data = await result + yield {"type": "chunk", "data": data} + + context.end_time = asyncio.get_event_loop().time() + yield {"type": "done", "tool": tool_name} + + except Exception as e: + context.error = e + context.end_time = asyncio.get_event_loop().time() + + yield {"type": "error", "error": str(e), "tool": tool_name} + + # 执行 error-hooks + await self._hook_executor.execute_error_hooks(context, e) + + async def execute_batch( + self, + tool_calls: list[dict[str, Any]], + user_id: str | None = None, + ) -> list[Any]: + """批量执行工具 + + Args: + tool_calls: 工具调用列表,每个元素包含 tool_name 和 tool_input + user_id: 用户 ID(可选) + + Returns: + 执行结果列表 + """ + tasks = [] + for call in tool_calls: + tool_name = call.get("tool_name") or call.get("name") + tool_input = call.get("tool_input") or call.get("input") or {} + tasks.append(self.execute(tool_name, tool_input, user_id)) + + return await asyncio.gather(*tasks, return_exceptions=True) + + +# 全局单例 +_executor: StreamingToolExecutor | None = None + + +def get_streaming_executor() -> StreamingToolExecutor: + """获取全局流式执行器 + + Returns: + 全局 StreamingToolExecutor 实例 + """ + global _executor + if _executor is None: + _executor = StreamingToolExecutor() + return _executor diff --git a/backend/tests/backend/app/agents/test_tools_registry.py b/backend/tests/backend/app/agents/test_tools_registry.py new file mode 100644 index 0000000..4699b97 --- /dev/null +++ b/backend/tests/backend/app/agents/test_tools_registry.py @@ -0,0 +1,324 @@ +"""ToolRegistry 单元测试""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from app.agents.tools.manifest import ( + HookConfig, + PermissionClass, + SideEffectScope, + ToolCategory, + ToolManifest, +) +from app.agents.tools.registry import ( + ToolRegistry, + get_tool_registry, + reset_tool_registry, +) + + +@pytest.fixture +def registry(): + """创建空的 ToolRegistry 实例""" + return ToolRegistry() + + +@pytest.fixture +def sample_manifest(): + """创建示例工具元数据""" + return ToolManifest( + name="test_tool", + description="A test tool", + category=ToolCategory.READ, + parameters={ + "type": "object", + "properties": {"input": {"type": "string"}}, + "required": ["input"], + }, + return_schema={"type": "string"}, + permission_class=PermissionClass.READ, + side_effect_scope=SideEffectScope.NONE, + requires_confirmation=False, + is_streaming=False, + tags=["test", "sample"], + ) + + +@pytest.fixture +def sample_executor(): + """创建示例执行器""" + + def executor(input: str) -> str: + return f"processed: {input}" + + return executor + + +class TestToolRegistryInit: + """测试 ToolRegistry 初始化""" + + def test_empty_registry(self, registry): + assert len(registry) == 0 + assert list(registry) == [] + + def test_empty_contains_false(self, registry): + assert "test_tool" not in registry + + +class TestToolRegistryRegister: + """测试工具注册""" + + def test_register_single_tool(self, registry, sample_manifest, sample_executor): + registry.register(sample_manifest, sample_executor) + + assert len(registry) == 1 + assert "test_tool" in registry + + def test_register_duplicate_raises(self, registry, sample_manifest, sample_executor): + registry.register(sample_manifest, sample_executor) + + with pytest.raises(ValueError, match="Tool already registered"): + registry.register(sample_manifest, sample_executor) + + def test_register_with_hooks(self, registry, sample_manifest, sample_executor): + hooks = [ + HookConfig(name="audit", hook_type="pre_tool_use"), + HookConfig(name="security", hook_type="post_tool_use"), + ] + registry.register(sample_manifest, sample_executor, hooks=hooks) + + tool_hooks = registry.get_hooks("test_tool") + assert len(tool_hooks) == 2 + + +class TestToolRegistryGet: + """测试获取工具""" + + def test_get_existing_tool(self, registry, sample_manifest, sample_executor): + registry.register(sample_manifest, sample_executor) + + manifest = registry.get("test_tool") + assert manifest is not None + assert manifest.name == "test_tool" + + def test_get_nonexistent_tool(self, registry): + assert registry.get("nonexistent") is None + + def test_get_executor(self, registry, sample_manifest, sample_executor): + registry.register(sample_manifest, sample_executor) + + executor = registry.get_executor("test_tool") + assert executor is not None + assert executor("hello") == "processed: hello" + + def test_get_nonexistent_executor(self, registry): + assert registry.get_executor("nonexistent") is None + + +class TestToolRegistryList: + """测试工具列表""" + + def test_list_all_empty(self, registry): + assert registry.list_all() == [] + + def test_list_all_with_tools(self, registry, sample_manifest, sample_executor): + registry.register(sample_manifest, sample_executor) + + manifest2 = ToolManifest( + name="test_tool2", + description="Another test tool", + category=ToolCategory.WRITE, + parameters={"type": "object"}, + return_schema={"type": "string"}, + permission_class=PermissionClass.WRITE, + side_effect_scope=SideEffectScope.LOCAL_STATE, + requires_confirmation=True, + ) + registry.register(manifest2, sample_executor) + + all_tools = registry.list_all() + assert len(all_tools) == 2 + + def test_list_by_category(self, registry, sample_manifest, sample_executor): + registry.register(sample_manifest, sample_executor) + + manifest2 = ToolManifest( + name="write_tool", + description="A write tool", + category=ToolCategory.WRITE, + parameters={"type": "object"}, + return_schema={"type": "string"}, + permission_class=PermissionClass.WRITE, + side_effect_scope=SideEffectScope.LOCAL_STATE, + requires_confirmation=True, + ) + registry.register(manifest2, sample_executor) + + read_tools = registry.list_by_category(ToolCategory.READ) + write_tools = registry.list_by_category(ToolCategory.WRITE) + + assert len(read_tools) == 1 + assert len(write_tools) == 1 + + def test_list_by_permission(self, registry, sample_manifest, sample_executor): + registry.register(sample_manifest, sample_executor) + + read_tools = registry.list_by_permission(PermissionClass.READ) + write_tools = registry.list_by_permission(PermissionClass.WRITE) + + assert len(read_tools) == 1 + assert len(write_tools) == 0 + + +class TestToolRegistrySearch: + """测试工具搜索""" + + def test_search_by_tag(self, registry, sample_manifest, sample_executor): + registry.register(sample_manifest, sample_executor) + + results = registry.search_by_tag("test") + assert len(results) == 1 + assert results[0].name == "test_tool" + + results = registry.search_by_tag("nonexistent") + assert len(results) == 0 + + def test_search_by_name(self, registry, sample_manifest, sample_executor): + registry.register(sample_manifest, sample_executor) + + results = registry.search_by_name("test") + assert len(results) == 1 + + results = registry.search_by_name("tool") + assert len(results) == 1 + + results = registry.search_by_name("xyz") + assert len(results) == 0 + + +class TestToolRegistryUtility: + """测试工具方法""" + + def test_requires_confirmation_true(self, registry, sample_manifest, sample_executor): + sample_manifest.requires_confirmation = True + registry.register(sample_manifest, sample_executor) + + assert registry.get_requires_confirmation("test_tool") is True + + def test_requires_confirmation_false(self, registry, sample_manifest, sample_executor): + registry.register(sample_manifest, sample_executor) + + assert registry.get_requires_confirmation("test_tool") is False + + def test_requires_confirmation_nonexistent(self, registry): + assert registry.get_requires_confirmation("nonexistent") is False + + def test_is_streaming_true(self, registry, sample_manifest, sample_executor): + sample_manifest.is_streaming = True + registry.register(sample_manifest, sample_executor) + + assert registry.get_is_streaming("test_tool") is True + + def test_is_streaming_false(self, registry, sample_manifest, sample_executor): + registry.register(sample_manifest, sample_executor) + + assert registry.get_is_streaming("test_tool") is False + + +class TestToolRegistryUnregister: + """测试工具注销""" + + def test_unregister_existing(self, registry, sample_manifest, sample_executor): + registry.register(sample_manifest, sample_executor) + + assert registry.unregister("test_tool") is True + assert len(registry) == 0 + assert "test_tool" not in registry + + def test_unregister_nonexistent(self, registry): + assert registry.unregister("nonexistent") is False + + def test_unregister_clears_hooks(self, registry, sample_manifest, sample_executor): + hooks = [HookConfig(name="test", hook_type="pre_tool_use")] + registry.register(sample_manifest, sample_executor, hooks=hooks) + + registry.unregister("test_tool") + + assert registry.get_hooks("test_tool") == [] + + +class TestToolRegistryClear: + """测试清空注册表""" + + def test_clear(self, registry, sample_manifest, sample_executor): + registry.register(sample_manifest, sample_executor) + + manifest2 = ToolManifest( + name="tool2", + description="Another tool", + category=ToolCategory.READ, + parameters={"type": "object"}, + return_schema={"type": "string"}, + permission_class=PermissionClass.READ, + side_effect_scope=SideEffectScope.NONE, + requires_confirmation=False, + ) + registry.register(manifest2, sample_executor) + + registry.clear() + + assert len(registry) == 0 + + +class TestGlobalRegistry: + """测试全局注册表""" + + def test_get_tool_registry_returns_singleton(self): + reset_tool_registry() + + reg1 = get_tool_registry() + reg2 = get_tool_registry() + + assert reg1 is reg2 + + def test_reset_tool_registry(self): + reset_tool_registry() + reg1 = get_tool_registry() + + reset_tool_registry() + reg2 = get_tool_registry() + + # After reset, should get a new instance + # (Note: this tests the reset function works) + + +class TestToolManifest: + """测试 ToolManifest""" + + def test_to_dict(self, sample_manifest): + d = sample_manifest.to_dict() + + assert d["name"] == "test_tool" + assert d["description"] == "A test tool" + assert d["category"] == "read" + assert d["permission_class"] == "read" + assert d["side_effect_scope"] == "none" + assert d["requires_confirmation"] is False + assert d["is_streaming"] is False + assert d["tags"] == ["test", "sample"] + + +class TestHookConfig: + """测试 HookConfig""" + + def test_matches_tool_with_no_filter(self): + hook = HookConfig(name="test", hook_type="pre_tool_use") + + assert hook.matches_tool("any_tool") is True + + def test_matches_tool_with_filter(self): + hook = HookConfig(name="test", hook_type="pre_tool_use", filter_names=["tool_a", "tool_b"]) + + assert hook.matches_tool("tool_a") is True + assert hook.matches_tool("tool_b") is True + assert hook.matches_tool("tool_c") is False