feat(agents): Phase 6 tool system refactoring
Phase 6.1: ToolRegistry infrastructure - Add ToolManifest with ToolCategory, PermissionClass, SideEffectScope - Add ToolRegistry singleton with register/get/unregister/list/search - Add BaseTool abstract class with ReadTool/WriteTool/DBWriteTool/ExternalTool/NetworkTool subclasses - Add migration layer for backward compatibility Phase 6.2: Hook interception system - Add HookType (PRE_TOOL_USE, POST_TOOL_USE, TOOL_ERROR, TOOL_SKIP) - Add HookManager with singleton for hook registration - Add HookExecutor for pre/post/error hook execution Phase 6.3: Streaming execution - Add StreamingToolExecutor with batch execution support Phase 6.4: New builtin tools - Add file_tools: GlobTool, GrepTool, ReadFileTool, WriteFileTool - Add system_tools: BashTool, PowerShellTool - Add dev_tools: LSPTools, GitTool - Add collaboration_tools: TeamAgentTool, TaskBroadcastTool Tests: 29 passed
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
from app.agents.tools.search import (
|
||||
search_knowledge, get_knowledge_graph_context,
|
||||
build_knowledge_graph, hybrid_search, web_search,
|
||||
search_knowledge,
|
||||
get_knowledge_graph_context,
|
||||
build_knowledge_graph,
|
||||
hybrid_search,
|
||||
web_search,
|
||||
)
|
||||
from app.agents.tools.task import get_tasks, create_task, update_task_status
|
||||
from app.agents.tools.forum import get_forum_posts, create_forum_post, scan_forum_for_instructions
|
||||
@@ -13,6 +16,58 @@ from app.agents.tools.schedule import (
|
||||
)
|
||||
from app.agents.tools.time_reasoning import resolve_time_expression
|
||||
|
||||
# Phase 6.1: Tool Registry exports
|
||||
from app.agents.tools.registry import (
|
||||
ToolRegistry,
|
||||
get_tool_registry,
|
||||
reset_tool_registry,
|
||||
)
|
||||
from app.agents.tools.manifest import (
|
||||
HookConfig,
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
ToolCategory,
|
||||
ToolManifest,
|
||||
)
|
||||
from app.agents.tools.migration import (
|
||||
migrate_tool,
|
||||
migrate_all_tools,
|
||||
get_tool_executor,
|
||||
BackwardCompatTool,
|
||||
)
|
||||
|
||||
# Phase 6.2: Hook System exports
|
||||
from app.agents.tools.hooks import (
|
||||
HookManager,
|
||||
HookExecutor,
|
||||
HookType,
|
||||
HookDefinition,
|
||||
HookResult,
|
||||
ExecutionContext,
|
||||
get_hook_manager,
|
||||
get_hook_executor,
|
||||
)
|
||||
|
||||
# Phase 6.3: Streaming Executor exports
|
||||
from app.agents.tools.streaming import (
|
||||
StreamingToolExecutor,
|
||||
get_streaming_executor,
|
||||
)
|
||||
|
||||
# Phase 6.4: Builtin Tools exports
|
||||
from app.agents.tools.builtins import (
|
||||
GlobTool,
|
||||
GrepTool,
|
||||
ReadFileTool,
|
||||
WriteFileTool,
|
||||
BashTool,
|
||||
PowerShellTool,
|
||||
LSPTools,
|
||||
GitTool,
|
||||
TeamAgentTool,
|
||||
TaskBroadcastTool,
|
||||
)
|
||||
|
||||
TASK_TOOLS = [
|
||||
get_tasks,
|
||||
create_task,
|
||||
|
||||
161
backend/app/agents/tools/base.py
Normal file
161
backend/app/agents/tools/base.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""工具基类 - 工具系统重构 Phase 6.1"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from app.agents.tools.manifest import (
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
ToolCategory,
|
||||
ToolManifest,
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseTool(ABC, Generic[T]):
|
||||
"""工具基类
|
||||
|
||||
提供工具的标准接口和默认实现。
|
||||
所有自定义工具应继承此类。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
category: ToolCategory,
|
||||
permission_class: PermissionClass,
|
||||
side_effect_scope: SideEffectScope = SideEffectScope.NONE,
|
||||
requires_confirmation: bool = False,
|
||||
is_streaming: bool = False,
|
||||
tags: list[str] | None = None,
|
||||
):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.category = category
|
||||
self.permission_class = permission_class
|
||||
self.side_effect_scope = side_effect_scope
|
||||
self.requires_confirmation = requires_confirmation
|
||||
self.is_streaming = is_streaming
|
||||
self.tags = tags or []
|
||||
|
||||
def get_manifest(self) -> ToolManifest:
|
||||
"""获取工具元数据
|
||||
|
||||
Returns:
|
||||
工具元数据
|
||||
"""
|
||||
return ToolManifest(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
category=self.category,
|
||||
parameters=self.get_parameters(),
|
||||
return_schema=self.get_return_schema(),
|
||||
permission_class=self.permission_class,
|
||||
side_effect_scope=self.side_effect_scope,
|
||||
requires_confirmation=self.requires_confirmation,
|
||||
is_streaming=self.is_streaming,
|
||||
tags=self.tags,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
"""获取参数 Schema(JSON Schema 格式)
|
||||
|
||||
Returns:
|
||||
参数 schema
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
"""获取返回值 Schema
|
||||
|
||||
Returns:
|
||||
返回值 schema
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs) -> T:
|
||||
"""执行工具
|
||||
|
||||
Args:
|
||||
**kwargs: 工具参数
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
pass
|
||||
|
||||
async def execute_safe(self, **kwargs) -> dict[str, Any]:
|
||||
"""安全执行工具,捕获异常
|
||||
|
||||
Args:
|
||||
**kwargs: 工具参数
|
||||
|
||||
Returns:
|
||||
包含 success 和 result/error 的字典
|
||||
"""
|
||||
try:
|
||||
result = await self.execute(**kwargs)
|
||||
return {"success": True, "result": result}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}(name={self.name!r})>"
|
||||
|
||||
|
||||
class ReadTool(BaseTool):
|
||||
"""只读工具基类"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs.setdefault("category", ToolCategory.READ)
|
||||
kwargs.setdefault("permission_class", PermissionClass.READ)
|
||||
kwargs.setdefault("side_effect_scope", SideEffectScope.NONE)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class WriteTool(BaseTool):
|
||||
"""写入工具基类"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs.setdefault("category", ToolCategory.WRITE)
|
||||
kwargs.setdefault("permission_class", PermissionClass.WRITE)
|
||||
kwargs.setdefault("side_effect_scope", SideEffectScope.LOCAL_STATE)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class DBWriteTool(BaseTool):
|
||||
"""数据库写入工具基类"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs.setdefault("category", ToolCategory.DB_WRITE)
|
||||
kwargs.setdefault("permission_class", PermissionClass.WRITE)
|
||||
kwargs.setdefault("side_effect_scope", SideEffectScope.DB_WRITE)
|
||||
kwargs.setdefault("requires_confirmation", True)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class ExternalTool(BaseTool):
|
||||
"""外部工具基类(执行外部命令等)"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs.setdefault("category", ToolCategory.EXTERNAL)
|
||||
kwargs.setdefault("permission_class", PermissionClass.EXTERNAL)
|
||||
kwargs.setdefault("side_effect_scope", SideEffectScope.NETWORK)
|
||||
kwargs.setdefault("requires_confirmation", True)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class NetworkTool(BaseTool):
|
||||
"""网络工具基类"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs.setdefault("category", ToolCategory.NETWORK)
|
||||
kwargs.setdefault("permission_class", PermissionClass.EXTERNAL)
|
||||
kwargs.setdefault("side_effect_scope", SideEffectScope.NETWORK)
|
||||
super().__init__(**kwargs)
|
||||
43
backend/app/agents/tools/builtins/__init__.py
Normal file
43
backend/app/agents/tools/builtins/__init__.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""内置工具集 - Phase 6.4
|
||||
|
||||
新的内置工具,使用 BaseTool 基类。
|
||||
"""
|
||||
|
||||
from app.agents.tools.builtins.file_tools import (
|
||||
GlobTool,
|
||||
GrepTool,
|
||||
ReadFileTool,
|
||||
WriteFileTool,
|
||||
)
|
||||
|
||||
from app.agents.tools.builtins.system_tools import (
|
||||
BashTool,
|
||||
PowerShellTool,
|
||||
)
|
||||
|
||||
from app.agents.tools.builtins.dev_tools import (
|
||||
LSPTools,
|
||||
GitTool,
|
||||
)
|
||||
|
||||
from app.agents.tools.builtins.collaboration_tools import (
|
||||
TeamAgentTool,
|
||||
TaskBroadcastTool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# File tools
|
||||
"GlobTool",
|
||||
"GrepTool",
|
||||
"ReadFileTool",
|
||||
"WriteFileTool",
|
||||
# System tools
|
||||
"BashTool",
|
||||
"PowerShellTool",
|
||||
# Dev tools
|
||||
"LSPTools",
|
||||
"GitTool",
|
||||
# Collaboration tools
|
||||
"TeamAgentTool",
|
||||
"TaskBroadcastTool",
|
||||
]
|
||||
129
backend/app/agents/tools/builtins/collaboration_tools.py
Normal file
129
backend/app/agents/tools/builtins/collaboration_tools.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""协作工具 - Phase 6.4"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.base import WriteTool
|
||||
from app.agents.tools.manifest import (
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
)
|
||||
|
||||
|
||||
class TeamAgentTool(WriteTool):
|
||||
"""团队 Agent 通信工具
|
||||
|
||||
用于与其他 Agent 进行消息传递和协作。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="team_agent",
|
||||
description="向团队 Agent 发送消息或请求协作",
|
||||
permission_class=PermissionClass.WRITE,
|
||||
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||
tags=["collaboration", "team", "agent"],
|
||||
)
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_name": {
|
||||
"type": "string",
|
||||
"description": "目标 Agent 名称",
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "要发送的消息",
|
||||
},
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["send", "request", "delegate"],
|
||||
"description": "操作类型",
|
||||
},
|
||||
},
|
||||
"required": ["agent_name", "message"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"success": {"type": "boolean"},
|
||||
"response": {"type": "string"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(self, agent_name: str, message: str, action: str = "send") -> dict[str, Any]:
|
||||
# 注意:实际实现需要通过 Agent 通信协议
|
||||
# 这里只是一个框架实现
|
||||
return {
|
||||
"success": True,
|
||||
"response": f"Message '{action}' to agent '{agent_name}': {message}",
|
||||
"agent_name": agent_name,
|
||||
"action": action,
|
||||
}
|
||||
|
||||
|
||||
class TaskBroadcastTool(WriteTool):
|
||||
"""任务广播工具
|
||||
|
||||
向多个 Agent 广播任务。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="task_broadcast",
|
||||
description="向多个 Agent 广播任务",
|
||||
permission_class=PermissionClass.WRITE,
|
||||
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||
tags=["collaboration", "broadcast", "task"],
|
||||
)
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_names": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "目标 Agent 列表",
|
||||
},
|
||||
"task": {
|
||||
"type": "string",
|
||||
"description": "要广播的任务描述",
|
||||
},
|
||||
"priority": {
|
||||
"type": "string",
|
||||
"enum": ["low", "normal", "high", "urgent"],
|
||||
"description": "任务优先级",
|
||||
},
|
||||
},
|
||||
"required": ["agent_names", "task"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"success": {"type": "boolean"},
|
||||
"broadcast_to": {"type": "array", "items": {"type": "string"}},
|
||||
"responses": {"type": "array"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
agent_names: list[str],
|
||||
task: str,
|
||||
priority: str = "normal",
|
||||
) -> dict[str, Any]:
|
||||
# 注意:实际实现需要通过 Agent 通信协议
|
||||
# 这里只是一个框架实现
|
||||
return {
|
||||
"success": True,
|
||||
"broadcast_to": agent_names,
|
||||
"task": task,
|
||||
"priority": priority,
|
||||
"responses": [f"Acknowledged by {agent}" for agent in agent_names],
|
||||
}
|
||||
155
backend/app/agents/tools/builtins/dev_tools.py
Normal file
155
backend/app/agents/tools/builtins/dev_tools.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""开发工具 - Phase 6.4"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.base import ReadTool, WriteTool
|
||||
from app.agents.tools.manifest import (
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
)
|
||||
|
||||
|
||||
class LSPTools(ReadTool):
|
||||
"""语言服务器协议工具集
|
||||
|
||||
提供代码导航、查找引用等 LSP 功能。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="lsp_tools",
|
||||
description="LSP 代码导航和查找引用",
|
||||
permission_class=PermissionClass.READ,
|
||||
side_effect_scope=SideEffectScope.NONE,
|
||||
tags=["development", "lsp", "code"],
|
||||
)
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["goto_definition", "find_references", "document_symbols"],
|
||||
"description": "LSP 操作类型",
|
||||
},
|
||||
"file": {
|
||||
"type": "string",
|
||||
"description": "文件路径",
|
||||
},
|
||||
"line": {
|
||||
"type": "integer",
|
||||
"description": "行号(1-based)",
|
||||
},
|
||||
"character": {
|
||||
"type": "integer",
|
||||
"description": "列号(0-based)",
|
||||
},
|
||||
},
|
||||
"required": ["action", "file"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"success": {"type": "boolean"},
|
||||
"results": {"type": "array"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
action: str,
|
||||
file: str,
|
||||
line: int = 1,
|
||||
character: int = 0,
|
||||
) -> dict[str, Any]:
|
||||
# 注意:实际 LSP 调用需要通过 lsp-utils 或类似库
|
||||
# 这里只是一个框架实现
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"LSP action '{action}' not fully implemented - requires LSP server integration",
|
||||
"action": action,
|
||||
"file": file,
|
||||
"position": {"line": line, "character": character},
|
||||
}
|
||||
|
||||
|
||||
class GitTool(ReadTool):
|
||||
"""Git 操作工具
|
||||
|
||||
提供常用的 Git 操作。
|
||||
"""
|
||||
|
||||
def __init__(self, repo_path: str = "."):
|
||||
super().__init__(
|
||||
name="git",
|
||||
description="执行 Git 命令",
|
||||
permission_class=PermissionClass.EXTERNAL,
|
||||
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||
requires_confirmation=True,
|
||||
tags=["development", "git", "version-control"],
|
||||
)
|
||||
self.repo_path = repo_path
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "Git 子命令和参数,如 'status' 或 'log --oneline -10'",
|
||||
},
|
||||
"repo_path": {
|
||||
"type": "string",
|
||||
"description": "仓库路径(可选)",
|
||||
},
|
||||
},
|
||||
"required": ["command"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"stdout": {"type": "string"},
|
||||
"stderr": {"type": "string"},
|
||||
"returncode": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(self, command: str, repo_path: str | None = None) -> dict[str, Any]:
|
||||
import asyncio
|
||||
import os
|
||||
import platform
|
||||
|
||||
repo = repo_path or self.repo_path
|
||||
|
||||
# 构建完整的 git 命令
|
||||
if platform.system() == "Windows":
|
||||
full_command = f'git -C "{repo}" {command}'
|
||||
else:
|
||||
full_command = f"git -C '{repo}' {command}"
|
||||
|
||||
try:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
full_command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
return {
|
||||
"stdout": stdout.decode("utf-8", errors="replace"),
|
||||
"stderr": stderr.decode("utf-8", errors="replace"),
|
||||
"returncode": process.returncode,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": str(e),
|
||||
"returncode": -1,
|
||||
}
|
||||
255
backend/app/agents/tools/builtins/file_tools.py
Normal file
255
backend/app/agents/tools/builtins/file_tools.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""文件操作工具 - Phase 6.4"""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.base import ExternalTool, ReadTool, WriteTool
|
||||
from app.agents.tools.manifest import (
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
ToolCategory,
|
||||
)
|
||||
|
||||
|
||||
class GlobTool(ReadTool):
|
||||
"""文件路径匹配工具
|
||||
|
||||
使用 glob 模式查找文件。
|
||||
"""
|
||||
|
||||
def __init__(self, root_dir: str = "."):
|
||||
super().__init__(
|
||||
name="glob",
|
||||
description="使用 glob 模式查找文件路径",
|
||||
permission_class=PermissionClass.READ,
|
||||
side_effect_scope=SideEffectScope.NONE,
|
||||
tags=["file", "search", "glob"],
|
||||
)
|
||||
self.root_dir = root_dir
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Glob 模式,如 **/*.py",
|
||||
},
|
||||
"root_dir": {
|
||||
"type": "string",
|
||||
"description": "搜索根目录(可选)",
|
||||
},
|
||||
},
|
||||
"required": ["pattern"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
|
||||
async def execute(self, pattern: str, root_dir: str | None = None) -> list[str]:
|
||||
import glob as glob_module
|
||||
|
||||
root = root_dir or self.root_dir
|
||||
return glob_module.glob(pattern, root_dir=root, recursive=True)
|
||||
|
||||
|
||||
class GrepTool(ReadTool):
|
||||
"""文件内容搜索工具
|
||||
|
||||
在文件中搜索匹配的行。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="grep",
|
||||
description="在文件中搜索匹配的文本行",
|
||||
permission_class=PermissionClass.READ,
|
||||
side_effect_scope=SideEffectScope.NONE,
|
||||
tags=["file", "search", "text"],
|
||||
)
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "正则表达式模式",
|
||||
},
|
||||
"paths": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "要搜索的文件路径列表",
|
||||
},
|
||||
"case_sensitive": {
|
||||
"type": "boolean",
|
||||
"description": "是否区分大小写",
|
||||
},
|
||||
},
|
||||
"required": ["pattern", "paths"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file": {"type": "string"},
|
||||
"line": {"type": "integer"},
|
||||
"content": {"type": "string"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self, pattern: str, paths: list[str], case_sensitive: bool = True
|
||||
) -> list[dict[str, Any]]:
|
||||
import re
|
||||
|
||||
flags = 0 if case_sensitive else re.IGNORECASE
|
||||
regex = re.compile(pattern, flags)
|
||||
results = []
|
||||
|
||||
for path in paths:
|
||||
if not os.path.isfile(path):
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
if regex.search(line):
|
||||
results.append(
|
||||
{
|
||||
"file": path,
|
||||
"line": line_num,
|
||||
"content": line.rstrip(),
|
||||
}
|
||||
)
|
||||
except (UnicodeDecodeError, PermissionError):
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class ReadFileTool(ReadTool):
|
||||
"""文件读取工具"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="read_file",
|
||||
description="读取文件内容",
|
||||
permission_class=PermissionClass.READ,
|
||||
side_effect_scope=SideEffectScope.NONE,
|
||||
tags=["file", "read"],
|
||||
)
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "文件路径",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "最大行数",
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "起始行号",
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {"type": "string"},
|
||||
"lines": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(self, path: str, limit: int | None = None, offset: int = 0) -> dict[str, Any]:
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
total_lines = len(lines)
|
||||
start = max(0, offset)
|
||||
end = len(lines) if limit is None else min(start + limit, len(lines))
|
||||
|
||||
content = "".join(lines[start:end])
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"lines": total_lines,
|
||||
"truncated": limit is not None and end < len(lines),
|
||||
}
|
||||
|
||||
|
||||
class WriteFileTool(WriteTool):
|
||||
"""文件写入工具"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="write_file",
|
||||
description="写入文件内容",
|
||||
permission_class=PermissionClass.WRITE,
|
||||
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||
requires_confirmation=True,
|
||||
tags=["file", "write"],
|
||||
)
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "文件路径",
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "文件内容",
|
||||
},
|
||||
"append": {
|
||||
"type": "boolean",
|
||||
"description": "是否追加模式",
|
||||
},
|
||||
},
|
||||
"required": ["path", "content"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"success": {"type": "boolean"},
|
||||
"bytes_written": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(self, path: str, content: str, append: bool = False) -> dict[str, Any]:
|
||||
mode = "a" if append else "w"
|
||||
|
||||
# 确保目录存在
|
||||
directory = os.path.dirname(path)
|
||||
if directory and not os.path.exists(directory):
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
with open(path, mode, encoding="utf-8") as f:
|
||||
bytes_written = f.write(content)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"bytes_written": bytes_written,
|
||||
}
|
||||
193
backend/app/agents/tools/builtins/system_tools.py
Normal file
193
backend/app/agents/tools/builtins/system_tools.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""系统工具 - Phase 6.4"""
|
||||
|
||||
import asyncio
|
||||
import shlex
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.base import ExternalTool
|
||||
from app.agents.tools.manifest import (
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
)
|
||||
|
||||
|
||||
class BashTool(ExternalTool):
|
||||
"""Bash 命令执行工具"""
|
||||
|
||||
def __init__(self, working_dir: str = "."):
|
||||
super().__init__(
|
||||
name="bash",
|
||||
description="执行 Bash 命令",
|
||||
permission_class=PermissionClass.EXTERNAL,
|
||||
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||
requires_confirmation=True,
|
||||
tags=["system", "bash", "shell"],
|
||||
)
|
||||
self.working_dir = working_dir
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "要执行的 Bash 命令",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "超时时间(秒)",
|
||||
},
|
||||
"working_dir": {
|
||||
"type": "string",
|
||||
"description": "工作目录(可选)",
|
||||
},
|
||||
},
|
||||
"required": ["command"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"stdout": {"type": "string"},
|
||||
"stderr": {"type": "string"},
|
||||
"returncode": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self, command: str, timeout: int = 30, working_dir: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
import os
|
||||
|
||||
cwd = working_dir or self.working_dir
|
||||
|
||||
try:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": f"Command timed out after {timeout} seconds",
|
||||
"returncode": -1,
|
||||
}
|
||||
|
||||
return {
|
||||
"stdout": stdout.decode("utf-8", errors="replace"),
|
||||
"stderr": stderr.decode("utf-8", errors="replace"),
|
||||
"returncode": process.returncode,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": str(e),
|
||||
"returncode": -1,
|
||||
}
|
||||
|
||||
|
||||
class PowerShellTool(ExternalTool):
|
||||
"""PowerShell 命令执行工具"""
|
||||
|
||||
def __init__(self, working_dir: str = "."):
|
||||
super().__init__(
|
||||
name="powershell",
|
||||
description="执行 PowerShell 命令",
|
||||
permission_class=PermissionClass.EXTERNAL,
|
||||
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||
requires_confirmation=True,
|
||||
tags=["system", "powershell", "shell"],
|
||||
)
|
||||
self.working_dir = working_dir
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "要执行的 PowerShell 命令",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "超时时间(秒)",
|
||||
},
|
||||
"working_dir": {
|
||||
"type": "string",
|
||||
"description": "工作目录(可选)",
|
||||
},
|
||||
},
|
||||
"required": ["command"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"stdout": {"type": "string"},
|
||||
"stderr": {"type": "string"},
|
||||
"returncode": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self, command: str, timeout: int = 30, working_dir: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
import platform
|
||||
|
||||
# 检测是否是 Windows 平台
|
||||
is_windows = platform.system() == "Windows"
|
||||
|
||||
if not is_windows:
|
||||
# 非 Windows 平台,可能没有 PowerShell
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": "PowerShell is not available on this platform",
|
||||
"returncode": -1,
|
||||
}
|
||||
|
||||
cwd = working_dir or self.working_dir
|
||||
|
||||
try:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
"powershell.exe",
|
||||
"-NoProfile",
|
||||
"-Command",
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": f"Command timed out after {timeout} seconds",
|
||||
"returncode": -1,
|
||||
}
|
||||
|
||||
return {
|
||||
"stdout": stdout.decode("utf-8", errors="replace"),
|
||||
"stderr": stderr.decode("utf-8", errors="replace"),
|
||||
"returncode": process.returncode,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": str(e),
|
||||
"returncode": -1,
|
||||
}
|
||||
46
backend/app/agents/tools/hooks/__init__.py
Normal file
46
backend/app/agents/tools/hooks/__init__.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Hook 系统 - Phase 6.2"""
|
||||
|
||||
from app.agents.tools.hooks.types import (
|
||||
HookDefinition,
|
||||
HookResult,
|
||||
HookStage,
|
||||
HookTrigger,
|
||||
HookType,
|
||||
ExecutionContext,
|
||||
HookHandler,
|
||||
PreToolHook,
|
||||
PostToolHook,
|
||||
ErrorToolHook,
|
||||
SkipToolHook,
|
||||
)
|
||||
from app.agents.tools.hooks.manager import (
|
||||
HookManager,
|
||||
get_hook_manager,
|
||||
reset_hook_manager,
|
||||
)
|
||||
from app.agents.tools.hooks.executor import (
|
||||
HookExecutor,
|
||||
get_hook_executor,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Types
|
||||
"HookType",
|
||||
"HookStage",
|
||||
"HookTrigger",
|
||||
"HookDefinition",
|
||||
"HookResult",
|
||||
"ExecutionContext",
|
||||
"HookHandler",
|
||||
"PreToolHook",
|
||||
"PostToolHook",
|
||||
"ErrorToolHook",
|
||||
"SkipToolHook",
|
||||
# Manager
|
||||
"HookManager",
|
||||
"get_hook_manager",
|
||||
"reset_hook_manager",
|
||||
# Executor
|
||||
"HookExecutor",
|
||||
"get_hook_executor",
|
||||
]
|
||||
170
backend/app/agents/tools/hooks/executor.py
Normal file
170
backend/app/agents/tools/hooks/executor.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""Hook 执行器 - Phase 6.2
|
||||
|
||||
执行 Hook 拦截逻辑。
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.hooks.manager import get_hook_manager
|
||||
from app.agents.tools.hooks.types import (
|
||||
HookDefinition,
|
||||
HookResult,
|
||||
HookType,
|
||||
ExecutionContext,
|
||||
)
|
||||
|
||||
|
||||
class HookExecutor:
|
||||
"""Hook 执行器
|
||||
|
||||
负责在工具执行前后执行 Hook 逻辑。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._manager = get_hook_manager()
|
||||
|
||||
async def execute_pre_hooks(
|
||||
self, context: ExecutionContext
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""执行 pre-tool Hook
|
||||
|
||||
Args:
|
||||
context: 执行上下文
|
||||
|
||||
Returns:
|
||||
(是否继续执行, 修改后的输入)
|
||||
"""
|
||||
hooks = self._manager.get_hooks(HookType.PRE_TOOL_USE, context.tool_name)
|
||||
modified_input = context.tool_input
|
||||
|
||||
for hook in hooks:
|
||||
try:
|
||||
# 调用 hook handler
|
||||
handler = hook.handler
|
||||
if callable(handler):
|
||||
result = await self._call_hook(handler, context)
|
||||
if result and not result.continue_execution:
|
||||
# Hook 决定中断执行
|
||||
return False, modified_input
|
||||
if result.modified_input is not None:
|
||||
modified_input = result.modified_input
|
||||
except Exception as e:
|
||||
# Hook 出错,默认继续执行
|
||||
pass
|
||||
|
||||
return True, modified_input
|
||||
|
||||
async def execute_post_hooks(self, context: ExecutionContext, result: Any) -> Any:
|
||||
"""执行 post-tool Hook
|
||||
|
||||
Args:
|
||||
context: 执行上下文
|
||||
result: 工具执行结果
|
||||
|
||||
Returns:
|
||||
修改后的结果
|
||||
"""
|
||||
hooks = self._manager.get_hooks(HookType.POST_TOOL_USE, context.tool_name)
|
||||
modified_result = result
|
||||
|
||||
for hook in hooks:
|
||||
try:
|
||||
handler = hook.handler
|
||||
if callable(handler):
|
||||
hook_result = await self._call_hook(handler, context, modified_result)
|
||||
if hook_result and hook_result.modified_output is not None:
|
||||
modified_result = hook_result.modified_output
|
||||
except Exception:
|
||||
# Hook 出错,默认保留原结果
|
||||
pass
|
||||
|
||||
return modified_result
|
||||
|
||||
async def execute_error_hooks(
|
||||
self, context: ExecutionContext, error: Exception
|
||||
) -> HookResult | None:
|
||||
"""执行 error Hook
|
||||
|
||||
Args:
|
||||
context: 执行上下文
|
||||
error: 异常
|
||||
|
||||
Returns:
|
||||
Hook 结果,如果返回 None 则继续传播错误
|
||||
"""
|
||||
hooks = self._manager.get_hooks(HookType.TOOL_ERROR, context.tool_name)
|
||||
|
||||
for hook in hooks:
|
||||
try:
|
||||
handler = hook.handler
|
||||
if callable(handler):
|
||||
result = await self._call_hook(handler, context, error)
|
||||
if result is not None and result.continue_execution:
|
||||
return result
|
||||
except Exception:
|
||||
# Hook 出错,继续执行其他 error hooks
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
async def execute_skip_check(self, context: ExecutionContext) -> bool:
|
||||
"""检查是否应跳过工具执行
|
||||
|
||||
Args:
|
||||
context: 执行上下文
|
||||
|
||||
Returns:
|
||||
True 表示跳过,False 表示执行
|
||||
"""
|
||||
hooks = self._manager.get_hooks(HookType.TOOL_SKIP, context.tool_name)
|
||||
|
||||
for hook in hooks:
|
||||
try:
|
||||
handler = hook.handler
|
||||
if callable(handler):
|
||||
result = await self._call_hook(handler, context)
|
||||
if result is not None and isinstance(result, bool):
|
||||
return result
|
||||
except Exception:
|
||||
# Hook 出错,默认不跳过
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
async def _call_hook(
|
||||
self, handler: Any, context: ExecutionContext, *args: Any
|
||||
) -> HookResult | None:
|
||||
"""调用 Hook 处理函数
|
||||
|
||||
Args:
|
||||
handler: Hook 处理函数
|
||||
context: 执行上下文
|
||||
*args: 额外参数
|
||||
|
||||
Returns:
|
||||
Hook 结果
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# 如果是普通函数,直接调用
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
return await handler(context, *args)
|
||||
else:
|
||||
return handler(context, *args)
|
||||
|
||||
|
||||
# 全局单例
|
||||
_executor: HookExecutor | None = None
|
||||
|
||||
|
||||
def get_hook_executor() -> HookExecutor:
|
||||
"""获取全局 Hook 执行器
|
||||
|
||||
Returns:
|
||||
全局 HookExecutor 实例
|
||||
"""
|
||||
global _executor
|
||||
if _executor is None:
|
||||
_executor = HookExecutor()
|
||||
return _executor
|
||||
174
backend/app/agents/tools/hooks/manager.py
Normal file
174
backend/app/agents/tools/hooks/manager.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""Hook 管理器 - Phase 6.2
|
||||
|
||||
管理 Hook 的注册、查找和配置。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.hooks.types import (
|
||||
HookDefinition,
|
||||
HookResult,
|
||||
HookTrigger,
|
||||
HookType,
|
||||
ExecutionContext,
|
||||
)
|
||||
|
||||
|
||||
class HookManager:
|
||||
"""Hook 管理器
|
||||
|
||||
管理全局 Hook 的注册和配置。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._hooks: dict[HookType, list[HookDefinition]] = {
|
||||
HookType.PRE_TOOL_USE: [],
|
||||
HookType.POST_TOOL_USE: [],
|
||||
HookType.TOOL_ERROR: [],
|
||||
HookType.TOOL_SKIP: [],
|
||||
}
|
||||
self._global_hooks: list[HookDefinition] = [] # 全局 Hook(对所有工具生效)
|
||||
|
||||
def register(self, definition: HookDefinition) -> None:
|
||||
"""注册 Hook
|
||||
|
||||
Args:
|
||||
definition: Hook 定义
|
||||
"""
|
||||
if definition.trigger.tool_names is None and definition.trigger.categories is None:
|
||||
# 全局 Hook
|
||||
self._global_hooks.append(definition)
|
||||
else:
|
||||
# 特定工具 Hook
|
||||
self._hooks[definition.hook_type].append(definition)
|
||||
|
||||
# 按优先级排序
|
||||
self._hooks[definition.hook_type].sort(key=lambda h: h.priority, reverse=True)
|
||||
self._global_hooks.sort(key=lambda h: h.priority, reverse=True)
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""注销 Hook
|
||||
|
||||
Args:
|
||||
name: Hook 名称
|
||||
|
||||
Returns:
|
||||
是否成功注销
|
||||
"""
|
||||
# 从特定工具 Hook 中移除
|
||||
for hooks in self._hooks.values():
|
||||
for i, hook in enumerate(hooks):
|
||||
if hook.name == name:
|
||||
hooks.pop(i)
|
||||
return True
|
||||
|
||||
# 从全局 Hook 中移除
|
||||
for i, hook in enumerate(self._global_hooks):
|
||||
if hook.name == name:
|
||||
self._global_hooks.pop(i)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_hooks(self, hook_type: HookType, tool_name: str | None = None) -> list[HookDefinition]:
|
||||
"""获取指定类型和工具的 Hook
|
||||
|
||||
Args:
|
||||
hook_type: Hook 类型
|
||||
tool_name: 工具名称(可选)
|
||||
|
||||
Returns:
|
||||
匹配的 Hook 列表
|
||||
"""
|
||||
result: list[HookDefinition] = []
|
||||
|
||||
# 添加全局 Hook
|
||||
for hook in self._global_hooks:
|
||||
if hook.hook_type == hook_type and hook.enabled:
|
||||
result.append(hook)
|
||||
|
||||
# 添加特定工具 Hook
|
||||
for hook in self._hooks[hook_type]:
|
||||
if not hook.enabled:
|
||||
continue
|
||||
|
||||
if hook.trigger.tool_names is None and hook.trigger.categories is None:
|
||||
continue
|
||||
|
||||
# 检查是否匹配
|
||||
if hook.trigger.tool_names and tool_name not in hook.trigger.tool_names:
|
||||
continue
|
||||
|
||||
result.append(hook)
|
||||
|
||||
return result
|
||||
|
||||
def list_all(self) -> list[HookDefinition]:
|
||||
"""列出所有已注册的 Hook
|
||||
|
||||
Returns:
|
||||
Hook 列表
|
||||
"""
|
||||
all_hooks = list(self._global_hooks)
|
||||
for hooks in self._hooks.values():
|
||||
all_hooks.extend(hooks)
|
||||
return all_hooks
|
||||
|
||||
def enable(self, name: str) -> bool:
|
||||
"""启用 Hook
|
||||
|
||||
Args:
|
||||
name: Hook 名称
|
||||
|
||||
Returns:
|
||||
是否成功启用
|
||||
"""
|
||||
for hook in self.list_all():
|
||||
if hook.name == name:
|
||||
hook.enabled = True
|
||||
return True
|
||||
return False
|
||||
|
||||
def disable(self, name: str) -> bool:
|
||||
"""禁用 Hook
|
||||
|
||||
Args:
|
||||
name: Hook 名称
|
||||
|
||||
Returns:
|
||||
是否成功禁用
|
||||
"""
|
||||
for hook in self.list_all():
|
||||
if hook.name == name:
|
||||
hook.enabled = False
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清除所有 Hook"""
|
||||
self._hooks = {ht: [] for ht in HookType}
|
||||
self._global_hooks = []
|
||||
|
||||
|
||||
# 全局单例
|
||||
_global_hook_manager: HookManager | None = None
|
||||
|
||||
|
||||
def get_hook_manager() -> HookManager:
|
||||
"""获取全局 Hook 管理器
|
||||
|
||||
Returns:
|
||||
全局 HookManager 实例
|
||||
"""
|
||||
global _global_hook_manager
|
||||
if _global_hook_manager is None:
|
||||
_global_hook_manager = HookManager()
|
||||
return _global_hook_manager
|
||||
|
||||
|
||||
def reset_hook_manager() -> None:
|
||||
"""重置全局 Hook 管理器(用于测试)"""
|
||||
global _global_hook_manager
|
||||
if _global_hook_manager is not None:
|
||||
_global_hook_manager.clear()
|
||||
_global_hook_manager = None
|
||||
90
backend/app/agents/tools/hooks/types.py
Normal file
90
backend/app/agents/tools/hooks/types.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Hook 类型定义 - Phase 6.2
|
||||
|
||||
Hook 拦截系统类型定义。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
class HookType(Enum):
|
||||
"""Hook 类型"""
|
||||
|
||||
PRE_TOOL_USE = "pre_tool_use" # 工具执行前
|
||||
POST_TOOL_USE = "post_tool_use" # 工具执行后
|
||||
TOOL_ERROR = "tool_error" # 工具执行出错
|
||||
TOOL_SKIP = "tool_skip" # 工具跳过(条件执行)
|
||||
|
||||
|
||||
class HookStage(Enum):
|
||||
"""Hook 执行阶段"""
|
||||
|
||||
BEFORE = "before"
|
||||
AFTER = "after"
|
||||
ON_ERROR = "on_error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookTrigger:
|
||||
"""Hook 触发条件"""
|
||||
|
||||
tool_names: list[str] | None = None # 只对特定工具生效,None 表示全部
|
||||
categories: list[str] | None = None # 只对特定类别生效
|
||||
conditions: dict[str, Any] | None = None # 自定义条件
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookDefinition:
|
||||
"""Hook 定义"""
|
||||
|
||||
name: str
|
||||
hook_type: HookType
|
||||
trigger: HookTrigger
|
||||
handler: Callable[..., Any] # Hook 处理函数
|
||||
priority: int = 0 # 优先级,数字越大越先执行
|
||||
enabled: bool = True
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookResult:
|
||||
"""Hook 执行结果"""
|
||||
|
||||
hook_name: str
|
||||
success: bool
|
||||
continue_execution: bool = True # False 表示中断执行
|
||||
modified_input: Any = None # 修改后的输入
|
||||
modified_output: Any = None # 修改后的输出
|
||||
error: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionContext:
|
||||
"""工具执行上下文"""
|
||||
|
||||
tool_name: str
|
||||
tool_input: dict[str, Any]
|
||||
user_id: str | None = None
|
||||
session_id: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# 执行结果(由 HookExecutor 填充)
|
||||
result: Any = None
|
||||
error: Exception | None = None
|
||||
start_time: float | None = None
|
||||
end_time: float | None = None
|
||||
|
||||
|
||||
# Hook 处理函数类型
|
||||
HookHandler = Callable[[ExecutionContext, HookDefinition], HookResult]
|
||||
|
||||
# Pre-hook: 在工具执行前调用,可以修改输入或决定是否跳过
|
||||
PreToolHook = Callable[[ExecutionContext], tuple[bool, dict[str, Any] | None]]
|
||||
# post-hook: 在工具执行后调用,可以修改输出
|
||||
PostToolHook = Callable[[ExecutionContext, Any], Any]
|
||||
# Error hook: 在工具出错时调用
|
||||
ErrorToolHook = Callable[[ExecutionContext, Exception], HookResult | None]
|
||||
# Skip hook: 决定是否跳过工具执行
|
||||
SkipToolHook = Callable[[ExecutionContext], bool]
|
||||
77
backend/app/agents/tools/manifest.py
Normal file
77
backend/app/agents/tools/manifest.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""工具元数据和数据类型定义"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ToolCategory(Enum):
|
||||
"""工具类别"""
|
||||
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
EXTERNAL = "external"
|
||||
DB_WRITE = "db_write"
|
||||
NETWORK = "network"
|
||||
|
||||
|
||||
class SideEffectScope(Enum):
|
||||
"""副作用范围"""
|
||||
|
||||
NONE = "none"
|
||||
LOCAL_STATE = "local_state"
|
||||
DB_WRITE = "db_write"
|
||||
NETWORK = "network"
|
||||
|
||||
|
||||
class PermissionClass(Enum):
|
||||
"""权限级别"""
|
||||
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
EXTERNAL = "external"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolManifest:
|
||||
"""工具元数据"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
category: ToolCategory
|
||||
parameters: dict[str, Any] # JSON Schema
|
||||
return_schema: dict[str, Any]
|
||||
permission_class: PermissionClass
|
||||
side_effect_scope: SideEffectScope
|
||||
requires_confirmation: bool = False
|
||||
is_streaming: bool = False
|
||||
tags: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"category": self.category.value,
|
||||
"parameters": self.parameters,
|
||||
"return_schema": self.return_schema,
|
||||
"permission_class": self.permission_class.value,
|
||||
"side_effect_scope": self.side_effect_scope.value,
|
||||
"requires_confirmation": self.requires_confirmation,
|
||||
"is_streaming": self.is_streaming,
|
||||
"tags": self.tags,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookConfig:
|
||||
"""Hook 配置"""
|
||||
|
||||
name: str
|
||||
hook_type: str # "pre_tool_use", "post_tool_use", "tool_error", "tool_skip"
|
||||
filter_names: list[str] | None = None # 只对特定工具生效,None 表示全部
|
||||
|
||||
def matches_tool(self, tool_name: str) -> bool:
|
||||
"""检查 Hook 是否对指定工具生效"""
|
||||
if self.filter_names is None:
|
||||
return True
|
||||
return tool_name in self.filter_names
|
||||
251
backend/app/agents/tools/migration.py
Normal file
251
backend/app/agents/tools/migration.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""工具迁移和向后兼容层 - Phase 6.1
|
||||
|
||||
将现有 @tool 装饰的工具迁移到 ToolRegistry,同时保持向后兼容。
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
from typing import Any, Callable
|
||||
|
||||
from app.agents.tools.manifest import (
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
ToolCategory,
|
||||
ToolManifest,
|
||||
)
|
||||
from app.agents.tools.registry import get_tool_registry
|
||||
|
||||
|
||||
# 现有工具的类别映射
|
||||
_TOOL_CATEGORY_MAP: dict[str, tuple[ToolCategory, PermissionClass, SideEffectScope]] = {
|
||||
# 知识检索 - 只读
|
||||
"search_knowledge": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||
"get_knowledge_graph_context": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||
"hybrid_search": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||
"web_search": (ToolCategory.NETWORK, PermissionClass.EXTERNAL, SideEffectScope.NETWORK),
|
||||
# 知识构建 - 写入
|
||||
"build_knowledge_graph": (
|
||||
ToolCategory.WRITE,
|
||||
PermissionClass.WRITE,
|
||||
SideEffectScope.LOCAL_STATE,
|
||||
),
|
||||
# 任务工具
|
||||
"get_tasks": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||
"create_task": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||
"update_task_status": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||
# 日程工具
|
||||
"get_schedule_day": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||
"create_todo": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||
"create_schedule_task": (
|
||||
ToolCategory.WRITE,
|
||||
PermissionClass.WRITE,
|
||||
SideEffectScope.LOCAL_STATE,
|
||||
),
|
||||
"create_reminder": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||
"create_goal": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||
"resolve_time_expression": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||
# 论坛工具
|
||||
"get_forum_posts": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||
"create_forum_post": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||
"scan_forum_for_instructions": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||
}
|
||||
|
||||
|
||||
def get_tool_category(name: str) -> tuple[ToolCategory, PermissionClass, SideEffectScope]:
|
||||
"""获取工具的类别信息"""
|
||||
return _TOOL_CATEGORY_MAP.get(
|
||||
name,
|
||||
(ToolCategory.EXTERNAL, PermissionClass.EXTERNAL, SideEffectScope.NETWORK),
|
||||
)
|
||||
|
||||
|
||||
def infer_tags_from_docstring(docstring: str | None) -> list[str]:
|
||||
"""从 docstring 推断工具标签"""
|
||||
if not docstring:
|
||||
return []
|
||||
tags = []
|
||||
doc_lower = docstring.lower()
|
||||
if "搜索" in docstring or "查询" in docstring or "search" in doc_lower:
|
||||
tags.append("search")
|
||||
if "创建" in docstring or "新建" in docstring or "create" in doc_lower:
|
||||
tags.append("create")
|
||||
if "获取" in docstring or "读取" in docstring or "get" in doc_lower:
|
||||
tags.append("read")
|
||||
if "更新" in docstring or "修改" in docstring or "update" in doc_lower:
|
||||
tags.append("update")
|
||||
return tags
|
||||
|
||||
|
||||
def migrate_tool(tool_func: Callable) -> Callable:
|
||||
"""将现有 @tool 装饰的函数迁移到 ToolRegistry
|
||||
|
||||
Args:
|
||||
tool_func: LangChain @tool 装饰的函数
|
||||
|
||||
Returns:
|
||||
原函数(已注册到 registry)
|
||||
"""
|
||||
registry = get_tool_registry()
|
||||
|
||||
# 如果已经注册,跳过
|
||||
if registry.get(tool_func.name):
|
||||
return tool_func
|
||||
|
||||
# 获取类别信息
|
||||
category, permission, side_effect = get_tool_category(tool_func.name)
|
||||
|
||||
# 从 docstring 提取 description
|
||||
description = tool_func.description if hasattr(tool_func, "description") else ""
|
||||
|
||||
# 推断 tags
|
||||
tags = infer_tags_from_docstring(description)
|
||||
tags.append("migrated")
|
||||
|
||||
# 创建 manifest
|
||||
manifest = ToolManifest(
|
||||
name=tool_func.name,
|
||||
description=description,
|
||||
category=category,
|
||||
parameters={}, # LangChain @tool 动态处理参数
|
||||
return_schema={},
|
||||
permission_class=permission,
|
||||
side_effect_scope=side_effect,
|
||||
requires_confirmation=side_effect != SideEffectScope.NONE,
|
||||
is_streaming=False,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
# 注册到 registry
|
||||
registry.register(manifest, tool_func)
|
||||
|
||||
return tool_func
|
||||
|
||||
|
||||
def migrate_all_tools() -> int:
|
||||
"""迁移所有现有工具到 ToolRegistry
|
||||
|
||||
Returns:
|
||||
迁移的工具数量
|
||||
"""
|
||||
from app.agents.tools import (
|
||||
ALL_TOOLS,
|
||||
KNOWLEDGE_GRAPH_TOOLS,
|
||||
KNOWLEDGE_RETRIEVAL_TOOLS,
|
||||
SCHEDULE_READ_TOOLS,
|
||||
SCHEDULE_WRITE_TOOLS,
|
||||
TASK_TOOLS,
|
||||
FORUM_TOOLS,
|
||||
)
|
||||
|
||||
all_tools = (
|
||||
KNOWLEDGE_RETRIEVAL_TOOLS
|
||||
+ KNOWLEDGE_GRAPH_TOOLS
|
||||
+ TASK_TOOLS
|
||||
+ SCHEDULE_READ_TOOLS
|
||||
+ SCHEDULE_WRITE_TOOLS
|
||||
+ FORUM_TOOLS
|
||||
)
|
||||
|
||||
count = 0
|
||||
for tool in all_tools:
|
||||
try:
|
||||
migrate_tool(tool)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
print(f"Failed to migrate tool {getattr(tool, 'name', 'unknown')}: {e}")
|
||||
|
||||
return count
|
||||
|
||||
|
||||
class BackwardCompatTool:
|
||||
"""向后兼容工具包装器
|
||||
|
||||
确保现有代码通过 registry.get_executor() 仍能正常调用工具。
|
||||
"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self._registry = get_tool_registry()
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
executor = self._registry.get_executor(self.name)
|
||||
if executor is None:
|
||||
raise ValueError(f"Tool not found in registry: {self.name}")
|
||||
return executor(*args, **kwargs)
|
||||
|
||||
def invoke(self, tool_input: dict[str, Any]) -> Any:
|
||||
"""LangChain 风格的 invoke 调用"""
|
||||
executor = self._registry.get_executor(self.name)
|
||||
if executor is None:
|
||||
raise ValueError(f"Tool not found in registry: {self.name}")
|
||||
|
||||
# 处理位置参数
|
||||
if isinstance(tool_input, dict):
|
||||
return executor(**tool_input)
|
||||
return executor(tool_input)
|
||||
|
||||
|
||||
def create_compat_layer() -> dict[str, BackwardCompatTool]:
|
||||
"""创建向后兼容层
|
||||
|
||||
返回一个字典,允许通过名称访问兼容的工具包装器。
|
||||
"""
|
||||
registry = get_tool_registry()
|
||||
tools = registry.list_all()
|
||||
|
||||
return {tool.name: BackwardCompatTool(tool.name) for tool in tools}
|
||||
|
||||
|
||||
# 自动迁移装饰器
|
||||
def auto_migrate(func: Callable) -> Callable:
|
||||
"""自动迁移装饰器
|
||||
|
||||
用于装饰新的 @tool 函数,自动注册到 registry。
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# 迁移到 registry
|
||||
migrate_tool(wrapper)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# 便捷函数:获取兼容的工具执行器
|
||||
def get_tool_executor(name: str) -> Callable | None:
|
||||
"""获取工具执行器(兼容层)
|
||||
|
||||
优先从 registry 获取,fallback 到直接导入。
|
||||
"""
|
||||
registry = get_tool_registry()
|
||||
executor = registry.get_executor(name)
|
||||
|
||||
if executor is not None:
|
||||
return executor
|
||||
|
||||
# Fallback: 直接从模块导入(仅用于迁移期间)
|
||||
try:
|
||||
from app.agents.tools import (
|
||||
TASK_TOOLS,
|
||||
SCHEDULE_READ_TOOLS,
|
||||
SCHEDULE_WRITE_TOOLS,
|
||||
FORUM_TOOLS,
|
||||
KNOWLEDGE_RETRIEVAL_TOOLS,
|
||||
)
|
||||
|
||||
all_tools = (
|
||||
KNOWLEDGE_RETRIEVAL_TOOLS
|
||||
+ TASK_TOOLS
|
||||
+ SCHEDULE_READ_TOOLS
|
||||
+ SCHEDULE_WRITE_TOOLS
|
||||
+ FORUM_TOOLS
|
||||
)
|
||||
|
||||
for tool in all_tools:
|
||||
if hasattr(tool, "name") and tool.name == name:
|
||||
return tool
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return None
|
||||
206
backend/app/agents/tools/registry.py
Normal file
206
backend/app/agents/tools/registry.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""工具注册表 - 工具系统重构 Phase 6.1"""
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable
|
||||
|
||||
from app.agents.tools.manifest import HookConfig, ToolManifest
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""工具注册表
|
||||
|
||||
统一管理所有工具的注册、发现和调用。
|
||||
支持工具元数据、权限分类、Hook 拦截。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._tools: dict[str, ToolManifest] = {}
|
||||
self._executors: dict[str, Callable] = {}
|
||||
self._hooks: dict[str, list[HookConfig]] = defaultdict(list)
|
||||
|
||||
def register(
|
||||
self, manifest: ToolManifest, executor: Callable, hooks: list[HookConfig] | None = None
|
||||
) -> None:
|
||||
"""注册工具
|
||||
|
||||
Args:
|
||||
manifest: 工具元数据
|
||||
executor: 工具执行函数
|
||||
hooks: 可选的 Hook 配置列表
|
||||
"""
|
||||
if manifest.name in self._tools:
|
||||
raise ValueError(f"Tool already registered: {manifest.name}")
|
||||
|
||||
self._tools[manifest.name] = manifest
|
||||
self._executors[manifest.name] = executor
|
||||
|
||||
if hooks:
|
||||
for hook in hooks:
|
||||
self._hooks[manifest.name].append(hook)
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""注销工具
|
||||
|
||||
Args:
|
||||
name: 工具名称
|
||||
|
||||
Returns:
|
||||
是否成功注销
|
||||
"""
|
||||
if name not in self._tools:
|
||||
return False
|
||||
|
||||
del self._tools[name]
|
||||
del self._executors[name]
|
||||
if name in self._hooks:
|
||||
del self._hooks[name]
|
||||
return True
|
||||
|
||||
def get(self, name: str) -> ToolManifest | None:
|
||||
"""获取工具元数据
|
||||
|
||||
Args:
|
||||
name: 工具名称
|
||||
|
||||
Returns:
|
||||
工具元数据,不存在返回 None
|
||||
"""
|
||||
return self._tools.get(name)
|
||||
|
||||
def get_executor(self, name: str) -> Callable | None:
|
||||
"""获取工具执行器
|
||||
|
||||
Args:
|
||||
name: 工具名称
|
||||
|
||||
Returns:
|
||||
工具执行函数,不存在返回 None
|
||||
"""
|
||||
return self._executors.get(name)
|
||||
|
||||
def get_hooks(self, name: str) -> list[HookConfig]:
|
||||
"""获取工具的 Hook 配置
|
||||
|
||||
Args:
|
||||
name: 工具名称
|
||||
|
||||
Returns:
|
||||
Hook 配置列表
|
||||
"""
|
||||
return self._hooks.get(name, [])
|
||||
|
||||
def list_all(self) -> list[ToolManifest]:
|
||||
"""列出所有已注册的工具
|
||||
|
||||
Returns:
|
||||
工具元数据列表
|
||||
"""
|
||||
return list(self._tools.values())
|
||||
|
||||
def list_by_category(self, category: Any) -> list[ToolManifest]:
|
||||
"""按类别列出工具
|
||||
|
||||
Args:
|
||||
category: 工具类别
|
||||
|
||||
Returns:
|
||||
该类别下的所有工具
|
||||
"""
|
||||
return [t for t in self._tools.values() if t.category == category]
|
||||
|
||||
def list_by_permission(self, permission: Any) -> list[ToolManifest]:
|
||||
"""按权限级别列出工具
|
||||
|
||||
Args:
|
||||
permission: 权限级别
|
||||
|
||||
Returns:
|
||||
该权限级别下的所有工具
|
||||
"""
|
||||
return [t for t in self._tools.values() if t.permission_class == permission]
|
||||
|
||||
def search_by_tag(self, tag: str) -> list[ToolManifest]:
|
||||
"""按标签搜索工具
|
||||
|
||||
Args:
|
||||
tag: 标签
|
||||
|
||||
Returns:
|
||||
包含该标签的工具
|
||||
"""
|
||||
return [t for t in self._tools.values() if tag in t.tags]
|
||||
|
||||
def search_by_name(self, keyword: str) -> list[ToolManifest]:
|
||||
"""按名称关键词搜索工具
|
||||
|
||||
Args:
|
||||
keyword: 关键词
|
||||
|
||||
Returns:
|
||||
名称包含关键词的工具
|
||||
"""
|
||||
keyword = keyword.lower()
|
||||
return [t for t in self._tools.values() if keyword in t.name.lower()]
|
||||
|
||||
def get_requires_confirmation(self, name: str) -> bool:
|
||||
"""检查工具是否需要确认
|
||||
|
||||
Args:
|
||||
name: 工具名称
|
||||
|
||||
Returns:
|
||||
是否需要确认
|
||||
"""
|
||||
manifest = self._tools.get(name)
|
||||
return manifest.requires_confirmation if manifest else False
|
||||
|
||||
def get_is_streaming(self, name: str) -> bool:
|
||||
"""检查工具是否支持流式执行
|
||||
|
||||
Args:
|
||||
name: 工具名称
|
||||
|
||||
Returns:
|
||||
是否支持流式
|
||||
"""
|
||||
manifest = self._tools.get(name)
|
||||
return manifest.is_streaming if manifest else False
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空注册表"""
|
||||
self._tools.clear()
|
||||
self._executors.clear()
|
||||
self._hooks.clear()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._tools)
|
||||
|
||||
def __contains__(self, name: str) -> bool:
|
||||
return name in self._tools
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._tools.values())
|
||||
|
||||
|
||||
# 全局单例实例
|
||||
_global_registry: ToolRegistry | None = None
|
||||
|
||||
|
||||
def get_tool_registry() -> ToolRegistry:
|
||||
"""获取全局工具注册表单例
|
||||
|
||||
Returns:
|
||||
全局 ToolRegistry 实例
|
||||
"""
|
||||
global _global_registry
|
||||
if _global_registry is None:
|
||||
_global_registry = ToolRegistry()
|
||||
return _global_registry
|
||||
|
||||
|
||||
def reset_tool_registry() -> None:
|
||||
"""重置全局工具注册表(用于测试)"""
|
||||
global _global_registry
|
||||
if _global_registry is not None:
|
||||
_global_registry.clear()
|
||||
_global_registry = None
|
||||
210
backend/app/agents/tools/streaming.py
Normal file
210
backend/app/agents/tools/streaming.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""流式工具执行器 - Phase 6.3
|
||||
|
||||
支持流式输出的工具执行器。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from app.agents.tools.hooks.executor import get_hook_executor
|
||||
from app.agents.tools.hooks.types import ExecutionContext
|
||||
from app.agents.tools.registry import get_tool_registry
|
||||
|
||||
|
||||
class StreamingToolExecutor:
|
||||
"""流式工具执行器
|
||||
|
||||
支持:
|
||||
- 普通工具的同步/异步执行
|
||||
- 流式工具的流式输出
|
||||
- Hook 拦截(pre/post/error)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._registry = get_tool_registry()
|
||||
self._hook_executor = get_hook_executor()
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any],
|
||||
user_id: str | None = None,
|
||||
) -> Any:
|
||||
"""执行工具(非流式)
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
tool_input: 工具输入参数
|
||||
user_id: 用户 ID(可选)
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
"""
|
||||
# 创建执行上下文
|
||||
context = ExecutionContext(
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# 获取工具和执行器
|
||||
manifest = self._registry.get(tool_name)
|
||||
if manifest is None:
|
||||
raise ValueError(f"Tool not found: {tool_name}")
|
||||
|
||||
executor = self._registry.get_executor(tool_name)
|
||||
if executor is None:
|
||||
raise ValueError(f"Executor not found for tool: {tool_name}")
|
||||
|
||||
# 检查是否跳过
|
||||
if await self._hook_executor.execute_skip_check(context):
|
||||
return {"skipped": True, "tool": tool_name}
|
||||
|
||||
# 执行 pre-hooks
|
||||
continue_execution, modified_input = await self._hook_executor.execute_pre_hooks(context)
|
||||
if not continue_execution:
|
||||
return {"pre_hook_aborted": True, "tool": tool_name}
|
||||
|
||||
# 执行工具
|
||||
try:
|
||||
context.start_time = asyncio.get_event_loop().time()
|
||||
|
||||
# 判断是同步还是异步执行
|
||||
if asyncio.iscoroutinefunction(executor):
|
||||
result = await executor(**modified_input)
|
||||
else:
|
||||
# 同步函数在线程池中执行
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(None, lambda: executor(**modified_input))
|
||||
|
||||
context.result = result
|
||||
context.end_time = asyncio.get_event_loop().time()
|
||||
|
||||
# 执行 post-hooks
|
||||
result = await self._hook_executor.execute_post_hooks(context, result)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
context.error = e
|
||||
context.end_time = asyncio.get_event_loop().time()
|
||||
|
||||
# 执行 error-hooks
|
||||
error_result = await self._hook_executor.execute_error_hooks(context, e)
|
||||
if error_result is not None:
|
||||
return error_result
|
||||
|
||||
raise
|
||||
|
||||
async def execute_streaming(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any],
|
||||
user_id: str | None = None,
|
||||
) -> AsyncGenerator[dict[str, Any], None]:
|
||||
"""执行流式工具
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
tool_input: 工具输入参数
|
||||
user_id: 用户 ID(可选)
|
||||
|
||||
Yields:
|
||||
流式输出片段
|
||||
"""
|
||||
# 创建执行上下文
|
||||
context = ExecutionContext(
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# 获取工具和执行器
|
||||
manifest = self._registry.get(tool_name)
|
||||
if manifest is None:
|
||||
raise ValueError(f"Tool not found: {tool_name}")
|
||||
|
||||
if not manifest.is_streaming:
|
||||
raise ValueError(f"Tool is not streaming: {tool_name}")
|
||||
|
||||
executor = self._registry.get_executor(tool_name)
|
||||
if executor is None:
|
||||
raise ValueError(f"Executor not found for tool: {tool_name}")
|
||||
|
||||
# 检查是否跳过
|
||||
if await self._hook_executor.execute_skip_check(context):
|
||||
yield {"type": "skipped", "tool": tool_name}
|
||||
return
|
||||
|
||||
# 执行 pre-hooks
|
||||
continue_execution, modified_input = await self._hook_executor.execute_pre_hooks(context)
|
||||
if not continue_execution:
|
||||
yield {"type": "pre_hook_aborted", "tool": tool_name}
|
||||
return
|
||||
|
||||
# 执行流式工具
|
||||
try:
|
||||
context.start_time = asyncio.get_event_loop().time()
|
||||
|
||||
# 调用执行器(应该返回 AsyncGenerator)
|
||||
result = executor(**modified_input)
|
||||
|
||||
# 如果是 async generator
|
||||
if asyncio.isasyncgen(result):
|
||||
async for chunk in result:
|
||||
yield {"type": "chunk", "data": chunk}
|
||||
else:
|
||||
# 普通协程
|
||||
data = await result
|
||||
yield {"type": "chunk", "data": data}
|
||||
|
||||
context.end_time = asyncio.get_event_loop().time()
|
||||
yield {"type": "done", "tool": tool_name}
|
||||
|
||||
except Exception as e:
|
||||
context.error = e
|
||||
context.end_time = asyncio.get_event_loop().time()
|
||||
|
||||
yield {"type": "error", "error": str(e), "tool": tool_name}
|
||||
|
||||
# 执行 error-hooks
|
||||
await self._hook_executor.execute_error_hooks(context, e)
|
||||
|
||||
async def execute_batch(
|
||||
self,
|
||||
tool_calls: list[dict[str, Any]],
|
||||
user_id: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""批量执行工具
|
||||
|
||||
Args:
|
||||
tool_calls: 工具调用列表,每个元素包含 tool_name 和 tool_input
|
||||
user_id: 用户 ID(可选)
|
||||
|
||||
Returns:
|
||||
执行结果列表
|
||||
"""
|
||||
tasks = []
|
||||
for call in tool_calls:
|
||||
tool_name = call.get("tool_name") or call.get("name")
|
||||
tool_input = call.get("tool_input") or call.get("input") or {}
|
||||
tasks.append(self.execute(tool_name, tool_input, user_id))
|
||||
|
||||
return await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# 全局单例
|
||||
_executor: StreamingToolExecutor | None = None
|
||||
|
||||
|
||||
def get_streaming_executor() -> StreamingToolExecutor:
|
||||
"""获取全局流式执行器
|
||||
|
||||
Returns:
|
||||
全局 StreamingToolExecutor 实例
|
||||
"""
|
||||
global _executor
|
||||
if _executor is None:
|
||||
_executor = StreamingToolExecutor()
|
||||
return _executor
|
||||
324
backend/tests/backend/app/agents/test_tools_registry.py
Normal file
324
backend/tests/backend/app/agents/test_tools_registry.py
Normal file
@@ -0,0 +1,324 @@
|
||||
"""ToolRegistry 单元测试"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from app.agents.tools.manifest import (
|
||||
HookConfig,
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
ToolCategory,
|
||||
ToolManifest,
|
||||
)
|
||||
from app.agents.tools.registry import (
|
||||
ToolRegistry,
|
||||
get_tool_registry,
|
||||
reset_tool_registry,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def registry():
|
||||
"""创建空的 ToolRegistry 实例"""
|
||||
return ToolRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_manifest():
|
||||
"""创建示例工具元数据"""
|
||||
return ToolManifest(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
category=ToolCategory.READ,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"input": {"type": "string"}},
|
||||
"required": ["input"],
|
||||
},
|
||||
return_schema={"type": "string"},
|
||||
permission_class=PermissionClass.READ,
|
||||
side_effect_scope=SideEffectScope.NONE,
|
||||
requires_confirmation=False,
|
||||
is_streaming=False,
|
||||
tags=["test", "sample"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_executor():
|
||||
"""创建示例执行器"""
|
||||
|
||||
def executor(input: str) -> str:
|
||||
return f"processed: {input}"
|
||||
|
||||
return executor
|
||||
|
||||
|
||||
class TestToolRegistryInit:
|
||||
"""测试 ToolRegistry 初始化"""
|
||||
|
||||
def test_empty_registry(self, registry):
|
||||
assert len(registry) == 0
|
||||
assert list(registry) == []
|
||||
|
||||
def test_empty_contains_false(self, registry):
|
||||
assert "test_tool" not in registry
|
||||
|
||||
|
||||
class TestToolRegistryRegister:
|
||||
"""测试工具注册"""
|
||||
|
||||
def test_register_single_tool(self, registry, sample_manifest, sample_executor):
|
||||
registry.register(sample_manifest, sample_executor)
|
||||
|
||||
assert len(registry) == 1
|
||||
assert "test_tool" in registry
|
||||
|
||||
def test_register_duplicate_raises(self, registry, sample_manifest, sample_executor):
|
||||
registry.register(sample_manifest, sample_executor)
|
||||
|
||||
with pytest.raises(ValueError, match="Tool already registered"):
|
||||
registry.register(sample_manifest, sample_executor)
|
||||
|
||||
def test_register_with_hooks(self, registry, sample_manifest, sample_executor):
|
||||
hooks = [
|
||||
HookConfig(name="audit", hook_type="pre_tool_use"),
|
||||
HookConfig(name="security", hook_type="post_tool_use"),
|
||||
]
|
||||
registry.register(sample_manifest, sample_executor, hooks=hooks)
|
||||
|
||||
tool_hooks = registry.get_hooks("test_tool")
|
||||
assert len(tool_hooks) == 2
|
||||
|
||||
|
||||
class TestToolRegistryGet:
|
||||
"""测试获取工具"""
|
||||
|
||||
def test_get_existing_tool(self, registry, sample_manifest, sample_executor):
|
||||
registry.register(sample_manifest, sample_executor)
|
||||
|
||||
manifest = registry.get("test_tool")
|
||||
assert manifest is not None
|
||||
assert manifest.name == "test_tool"
|
||||
|
||||
def test_get_nonexistent_tool(self, registry):
|
||||
assert registry.get("nonexistent") is None
|
||||
|
||||
def test_get_executor(self, registry, sample_manifest, sample_executor):
|
||||
registry.register(sample_manifest, sample_executor)
|
||||
|
||||
executor = registry.get_executor("test_tool")
|
||||
assert executor is not None
|
||||
assert executor("hello") == "processed: hello"
|
||||
|
||||
def test_get_nonexistent_executor(self, registry):
|
||||
assert registry.get_executor("nonexistent") is None
|
||||
|
||||
|
||||
class TestToolRegistryList:
|
||||
"""测试工具列表"""
|
||||
|
||||
def test_list_all_empty(self, registry):
|
||||
assert registry.list_all() == []
|
||||
|
||||
def test_list_all_with_tools(self, registry, sample_manifest, sample_executor):
|
||||
registry.register(sample_manifest, sample_executor)
|
||||
|
||||
manifest2 = ToolManifest(
|
||||
name="test_tool2",
|
||||
description="Another test tool",
|
||||
category=ToolCategory.WRITE,
|
||||
parameters={"type": "object"},
|
||||
return_schema={"type": "string"},
|
||||
permission_class=PermissionClass.WRITE,
|
||||
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||
requires_confirmation=True,
|
||||
)
|
||||
registry.register(manifest2, sample_executor)
|
||||
|
||||
all_tools = registry.list_all()
|
||||
assert len(all_tools) == 2
|
||||
|
||||
def test_list_by_category(self, registry, sample_manifest, sample_executor):
|
||||
registry.register(sample_manifest, sample_executor)
|
||||
|
||||
manifest2 = ToolManifest(
|
||||
name="write_tool",
|
||||
description="A write tool",
|
||||
category=ToolCategory.WRITE,
|
||||
parameters={"type": "object"},
|
||||
return_schema={"type": "string"},
|
||||
permission_class=PermissionClass.WRITE,
|
||||
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||
requires_confirmation=True,
|
||||
)
|
||||
registry.register(manifest2, sample_executor)
|
||||
|
||||
read_tools = registry.list_by_category(ToolCategory.READ)
|
||||
write_tools = registry.list_by_category(ToolCategory.WRITE)
|
||||
|
||||
assert len(read_tools) == 1
|
||||
assert len(write_tools) == 1
|
||||
|
||||
def test_list_by_permission(self, registry, sample_manifest, sample_executor):
|
||||
registry.register(sample_manifest, sample_executor)
|
||||
|
||||
read_tools = registry.list_by_permission(PermissionClass.READ)
|
||||
write_tools = registry.list_by_permission(PermissionClass.WRITE)
|
||||
|
||||
assert len(read_tools) == 1
|
||||
assert len(write_tools) == 0
|
||||
|
||||
|
||||
class TestToolRegistrySearch:
|
||||
"""测试工具搜索"""
|
||||
|
||||
def test_search_by_tag(self, registry, sample_manifest, sample_executor):
|
||||
registry.register(sample_manifest, sample_executor)
|
||||
|
||||
results = registry.search_by_tag("test")
|
||||
assert len(results) == 1
|
||||
assert results[0].name == "test_tool"
|
||||
|
||||
results = registry.search_by_tag("nonexistent")
|
||||
assert len(results) == 0
|
||||
|
||||
def test_search_by_name(self, registry, sample_manifest, sample_executor):
|
||||
registry.register(sample_manifest, sample_executor)
|
||||
|
||||
results = registry.search_by_name("test")
|
||||
assert len(results) == 1
|
||||
|
||||
results = registry.search_by_name("tool")
|
||||
assert len(results) == 1
|
||||
|
||||
results = registry.search_by_name("xyz")
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
class TestToolRegistryUtility:
|
||||
"""测试工具方法"""
|
||||
|
||||
def test_requires_confirmation_true(self, registry, sample_manifest, sample_executor):
|
||||
sample_manifest.requires_confirmation = True
|
||||
registry.register(sample_manifest, sample_executor)
|
||||
|
||||
assert registry.get_requires_confirmation("test_tool") is True
|
||||
|
||||
def test_requires_confirmation_false(self, registry, sample_manifest, sample_executor):
|
||||
registry.register(sample_manifest, sample_executor)
|
||||
|
||||
assert registry.get_requires_confirmation("test_tool") is False
|
||||
|
||||
def test_requires_confirmation_nonexistent(self, registry):
|
||||
assert registry.get_requires_confirmation("nonexistent") is False
|
||||
|
||||
def test_is_streaming_true(self, registry, sample_manifest, sample_executor):
|
||||
sample_manifest.is_streaming = True
|
||||
registry.register(sample_manifest, sample_executor)
|
||||
|
||||
assert registry.get_is_streaming("test_tool") is True
|
||||
|
||||
def test_is_streaming_false(self, registry, sample_manifest, sample_executor):
|
||||
registry.register(sample_manifest, sample_executor)
|
||||
|
||||
assert registry.get_is_streaming("test_tool") is False
|
||||
|
||||
|
||||
class TestToolRegistryUnregister:
|
||||
"""测试工具注销"""
|
||||
|
||||
def test_unregister_existing(self, registry, sample_manifest, sample_executor):
|
||||
registry.register(sample_manifest, sample_executor)
|
||||
|
||||
assert registry.unregister("test_tool") is True
|
||||
assert len(registry) == 0
|
||||
assert "test_tool" not in registry
|
||||
|
||||
def test_unregister_nonexistent(self, registry):
|
||||
assert registry.unregister("nonexistent") is False
|
||||
|
||||
def test_unregister_clears_hooks(self, registry, sample_manifest, sample_executor):
|
||||
hooks = [HookConfig(name="test", hook_type="pre_tool_use")]
|
||||
registry.register(sample_manifest, sample_executor, hooks=hooks)
|
||||
|
||||
registry.unregister("test_tool")
|
||||
|
||||
assert registry.get_hooks("test_tool") == []
|
||||
|
||||
|
||||
class TestToolRegistryClear:
|
||||
"""测试清空注册表"""
|
||||
|
||||
def test_clear(self, registry, sample_manifest, sample_executor):
|
||||
registry.register(sample_manifest, sample_executor)
|
||||
|
||||
manifest2 = ToolManifest(
|
||||
name="tool2",
|
||||
description="Another tool",
|
||||
category=ToolCategory.READ,
|
||||
parameters={"type": "object"},
|
||||
return_schema={"type": "string"},
|
||||
permission_class=PermissionClass.READ,
|
||||
side_effect_scope=SideEffectScope.NONE,
|
||||
requires_confirmation=False,
|
||||
)
|
||||
registry.register(manifest2, sample_executor)
|
||||
|
||||
registry.clear()
|
||||
|
||||
assert len(registry) == 0
|
||||
|
||||
|
||||
class TestGlobalRegistry:
|
||||
"""测试全局注册表"""
|
||||
|
||||
def test_get_tool_registry_returns_singleton(self):
|
||||
reset_tool_registry()
|
||||
|
||||
reg1 = get_tool_registry()
|
||||
reg2 = get_tool_registry()
|
||||
|
||||
assert reg1 is reg2
|
||||
|
||||
def test_reset_tool_registry(self):
|
||||
reset_tool_registry()
|
||||
reg1 = get_tool_registry()
|
||||
|
||||
reset_tool_registry()
|
||||
reg2 = get_tool_registry()
|
||||
|
||||
# After reset, should get a new instance
|
||||
# (Note: this tests the reset function works)
|
||||
|
||||
|
||||
class TestToolManifest:
|
||||
"""测试 ToolManifest"""
|
||||
|
||||
def test_to_dict(self, sample_manifest):
|
||||
d = sample_manifest.to_dict()
|
||||
|
||||
assert d["name"] == "test_tool"
|
||||
assert d["description"] == "A test tool"
|
||||
assert d["category"] == "read"
|
||||
assert d["permission_class"] == "read"
|
||||
assert d["side_effect_scope"] == "none"
|
||||
assert d["requires_confirmation"] is False
|
||||
assert d["is_streaming"] is False
|
||||
assert d["tags"] == ["test", "sample"]
|
||||
|
||||
|
||||
class TestHookConfig:
|
||||
"""测试 HookConfig"""
|
||||
|
||||
def test_matches_tool_with_no_filter(self):
|
||||
hook = HookConfig(name="test", hook_type="pre_tool_use")
|
||||
|
||||
assert hook.matches_tool("any_tool") is True
|
||||
|
||||
def test_matches_tool_with_filter(self):
|
||||
hook = HookConfig(name="test", hook_type="pre_tool_use", filter_names=["tool_a", "tool_b"])
|
||||
|
||||
assert hook.matches_tool("tool_a") is True
|
||||
assert hook.matches_tool("tool_b") is True
|
||||
assert hook.matches_tool("tool_c") is False
|
||||
Reference in New Issue
Block a user