feat(agents): Phase 6 tool system refactoring
Phase 6.1: ToolRegistry infrastructure - Add ToolManifest with ToolCategory, PermissionClass, SideEffectScope - Add ToolRegistry singleton with register/get/unregister/list/search - Add BaseTool abstract class with ReadTool/WriteTool/DBWriteTool/ExternalTool/NetworkTool subclasses - Add migration layer for backward compatibility Phase 6.2: Hook interception system - Add HookType (PRE_TOOL_USE, POST_TOOL_USE, TOOL_ERROR, TOOL_SKIP) - Add HookManager with singleton for hook registration - Add HookExecutor for pre/post/error hook execution Phase 6.3: Streaming execution - Add StreamingToolExecutor with batch execution support Phase 6.4: New builtin tools - Add file_tools: GlobTool, GrepTool, ReadFileTool, WriteFileTool - Add system_tools: BashTool, PowerShellTool - Add dev_tools: LSPTools, GitTool - Add collaboration_tools: TeamAgentTool, TaskBroadcastTool Tests: 29 passed
This commit is contained in:
@@ -1,6 +1,9 @@
|
|||||||
from app.agents.tools.search import (
|
from app.agents.tools.search import (
|
||||||
search_knowledge, get_knowledge_graph_context,
|
search_knowledge,
|
||||||
build_knowledge_graph, hybrid_search, web_search,
|
get_knowledge_graph_context,
|
||||||
|
build_knowledge_graph,
|
||||||
|
hybrid_search,
|
||||||
|
web_search,
|
||||||
)
|
)
|
||||||
from app.agents.tools.task import get_tasks, create_task, update_task_status
|
from app.agents.tools.task import get_tasks, create_task, update_task_status
|
||||||
from app.agents.tools.forum import get_forum_posts, create_forum_post, scan_forum_for_instructions
|
from app.agents.tools.forum import get_forum_posts, create_forum_post, scan_forum_for_instructions
|
||||||
@@ -13,6 +16,58 @@ from app.agents.tools.schedule import (
|
|||||||
)
|
)
|
||||||
from app.agents.tools.time_reasoning import resolve_time_expression
|
from app.agents.tools.time_reasoning import resolve_time_expression
|
||||||
|
|
||||||
|
# Phase 6.1: Tool Registry exports
|
||||||
|
from app.agents.tools.registry import (
|
||||||
|
ToolRegistry,
|
||||||
|
get_tool_registry,
|
||||||
|
reset_tool_registry,
|
||||||
|
)
|
||||||
|
from app.agents.tools.manifest import (
|
||||||
|
HookConfig,
|
||||||
|
PermissionClass,
|
||||||
|
SideEffectScope,
|
||||||
|
ToolCategory,
|
||||||
|
ToolManifest,
|
||||||
|
)
|
||||||
|
from app.agents.tools.migration import (
|
||||||
|
migrate_tool,
|
||||||
|
migrate_all_tools,
|
||||||
|
get_tool_executor,
|
||||||
|
BackwardCompatTool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Phase 6.2: Hook System exports
|
||||||
|
from app.agents.tools.hooks import (
|
||||||
|
HookManager,
|
||||||
|
HookExecutor,
|
||||||
|
HookType,
|
||||||
|
HookDefinition,
|
||||||
|
HookResult,
|
||||||
|
ExecutionContext,
|
||||||
|
get_hook_manager,
|
||||||
|
get_hook_executor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Phase 6.3: Streaming Executor exports
|
||||||
|
from app.agents.tools.streaming import (
|
||||||
|
StreamingToolExecutor,
|
||||||
|
get_streaming_executor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Phase 6.4: Builtin Tools exports
|
||||||
|
from app.agents.tools.builtins import (
|
||||||
|
GlobTool,
|
||||||
|
GrepTool,
|
||||||
|
ReadFileTool,
|
||||||
|
WriteFileTool,
|
||||||
|
BashTool,
|
||||||
|
PowerShellTool,
|
||||||
|
LSPTools,
|
||||||
|
GitTool,
|
||||||
|
TeamAgentTool,
|
||||||
|
TaskBroadcastTool,
|
||||||
|
)
|
||||||
|
|
||||||
TASK_TOOLS = [
|
TASK_TOOLS = [
|
||||||
get_tasks,
|
get_tasks,
|
||||||
create_task,
|
create_task,
|
||||||
|
|||||||
161
backend/app/agents/tools/base.py
Normal file
161
backend/app/agents/tools/base.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
"""工具基类 - 工具系统重构 Phase 6.1"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Generic, TypeVar
|
||||||
|
|
||||||
|
from app.agents.tools.manifest import (
|
||||||
|
PermissionClass,
|
||||||
|
SideEffectScope,
|
||||||
|
ToolCategory,
|
||||||
|
ToolManifest,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTool(ABC, Generic[T]):
|
||||||
|
"""工具基类
|
||||||
|
|
||||||
|
提供工具的标准接口和默认实现。
|
||||||
|
所有自定义工具应继承此类。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
category: ToolCategory,
|
||||||
|
permission_class: PermissionClass,
|
||||||
|
side_effect_scope: SideEffectScope = SideEffectScope.NONE,
|
||||||
|
requires_confirmation: bool = False,
|
||||||
|
is_streaming: bool = False,
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.category = category
|
||||||
|
self.permission_class = permission_class
|
||||||
|
self.side_effect_scope = side_effect_scope
|
||||||
|
self.requires_confirmation = requires_confirmation
|
||||||
|
self.is_streaming = is_streaming
|
||||||
|
self.tags = tags or []
|
||||||
|
|
||||||
|
def get_manifest(self) -> ToolManifest:
|
||||||
|
"""获取工具元数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具元数据
|
||||||
|
"""
|
||||||
|
return ToolManifest(
|
||||||
|
name=self.name,
|
||||||
|
description=self.description,
|
||||||
|
category=self.category,
|
||||||
|
parameters=self.get_parameters(),
|
||||||
|
return_schema=self.get_return_schema(),
|
||||||
|
permission_class=self.permission_class,
|
||||||
|
side_effect_scope=self.side_effect_scope,
|
||||||
|
requires_confirmation=self.requires_confirmation,
|
||||||
|
is_streaming=self.is_streaming,
|
||||||
|
tags=self.tags,
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_parameters(self) -> dict[str, Any]:
|
||||||
|
"""获取参数 Schema(JSON Schema 格式)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
参数 schema
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_return_schema(self) -> dict[str, Any]:
|
||||||
|
"""获取返回值 Schema
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
返回值 schema
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def execute(self, **kwargs) -> T:
|
||||||
|
"""执行工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: 工具参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
执行结果
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def execute_safe(self, **kwargs) -> dict[str, Any]:
|
||||||
|
"""安全执行工具,捕获异常
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: 工具参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含 success 和 result/error 的字典
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await self.execute(**kwargs)
|
||||||
|
return {"success": True, "result": result}
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<{self.__class__.__name__}(name={self.name!r})>"
|
||||||
|
|
||||||
|
|
||||||
|
class ReadTool(BaseTool):
|
||||||
|
"""只读工具基类"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
kwargs.setdefault("category", ToolCategory.READ)
|
||||||
|
kwargs.setdefault("permission_class", PermissionClass.READ)
|
||||||
|
kwargs.setdefault("side_effect_scope", SideEffectScope.NONE)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class WriteTool(BaseTool):
|
||||||
|
"""写入工具基类"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
kwargs.setdefault("category", ToolCategory.WRITE)
|
||||||
|
kwargs.setdefault("permission_class", PermissionClass.WRITE)
|
||||||
|
kwargs.setdefault("side_effect_scope", SideEffectScope.LOCAL_STATE)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class DBWriteTool(BaseTool):
|
||||||
|
"""数据库写入工具基类"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
kwargs.setdefault("category", ToolCategory.DB_WRITE)
|
||||||
|
kwargs.setdefault("permission_class", PermissionClass.WRITE)
|
||||||
|
kwargs.setdefault("side_effect_scope", SideEffectScope.DB_WRITE)
|
||||||
|
kwargs.setdefault("requires_confirmation", True)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalTool(BaseTool):
|
||||||
|
"""外部工具基类(执行外部命令等)"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
kwargs.setdefault("category", ToolCategory.EXTERNAL)
|
||||||
|
kwargs.setdefault("permission_class", PermissionClass.EXTERNAL)
|
||||||
|
kwargs.setdefault("side_effect_scope", SideEffectScope.NETWORK)
|
||||||
|
kwargs.setdefault("requires_confirmation", True)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkTool(BaseTool):
|
||||||
|
"""网络工具基类"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
kwargs.setdefault("category", ToolCategory.NETWORK)
|
||||||
|
kwargs.setdefault("permission_class", PermissionClass.EXTERNAL)
|
||||||
|
kwargs.setdefault("side_effect_scope", SideEffectScope.NETWORK)
|
||||||
|
super().__init__(**kwargs)
|
||||||
43
backend/app/agents/tools/builtins/__init__.py
Normal file
43
backend/app/agents/tools/builtins/__init__.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
"""内置工具集 - Phase 6.4
|
||||||
|
|
||||||
|
新的内置工具,使用 BaseTool 基类。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.agents.tools.builtins.file_tools import (
|
||||||
|
GlobTool,
|
||||||
|
GrepTool,
|
||||||
|
ReadFileTool,
|
||||||
|
WriteFileTool,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.agents.tools.builtins.system_tools import (
|
||||||
|
BashTool,
|
||||||
|
PowerShellTool,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.agents.tools.builtins.dev_tools import (
|
||||||
|
LSPTools,
|
||||||
|
GitTool,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.agents.tools.builtins.collaboration_tools import (
|
||||||
|
TeamAgentTool,
|
||||||
|
TaskBroadcastTool,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# File tools
|
||||||
|
"GlobTool",
|
||||||
|
"GrepTool",
|
||||||
|
"ReadFileTool",
|
||||||
|
"WriteFileTool",
|
||||||
|
# System tools
|
||||||
|
"BashTool",
|
||||||
|
"PowerShellTool",
|
||||||
|
# Dev tools
|
||||||
|
"LSPTools",
|
||||||
|
"GitTool",
|
||||||
|
# Collaboration tools
|
||||||
|
"TeamAgentTool",
|
||||||
|
"TaskBroadcastTool",
|
||||||
|
]
|
||||||
129
backend/app/agents/tools/builtins/collaboration_tools.py
Normal file
129
backend/app/agents/tools/builtins/collaboration_tools.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
"""协作工具 - Phase 6.4"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.tools.base import WriteTool
|
||||||
|
from app.agents.tools.manifest import (
|
||||||
|
PermissionClass,
|
||||||
|
SideEffectScope,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TeamAgentTool(WriteTool):
|
||||||
|
"""团队 Agent 通信工具
|
||||||
|
|
||||||
|
用于与其他 Agent 进行消息传递和协作。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="team_agent",
|
||||||
|
description="向团队 Agent 发送消息或请求协作",
|
||||||
|
permission_class=PermissionClass.WRITE,
|
||||||
|
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||||
|
tags=["collaboration", "team", "agent"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"agent_name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "目标 Agent 名称",
|
||||||
|
},
|
||||||
|
"message": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "要发送的消息",
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["send", "request", "delegate"],
|
||||||
|
"description": "操作类型",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["agent_name", "message"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_return_schema(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"success": {"type": "boolean"},
|
||||||
|
"response": {"type": "string"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, agent_name: str, message: str, action: str = "send") -> dict[str, Any]:
|
||||||
|
# 注意:实际实现需要通过 Agent 通信协议
|
||||||
|
# 这里只是一个框架实现
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"response": f"Message '{action}' to agent '{agent_name}': {message}",
|
||||||
|
"agent_name": agent_name,
|
||||||
|
"action": action,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TaskBroadcastTool(WriteTool):
|
||||||
|
"""任务广播工具
|
||||||
|
|
||||||
|
向多个 Agent 广播任务。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="task_broadcast",
|
||||||
|
description="向多个 Agent 广播任务",
|
||||||
|
permission_class=PermissionClass.WRITE,
|
||||||
|
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||||
|
tags=["collaboration", "broadcast", "task"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"agent_names": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "目标 Agent 列表",
|
||||||
|
},
|
||||||
|
"task": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "要广播的任务描述",
|
||||||
|
},
|
||||||
|
"priority": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["low", "normal", "high", "urgent"],
|
||||||
|
"description": "任务优先级",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["agent_names", "task"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_return_schema(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"success": {"type": "boolean"},
|
||||||
|
"broadcast_to": {"type": "array", "items": {"type": "string"}},
|
||||||
|
"responses": {"type": "array"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
agent_names: list[str],
|
||||||
|
task: str,
|
||||||
|
priority: str = "normal",
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
# 注意:实际实现需要通过 Agent 通信协议
|
||||||
|
# 这里只是一个框架实现
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"broadcast_to": agent_names,
|
||||||
|
"task": task,
|
||||||
|
"priority": priority,
|
||||||
|
"responses": [f"Acknowledged by {agent}" for agent in agent_names],
|
||||||
|
}
|
||||||
155
backend/app/agents/tools/builtins/dev_tools.py
Normal file
155
backend/app/agents/tools/builtins/dev_tools.py
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
"""开发工具 - Phase 6.4"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.tools.base import ReadTool, WriteTool
|
||||||
|
from app.agents.tools.manifest import (
|
||||||
|
PermissionClass,
|
||||||
|
SideEffectScope,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LSPTools(ReadTool):
|
||||||
|
"""语言服务器协议工具集
|
||||||
|
|
||||||
|
提供代码导航、查找引用等 LSP 功能。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="lsp_tools",
|
||||||
|
description="LSP 代码导航和查找引用",
|
||||||
|
permission_class=PermissionClass.READ,
|
||||||
|
side_effect_scope=SideEffectScope.NONE,
|
||||||
|
tags=["development", "lsp", "code"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"action": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["goto_definition", "find_references", "document_symbols"],
|
||||||
|
"description": "LSP 操作类型",
|
||||||
|
},
|
||||||
|
"file": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "文件路径",
|
||||||
|
},
|
||||||
|
"line": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "行号(1-based)",
|
||||||
|
},
|
||||||
|
"character": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "列号(0-based)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["action", "file"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_return_schema(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"success": {"type": "boolean"},
|
||||||
|
"results": {"type": "array"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
action: str,
|
||||||
|
file: str,
|
||||||
|
line: int = 1,
|
||||||
|
character: int = 0,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
# 注意:实际 LSP 调用需要通过 lsp-utils 或类似库
|
||||||
|
# 这里只是一个框架实现
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"LSP action '{action}' not fully implemented - requires LSP server integration",
|
||||||
|
"action": action,
|
||||||
|
"file": file,
|
||||||
|
"position": {"line": line, "character": character},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class GitTool(ReadTool):
|
||||||
|
"""Git 操作工具
|
||||||
|
|
||||||
|
提供常用的 Git 操作。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, repo_path: str = "."):
|
||||||
|
super().__init__(
|
||||||
|
name="git",
|
||||||
|
description="执行 Git 命令",
|
||||||
|
permission_class=PermissionClass.EXTERNAL,
|
||||||
|
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||||
|
requires_confirmation=True,
|
||||||
|
tags=["development", "git", "version-control"],
|
||||||
|
)
|
||||||
|
self.repo_path = repo_path
|
||||||
|
|
||||||
|
def get_parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Git 子命令和参数,如 'status' 或 'log --oneline -10'",
|
||||||
|
},
|
||||||
|
"repo_path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "仓库路径(可选)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["command"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_return_schema(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"stdout": {"type": "string"},
|
||||||
|
"stderr": {"type": "string"},
|
||||||
|
"returncode": {"type": "integer"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, command: str, repo_path: str | None = None) -> dict[str, Any]:
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
|
||||||
|
repo = repo_path or self.repo_path
|
||||||
|
|
||||||
|
# 构建完整的 git 命令
|
||||||
|
if platform.system() == "Windows":
|
||||||
|
full_command = f'git -C "{repo}" {command}'
|
||||||
|
else:
|
||||||
|
full_command = f"git -C '{repo}' {command}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
process = await asyncio.create_subprocess_shell(
|
||||||
|
full_command,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
|
||||||
|
stdout, stderr = await process.communicate()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"stdout": stdout.decode("utf-8", errors="replace"),
|
||||||
|
"stderr": stderr.decode("utf-8", errors="replace"),
|
||||||
|
"returncode": process.returncode,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": str(e),
|
||||||
|
"returncode": -1,
|
||||||
|
}
|
||||||
255
backend/app/agents/tools/builtins/file_tools.py
Normal file
255
backend/app/agents/tools/builtins/file_tools.py
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
"""文件操作工具 - Phase 6.4"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.tools.base import ExternalTool, ReadTool, WriteTool
|
||||||
|
from app.agents.tools.manifest import (
|
||||||
|
PermissionClass,
|
||||||
|
SideEffectScope,
|
||||||
|
ToolCategory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GlobTool(ReadTool):
|
||||||
|
"""文件路径匹配工具
|
||||||
|
|
||||||
|
使用 glob 模式查找文件。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, root_dir: str = "."):
|
||||||
|
super().__init__(
|
||||||
|
name="glob",
|
||||||
|
description="使用 glob 模式查找文件路径",
|
||||||
|
permission_class=PermissionClass.READ,
|
||||||
|
side_effect_scope=SideEffectScope.NONE,
|
||||||
|
tags=["file", "search", "glob"],
|
||||||
|
)
|
||||||
|
self.root_dir = root_dir
|
||||||
|
|
||||||
|
def get_parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"pattern": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Glob 模式,如 **/*.py",
|
||||||
|
},
|
||||||
|
"root_dir": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "搜索根目录(可选)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["pattern"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_return_schema(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, pattern: str, root_dir: str | None = None) -> list[str]:
|
||||||
|
import glob as glob_module
|
||||||
|
|
||||||
|
root = root_dir or self.root_dir
|
||||||
|
return glob_module.glob(pattern, root_dir=root, recursive=True)
|
||||||
|
|
||||||
|
|
||||||
|
class GrepTool(ReadTool):
|
||||||
|
"""文件内容搜索工具
|
||||||
|
|
||||||
|
在文件中搜索匹配的行。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="grep",
|
||||||
|
description="在文件中搜索匹配的文本行",
|
||||||
|
permission_class=PermissionClass.READ,
|
||||||
|
side_effect_scope=SideEffectScope.NONE,
|
||||||
|
tags=["file", "search", "text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"pattern": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "正则表达式模式",
|
||||||
|
},
|
||||||
|
"paths": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "要搜索的文件路径列表",
|
||||||
|
},
|
||||||
|
"case_sensitive": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "是否区分大小写",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["pattern", "paths"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_return_schema(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"file": {"type": "string"},
|
||||||
|
"line": {"type": "integer"},
|
||||||
|
"content": {"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self, pattern: str, paths: list[str], case_sensitive: bool = True
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
import re
|
||||||
|
|
||||||
|
flags = 0 if case_sensitive else re.IGNORECASE
|
||||||
|
regex = re.compile(pattern, flags)
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for path in paths:
|
||||||
|
if not os.path.isfile(path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
for line_num, line in enumerate(f, 1):
|
||||||
|
if regex.search(line):
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"file": path,
|
||||||
|
"line": line_num,
|
||||||
|
"content": line.rstrip(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except (UnicodeDecodeError, PermissionError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class ReadFileTool(ReadTool):
|
||||||
|
"""文件读取工具"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="read_file",
|
||||||
|
description="读取文件内容",
|
||||||
|
permission_class=PermissionClass.READ,
|
||||||
|
side_effect_scope=SideEffectScope.NONE,
|
||||||
|
tags=["file", "read"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "文件路径",
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "最大行数",
|
||||||
|
},
|
||||||
|
"offset": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "起始行号",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["path"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_return_schema(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"content": {"type": "string"},
|
||||||
|
"lines": {"type": "integer"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, path: str, limit: int | None = None, offset: int = 0) -> dict[str, Any]:
|
||||||
|
if not os.path.isfile(path):
|
||||||
|
raise FileNotFoundError(f"File not found: {path}")
|
||||||
|
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
total_lines = len(lines)
|
||||||
|
start = max(0, offset)
|
||||||
|
end = len(lines) if limit is None else min(start + limit, len(lines))
|
||||||
|
|
||||||
|
content = "".join(lines[start:end])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"content": content,
|
||||||
|
"lines": total_lines,
|
||||||
|
"truncated": limit is not None and end < len(lines),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class WriteFileTool(WriteTool):
|
||||||
|
"""文件写入工具"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="write_file",
|
||||||
|
description="写入文件内容",
|
||||||
|
permission_class=PermissionClass.WRITE,
|
||||||
|
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||||
|
requires_confirmation=True,
|
||||||
|
tags=["file", "write"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "文件路径",
|
||||||
|
},
|
||||||
|
"content": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "文件内容",
|
||||||
|
},
|
||||||
|
"append": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "是否追加模式",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["path", "content"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_return_schema(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"success": {"type": "boolean"},
|
||||||
|
"bytes_written": {"type": "integer"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, path: str, content: str, append: bool = False) -> dict[str, Any]:
|
||||||
|
mode = "a" if append else "w"
|
||||||
|
|
||||||
|
# 确保目录存在
|
||||||
|
directory = os.path.dirname(path)
|
||||||
|
if directory and not os.path.exists(directory):
|
||||||
|
os.makedirs(directory, exist_ok=True)
|
||||||
|
|
||||||
|
with open(path, mode, encoding="utf-8") as f:
|
||||||
|
bytes_written = f.write(content)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"bytes_written": bytes_written,
|
||||||
|
}
|
||||||
193
backend/app/agents/tools/builtins/system_tools.py
Normal file
193
backend/app/agents/tools/builtins/system_tools.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
"""系统工具 - Phase 6.4"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import shlex
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.tools.base import ExternalTool
|
||||||
|
from app.agents.tools.manifest import (
|
||||||
|
PermissionClass,
|
||||||
|
SideEffectScope,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BashTool(ExternalTool):
|
||||||
|
"""Bash 命令执行工具"""
|
||||||
|
|
||||||
|
def __init__(self, working_dir: str = "."):
|
||||||
|
super().__init__(
|
||||||
|
name="bash",
|
||||||
|
description="执行 Bash 命令",
|
||||||
|
permission_class=PermissionClass.EXTERNAL,
|
||||||
|
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||||
|
requires_confirmation=True,
|
||||||
|
tags=["system", "bash", "shell"],
|
||||||
|
)
|
||||||
|
self.working_dir = working_dir
|
||||||
|
|
||||||
|
def get_parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "要执行的 Bash 命令",
|
||||||
|
},
|
||||||
|
"timeout": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "超时时间(秒)",
|
||||||
|
},
|
||||||
|
"working_dir": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "工作目录(可选)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["command"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_return_schema(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"stdout": {"type": "string"},
|
||||||
|
"stderr": {"type": "string"},
|
||||||
|
"returncode": {"type": "integer"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self, command: str, timeout: int = 30, working_dir: str | None = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
import os
|
||||||
|
|
||||||
|
cwd = working_dir or self.working_dir
|
||||||
|
|
||||||
|
try:
|
||||||
|
process = await asyncio.create_subprocess_shell(
|
||||||
|
command,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
cwd=cwd,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
process.kill()
|
||||||
|
await process.wait()
|
||||||
|
return {
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": f"Command timed out after {timeout} seconds",
|
||||||
|
"returncode": -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"stdout": stdout.decode("utf-8", errors="replace"),
|
||||||
|
"stderr": stderr.decode("utf-8", errors="replace"),
|
||||||
|
"returncode": process.returncode,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": str(e),
|
||||||
|
"returncode": -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PowerShellTool(ExternalTool):
|
||||||
|
"""PowerShell 命令执行工具"""
|
||||||
|
|
||||||
|
def __init__(self, working_dir: str = "."):
|
||||||
|
super().__init__(
|
||||||
|
name="powershell",
|
||||||
|
description="执行 PowerShell 命令",
|
||||||
|
permission_class=PermissionClass.EXTERNAL,
|
||||||
|
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||||
|
requires_confirmation=True,
|
||||||
|
tags=["system", "powershell", "shell"],
|
||||||
|
)
|
||||||
|
self.working_dir = working_dir
|
||||||
|
|
||||||
|
def get_parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "要执行的 PowerShell 命令",
|
||||||
|
},
|
||||||
|
"timeout": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "超时时间(秒)",
|
||||||
|
},
|
||||||
|
"working_dir": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "工作目录(可选)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["command"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_return_schema(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"stdout": {"type": "string"},
|
||||||
|
"stderr": {"type": "string"},
|
||||||
|
"returncode": {"type": "integer"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self, command: str, timeout: int = 30, working_dir: str | None = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
import platform
|
||||||
|
|
||||||
|
# 检测是否是 Windows 平台
|
||||||
|
is_windows = platform.system() == "Windows"
|
||||||
|
|
||||||
|
if not is_windows:
|
||||||
|
# 非 Windows 平台,可能没有 PowerShell
|
||||||
|
return {
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": "PowerShell is not available on this platform",
|
||||||
|
"returncode": -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
cwd = working_dir or self.working_dir
|
||||||
|
|
||||||
|
try:
|
||||||
|
process = await asyncio.create_subprocess_exec(
|
||||||
|
"powershell.exe",
|
||||||
|
"-NoProfile",
|
||||||
|
"-Command",
|
||||||
|
command,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
cwd=cwd,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
process.kill()
|
||||||
|
await process.wait()
|
||||||
|
return {
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": f"Command timed out after {timeout} seconds",
|
||||||
|
"returncode": -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"stdout": stdout.decode("utf-8", errors="replace"),
|
||||||
|
"stderr": stderr.decode("utf-8", errors="replace"),
|
||||||
|
"returncode": process.returncode,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": str(e),
|
||||||
|
"returncode": -1,
|
||||||
|
}
|
||||||
46
backend/app/agents/tools/hooks/__init__.py
Normal file
46
backend/app/agents/tools/hooks/__init__.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""Hook 系统 - Phase 6.2"""
|
||||||
|
|
||||||
|
from app.agents.tools.hooks.types import (
|
||||||
|
HookDefinition,
|
||||||
|
HookResult,
|
||||||
|
HookStage,
|
||||||
|
HookTrigger,
|
||||||
|
HookType,
|
||||||
|
ExecutionContext,
|
||||||
|
HookHandler,
|
||||||
|
PreToolHook,
|
||||||
|
PostToolHook,
|
||||||
|
ErrorToolHook,
|
||||||
|
SkipToolHook,
|
||||||
|
)
|
||||||
|
from app.agents.tools.hooks.manager import (
|
||||||
|
HookManager,
|
||||||
|
get_hook_manager,
|
||||||
|
reset_hook_manager,
|
||||||
|
)
|
||||||
|
from app.agents.tools.hooks.executor import (
|
||||||
|
HookExecutor,
|
||||||
|
get_hook_executor,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Types
|
||||||
|
"HookType",
|
||||||
|
"HookStage",
|
||||||
|
"HookTrigger",
|
||||||
|
"HookDefinition",
|
||||||
|
"HookResult",
|
||||||
|
"ExecutionContext",
|
||||||
|
"HookHandler",
|
||||||
|
"PreToolHook",
|
||||||
|
"PostToolHook",
|
||||||
|
"ErrorToolHook",
|
||||||
|
"SkipToolHook",
|
||||||
|
# Manager
|
||||||
|
"HookManager",
|
||||||
|
"get_hook_manager",
|
||||||
|
"reset_hook_manager",
|
||||||
|
# Executor
|
||||||
|
"HookExecutor",
|
||||||
|
"get_hook_executor",
|
||||||
|
]
|
||||||
170
backend/app/agents/tools/hooks/executor.py
Normal file
170
backend/app/agents/tools/hooks/executor.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
"""Hook 执行器 - Phase 6.2
|
||||||
|
|
||||||
|
执行 Hook 拦截逻辑。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.tools.hooks.manager import get_hook_manager
|
||||||
|
from app.agents.tools.hooks.types import (
|
||||||
|
HookDefinition,
|
||||||
|
HookResult,
|
||||||
|
HookType,
|
||||||
|
ExecutionContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HookExecutor:
|
||||||
|
"""Hook 执行器
|
||||||
|
|
||||||
|
负责在工具执行前后执行 Hook 逻辑。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._manager = get_hook_manager()
|
||||||
|
|
||||||
|
async def execute_pre_hooks(
|
||||||
|
self, context: ExecutionContext
|
||||||
|
) -> tuple[bool, dict[str, Any] | None]:
|
||||||
|
"""执行 pre-tool Hook
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: 执行上下文
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(是否继续执行, 修改后的输入)
|
||||||
|
"""
|
||||||
|
hooks = self._manager.get_hooks(HookType.PRE_TOOL_USE, context.tool_name)
|
||||||
|
modified_input = context.tool_input
|
||||||
|
|
||||||
|
for hook in hooks:
|
||||||
|
try:
|
||||||
|
# 调用 hook handler
|
||||||
|
handler = hook.handler
|
||||||
|
if callable(handler):
|
||||||
|
result = await self._call_hook(handler, context)
|
||||||
|
if result and not result.continue_execution:
|
||||||
|
# Hook 决定中断执行
|
||||||
|
return False, modified_input
|
||||||
|
if result.modified_input is not None:
|
||||||
|
modified_input = result.modified_input
|
||||||
|
except Exception as e:
|
||||||
|
# Hook 出错,默认继续执行
|
||||||
|
pass
|
||||||
|
|
||||||
|
return True, modified_input
|
||||||
|
|
||||||
|
async def execute_post_hooks(self, context: ExecutionContext, result: Any) -> Any:
|
||||||
|
"""执行 post-tool Hook
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: 执行上下文
|
||||||
|
result: 工具执行结果
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
修改后的结果
|
||||||
|
"""
|
||||||
|
hooks = self._manager.get_hooks(HookType.POST_TOOL_USE, context.tool_name)
|
||||||
|
modified_result = result
|
||||||
|
|
||||||
|
for hook in hooks:
|
||||||
|
try:
|
||||||
|
handler = hook.handler
|
||||||
|
if callable(handler):
|
||||||
|
hook_result = await self._call_hook(handler, context, modified_result)
|
||||||
|
if hook_result and hook_result.modified_output is not None:
|
||||||
|
modified_result = hook_result.modified_output
|
||||||
|
except Exception:
|
||||||
|
# Hook 出错,默认保留原结果
|
||||||
|
pass
|
||||||
|
|
||||||
|
return modified_result
|
||||||
|
|
||||||
|
async def execute_error_hooks(
|
||||||
|
self, context: ExecutionContext, error: Exception
|
||||||
|
) -> HookResult | None:
|
||||||
|
"""执行 error Hook
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: 执行上下文
|
||||||
|
error: 异常
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hook 结果,如果返回 None 则继续传播错误
|
||||||
|
"""
|
||||||
|
hooks = self._manager.get_hooks(HookType.TOOL_ERROR, context.tool_name)
|
||||||
|
|
||||||
|
for hook in hooks:
|
||||||
|
try:
|
||||||
|
handler = hook.handler
|
||||||
|
if callable(handler):
|
||||||
|
result = await self._call_hook(handler, context, error)
|
||||||
|
if result is not None and result.continue_execution:
|
||||||
|
return result
|
||||||
|
except Exception:
|
||||||
|
# Hook 出错,继续执行其他 error hooks
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def execute_skip_check(self, context: ExecutionContext) -> bool:
|
||||||
|
"""检查是否应跳过工具执行
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: 执行上下文
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True 表示跳过,False 表示执行
|
||||||
|
"""
|
||||||
|
hooks = self._manager.get_hooks(HookType.TOOL_SKIP, context.tool_name)
|
||||||
|
|
||||||
|
for hook in hooks:
|
||||||
|
try:
|
||||||
|
handler = hook.handler
|
||||||
|
if callable(handler):
|
||||||
|
result = await self._call_hook(handler, context)
|
||||||
|
if result is not None and isinstance(result, bool):
|
||||||
|
return result
|
||||||
|
except Exception:
|
||||||
|
# Hook 出错,默认不跳过
|
||||||
|
pass
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _call_hook(
|
||||||
|
self, handler: Any, context: ExecutionContext, *args: Any
|
||||||
|
) -> HookResult | None:
|
||||||
|
"""调用 Hook 处理函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
handler: Hook 处理函数
|
||||||
|
context: 执行上下文
|
||||||
|
*args: 额外参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hook 结果
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# 如果是普通函数,直接调用
|
||||||
|
if asyncio.iscoroutinefunction(handler):
|
||||||
|
return await handler(context, *args)
|
||||||
|
else:
|
||||||
|
return handler(context, *args)
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
_executor: HookExecutor | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_hook_executor() -> HookExecutor:
|
||||||
|
"""获取全局 Hook 执行器
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
全局 HookExecutor 实例
|
||||||
|
"""
|
||||||
|
global _executor
|
||||||
|
if _executor is None:
|
||||||
|
_executor = HookExecutor()
|
||||||
|
return _executor
|
||||||
174
backend/app/agents/tools/hooks/manager.py
Normal file
174
backend/app/agents/tools/hooks/manager.py
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
"""Hook 管理器 - Phase 6.2
|
||||||
|
|
||||||
|
管理 Hook 的注册、查找和配置。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.tools.hooks.types import (
|
||||||
|
HookDefinition,
|
||||||
|
HookResult,
|
||||||
|
HookTrigger,
|
||||||
|
HookType,
|
||||||
|
ExecutionContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HookManager:
|
||||||
|
"""Hook 管理器
|
||||||
|
|
||||||
|
管理全局 Hook 的注册和配置。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._hooks: dict[HookType, list[HookDefinition]] = {
|
||||||
|
HookType.PRE_TOOL_USE: [],
|
||||||
|
HookType.POST_TOOL_USE: [],
|
||||||
|
HookType.TOOL_ERROR: [],
|
||||||
|
HookType.TOOL_SKIP: [],
|
||||||
|
}
|
||||||
|
self._global_hooks: list[HookDefinition] = [] # 全局 Hook(对所有工具生效)
|
||||||
|
|
||||||
|
def register(self, definition: HookDefinition) -> None:
|
||||||
|
"""注册 Hook
|
||||||
|
|
||||||
|
Args:
|
||||||
|
definition: Hook 定义
|
||||||
|
"""
|
||||||
|
if definition.trigger.tool_names is None and definition.trigger.categories is None:
|
||||||
|
# 全局 Hook
|
||||||
|
self._global_hooks.append(definition)
|
||||||
|
else:
|
||||||
|
# 特定工具 Hook
|
||||||
|
self._hooks[definition.hook_type].append(definition)
|
||||||
|
|
||||||
|
# 按优先级排序
|
||||||
|
self._hooks[definition.hook_type].sort(key=lambda h: h.priority, reverse=True)
|
||||||
|
self._global_hooks.sort(key=lambda h: h.priority, reverse=True)
|
||||||
|
|
||||||
|
def unregister(self, name: str) -> bool:
|
||||||
|
"""注销 Hook
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Hook 名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功注销
|
||||||
|
"""
|
||||||
|
# 从特定工具 Hook 中移除
|
||||||
|
for hooks in self._hooks.values():
|
||||||
|
for i, hook in enumerate(hooks):
|
||||||
|
if hook.name == name:
|
||||||
|
hooks.pop(i)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 从全局 Hook 中移除
|
||||||
|
for i, hook in enumerate(self._global_hooks):
|
||||||
|
if hook.name == name:
|
||||||
|
self._global_hooks.pop(i)
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_hooks(self, hook_type: HookType, tool_name: str | None = None) -> list[HookDefinition]:
|
||||||
|
"""获取指定类型和工具的 Hook
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hook_type: Hook 类型
|
||||||
|
tool_name: 工具名称(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
匹配的 Hook 列表
|
||||||
|
"""
|
||||||
|
result: list[HookDefinition] = []
|
||||||
|
|
||||||
|
# 添加全局 Hook
|
||||||
|
for hook in self._global_hooks:
|
||||||
|
if hook.hook_type == hook_type and hook.enabled:
|
||||||
|
result.append(hook)
|
||||||
|
|
||||||
|
# 添加特定工具 Hook
|
||||||
|
for hook in self._hooks[hook_type]:
|
||||||
|
if not hook.enabled:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if hook.trigger.tool_names is None and hook.trigger.categories is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查是否匹配
|
||||||
|
if hook.trigger.tool_names and tool_name not in hook.trigger.tool_names:
|
||||||
|
continue
|
||||||
|
|
||||||
|
result.append(hook)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def list_all(self) -> list[HookDefinition]:
|
||||||
|
"""列出所有已注册的 Hook
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hook 列表
|
||||||
|
"""
|
||||||
|
all_hooks = list(self._global_hooks)
|
||||||
|
for hooks in self._hooks.values():
|
||||||
|
all_hooks.extend(hooks)
|
||||||
|
return all_hooks
|
||||||
|
|
||||||
|
def enable(self, name: str) -> bool:
|
||||||
|
"""启用 Hook
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Hook 名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功启用
|
||||||
|
"""
|
||||||
|
for hook in self.list_all():
|
||||||
|
if hook.name == name:
|
||||||
|
hook.enabled = True
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def disable(self, name: str) -> bool:
|
||||||
|
"""禁用 Hook
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Hook 名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功禁用
|
||||||
|
"""
|
||||||
|
for hook in self.list_all():
|
||||||
|
if hook.name == name:
|
||||||
|
hook.enabled = False
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""清除所有 Hook"""
|
||||||
|
self._hooks = {ht: [] for ht in HookType}
|
||||||
|
self._global_hooks = []
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
_global_hook_manager: HookManager | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_hook_manager() -> HookManager:
|
||||||
|
"""获取全局 Hook 管理器
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
全局 HookManager 实例
|
||||||
|
"""
|
||||||
|
global _global_hook_manager
|
||||||
|
if _global_hook_manager is None:
|
||||||
|
_global_hook_manager = HookManager()
|
||||||
|
return _global_hook_manager
|
||||||
|
|
||||||
|
|
||||||
|
def reset_hook_manager() -> None:
|
||||||
|
"""重置全局 Hook 管理器(用于测试)"""
|
||||||
|
global _global_hook_manager
|
||||||
|
if _global_hook_manager is not None:
|
||||||
|
_global_hook_manager.clear()
|
||||||
|
_global_hook_manager = None
|
||||||
90
backend/app/agents/tools/hooks/types.py
Normal file
90
backend/app/agents/tools/hooks/types.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""Hook 类型定义 - Phase 6.2
|
||||||
|
|
||||||
|
Hook 拦截系统类型定义。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
|
||||||
|
class HookType(Enum):
|
||||||
|
"""Hook 类型"""
|
||||||
|
|
||||||
|
PRE_TOOL_USE = "pre_tool_use" # 工具执行前
|
||||||
|
POST_TOOL_USE = "post_tool_use" # 工具执行后
|
||||||
|
TOOL_ERROR = "tool_error" # 工具执行出错
|
||||||
|
TOOL_SKIP = "tool_skip" # 工具跳过(条件执行)
|
||||||
|
|
||||||
|
|
||||||
|
class HookStage(Enum):
|
||||||
|
"""Hook 执行阶段"""
|
||||||
|
|
||||||
|
BEFORE = "before"
|
||||||
|
AFTER = "after"
|
||||||
|
ON_ERROR = "on_error"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HookTrigger:
|
||||||
|
"""Hook 触发条件"""
|
||||||
|
|
||||||
|
tool_names: list[str] | None = None # 只对特定工具生效,None 表示全部
|
||||||
|
categories: list[str] | None = None # 只对特定类别生效
|
||||||
|
conditions: dict[str, Any] | None = None # 自定义条件
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HookDefinition:
|
||||||
|
"""Hook 定义"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
hook_type: HookType
|
||||||
|
trigger: HookTrigger
|
||||||
|
handler: Callable[..., Any] # Hook 处理函数
|
||||||
|
priority: int = 0 # 优先级,数字越大越先执行
|
||||||
|
enabled: bool = True
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HookResult:
|
||||||
|
"""Hook 执行结果"""
|
||||||
|
|
||||||
|
hook_name: str
|
||||||
|
success: bool
|
||||||
|
continue_execution: bool = True # False 表示中断执行
|
||||||
|
modified_input: Any = None # 修改后的输入
|
||||||
|
modified_output: Any = None # 修改后的输出
|
||||||
|
error: str | None = None
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExecutionContext:
|
||||||
|
"""工具执行上下文"""
|
||||||
|
|
||||||
|
tool_name: str
|
||||||
|
tool_input: dict[str, Any]
|
||||||
|
user_id: str | None = None
|
||||||
|
session_id: str | None = None
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# 执行结果(由 HookExecutor 填充)
|
||||||
|
result: Any = None
|
||||||
|
error: Exception | None = None
|
||||||
|
start_time: float | None = None
|
||||||
|
end_time: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# Hook 处理函数类型
|
||||||
|
HookHandler = Callable[[ExecutionContext, HookDefinition], HookResult]
|
||||||
|
|
||||||
|
# Pre-hook: 在工具执行前调用,可以修改输入或决定是否跳过
|
||||||
|
PreToolHook = Callable[[ExecutionContext], tuple[bool, dict[str, Any] | None]]
|
||||||
|
# post-hook: 在工具执行后调用,可以修改输出
|
||||||
|
PostToolHook = Callable[[ExecutionContext, Any], Any]
|
||||||
|
# Error hook: 在工具出错时调用
|
||||||
|
ErrorToolHook = Callable[[ExecutionContext, Exception], HookResult | None]
|
||||||
|
# Skip hook: 决定是否跳过工具执行
|
||||||
|
SkipToolHook = Callable[[ExecutionContext], bool]
|
||||||
77
backend/app/agents/tools/manifest.py
Normal file
77
backend/app/agents/tools/manifest.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""工具元数据和数据类型定义"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCategory(Enum):
|
||||||
|
"""工具类别"""
|
||||||
|
|
||||||
|
READ = "read"
|
||||||
|
WRITE = "write"
|
||||||
|
EXTERNAL = "external"
|
||||||
|
DB_WRITE = "db_write"
|
||||||
|
NETWORK = "network"
|
||||||
|
|
||||||
|
|
||||||
|
class SideEffectScope(Enum):
|
||||||
|
"""副作用范围"""
|
||||||
|
|
||||||
|
NONE = "none"
|
||||||
|
LOCAL_STATE = "local_state"
|
||||||
|
DB_WRITE = "db_write"
|
||||||
|
NETWORK = "network"
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionClass(Enum):
|
||||||
|
"""权限级别"""
|
||||||
|
|
||||||
|
READ = "read"
|
||||||
|
WRITE = "write"
|
||||||
|
EXTERNAL = "external"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolManifest:
|
||||||
|
"""工具元数据"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
category: ToolCategory
|
||||||
|
parameters: dict[str, Any] # JSON Schema
|
||||||
|
return_schema: dict[str, Any]
|
||||||
|
permission_class: PermissionClass
|
||||||
|
side_effect_scope: SideEffectScope
|
||||||
|
requires_confirmation: bool = False
|
||||||
|
is_streaming: bool = False
|
||||||
|
tags: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"category": self.category.value,
|
||||||
|
"parameters": self.parameters,
|
||||||
|
"return_schema": self.return_schema,
|
||||||
|
"permission_class": self.permission_class.value,
|
||||||
|
"side_effect_scope": self.side_effect_scope.value,
|
||||||
|
"requires_confirmation": self.requires_confirmation,
|
||||||
|
"is_streaming": self.is_streaming,
|
||||||
|
"tags": self.tags,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HookConfig:
|
||||||
|
"""Hook 配置"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
hook_type: str # "pre_tool_use", "post_tool_use", "tool_error", "tool_skip"
|
||||||
|
filter_names: list[str] | None = None # 只对特定工具生效,None 表示全部
|
||||||
|
|
||||||
|
def matches_tool(self, tool_name: str) -> bool:
|
||||||
|
"""检查 Hook 是否对指定工具生效"""
|
||||||
|
if self.filter_names is None:
|
||||||
|
return True
|
||||||
|
return tool_name in self.filter_names
|
||||||
251
backend/app/agents/tools/migration.py
Normal file
251
backend/app/agents/tools/migration.py
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
"""工具迁移和向后兼容层 - Phase 6.1
|
||||||
|
|
||||||
|
将现有 @tool 装饰的工具迁移到 ToolRegistry,同时保持向后兼容。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
from app.agents.tools.manifest import (
|
||||||
|
PermissionClass,
|
||||||
|
SideEffectScope,
|
||||||
|
ToolCategory,
|
||||||
|
ToolManifest,
|
||||||
|
)
|
||||||
|
from app.agents.tools.registry import get_tool_registry
|
||||||
|
|
||||||
|
|
||||||
|
# 现有工具的类别映射
|
||||||
|
_TOOL_CATEGORY_MAP: dict[str, tuple[ToolCategory, PermissionClass, SideEffectScope]] = {
|
||||||
|
# 知识检索 - 只读
|
||||||
|
"search_knowledge": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||||
|
"get_knowledge_graph_context": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||||
|
"hybrid_search": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||||
|
"web_search": (ToolCategory.NETWORK, PermissionClass.EXTERNAL, SideEffectScope.NETWORK),
|
||||||
|
# 知识构建 - 写入
|
||||||
|
"build_knowledge_graph": (
|
||||||
|
ToolCategory.WRITE,
|
||||||
|
PermissionClass.WRITE,
|
||||||
|
SideEffectScope.LOCAL_STATE,
|
||||||
|
),
|
||||||
|
# 任务工具
|
||||||
|
"get_tasks": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||||
|
"create_task": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||||
|
"update_task_status": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||||
|
# 日程工具
|
||||||
|
"get_schedule_day": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||||
|
"create_todo": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||||
|
"create_schedule_task": (
|
||||||
|
ToolCategory.WRITE,
|
||||||
|
PermissionClass.WRITE,
|
||||||
|
SideEffectScope.LOCAL_STATE,
|
||||||
|
),
|
||||||
|
"create_reminder": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||||
|
"create_goal": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||||
|
"resolve_time_expression": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||||
|
# 论坛工具
|
||||||
|
"get_forum_posts": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||||
|
"create_forum_post": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||||
|
"scan_forum_for_instructions": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_category(name: str) -> tuple[ToolCategory, PermissionClass, SideEffectScope]:
|
||||||
|
"""获取工具的类别信息"""
|
||||||
|
return _TOOL_CATEGORY_MAP.get(
|
||||||
|
name,
|
||||||
|
(ToolCategory.EXTERNAL, PermissionClass.EXTERNAL, SideEffectScope.NETWORK),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def infer_tags_from_docstring(docstring: str | None) -> list[str]:
|
||||||
|
"""从 docstring 推断工具标签"""
|
||||||
|
if not docstring:
|
||||||
|
return []
|
||||||
|
tags = []
|
||||||
|
doc_lower = docstring.lower()
|
||||||
|
if "搜索" in docstring or "查询" in docstring or "search" in doc_lower:
|
||||||
|
tags.append("search")
|
||||||
|
if "创建" in docstring or "新建" in docstring or "create" in doc_lower:
|
||||||
|
tags.append("create")
|
||||||
|
if "获取" in docstring or "读取" in docstring or "get" in doc_lower:
|
||||||
|
tags.append("read")
|
||||||
|
if "更新" in docstring or "修改" in docstring or "update" in doc_lower:
|
||||||
|
tags.append("update")
|
||||||
|
return tags
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_tool(tool_func: Callable) -> Callable:
|
||||||
|
"""将现有 @tool 装饰的函数迁移到 ToolRegistry
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_func: LangChain @tool 装饰的函数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
原函数(已注册到 registry)
|
||||||
|
"""
|
||||||
|
registry = get_tool_registry()
|
||||||
|
|
||||||
|
# 如果已经注册,跳过
|
||||||
|
if registry.get(tool_func.name):
|
||||||
|
return tool_func
|
||||||
|
|
||||||
|
# 获取类别信息
|
||||||
|
category, permission, side_effect = get_tool_category(tool_func.name)
|
||||||
|
|
||||||
|
# 从 docstring 提取 description
|
||||||
|
description = tool_func.description if hasattr(tool_func, "description") else ""
|
||||||
|
|
||||||
|
# 推断 tags
|
||||||
|
tags = infer_tags_from_docstring(description)
|
||||||
|
tags.append("migrated")
|
||||||
|
|
||||||
|
# 创建 manifest
|
||||||
|
manifest = ToolManifest(
|
||||||
|
name=tool_func.name,
|
||||||
|
description=description,
|
||||||
|
category=category,
|
||||||
|
parameters={}, # LangChain @tool 动态处理参数
|
||||||
|
return_schema={},
|
||||||
|
permission_class=permission,
|
||||||
|
side_effect_scope=side_effect,
|
||||||
|
requires_confirmation=side_effect != SideEffectScope.NONE,
|
||||||
|
is_streaming=False,
|
||||||
|
tags=tags,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注册到 registry
|
||||||
|
registry.register(manifest, tool_func)
|
||||||
|
|
||||||
|
return tool_func
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_all_tools() -> int:
|
||||||
|
"""迁移所有现有工具到 ToolRegistry
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
迁移的工具数量
|
||||||
|
"""
|
||||||
|
from app.agents.tools import (
|
||||||
|
ALL_TOOLS,
|
||||||
|
KNOWLEDGE_GRAPH_TOOLS,
|
||||||
|
KNOWLEDGE_RETRIEVAL_TOOLS,
|
||||||
|
SCHEDULE_READ_TOOLS,
|
||||||
|
SCHEDULE_WRITE_TOOLS,
|
||||||
|
TASK_TOOLS,
|
||||||
|
FORUM_TOOLS,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_tools = (
|
||||||
|
KNOWLEDGE_RETRIEVAL_TOOLS
|
||||||
|
+ KNOWLEDGE_GRAPH_TOOLS
|
||||||
|
+ TASK_TOOLS
|
||||||
|
+ SCHEDULE_READ_TOOLS
|
||||||
|
+ SCHEDULE_WRITE_TOOLS
|
||||||
|
+ FORUM_TOOLS
|
||||||
|
)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for tool in all_tools:
|
||||||
|
try:
|
||||||
|
migrate_tool(tool)
|
||||||
|
count += 1
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to migrate tool {getattr(tool, 'name', 'unknown')}: {e}")
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
class BackwardCompatTool:
|
||||||
|
"""向后兼容工具包装器
|
||||||
|
|
||||||
|
确保现有代码通过 registry.get_executor() 仍能正常调用工具。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self.name = name
|
||||||
|
self._registry = get_tool_registry()
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs) -> Any:
|
||||||
|
executor = self._registry.get_executor(self.name)
|
||||||
|
if executor is None:
|
||||||
|
raise ValueError(f"Tool not found in registry: {self.name}")
|
||||||
|
return executor(*args, **kwargs)
|
||||||
|
|
||||||
|
def invoke(self, tool_input: dict[str, Any]) -> Any:
|
||||||
|
"""LangChain 风格的 invoke 调用"""
|
||||||
|
executor = self._registry.get_executor(self.name)
|
||||||
|
if executor is None:
|
||||||
|
raise ValueError(f"Tool not found in registry: {self.name}")
|
||||||
|
|
||||||
|
# 处理位置参数
|
||||||
|
if isinstance(tool_input, dict):
|
||||||
|
return executor(**tool_input)
|
||||||
|
return executor(tool_input)
|
||||||
|
|
||||||
|
|
||||||
|
def create_compat_layer() -> dict[str, BackwardCompatTool]:
|
||||||
|
"""创建向后兼容层
|
||||||
|
|
||||||
|
返回一个字典,允许通过名称访问兼容的工具包装器。
|
||||||
|
"""
|
||||||
|
registry = get_tool_registry()
|
||||||
|
tools = registry.list_all()
|
||||||
|
|
||||||
|
return {tool.name: BackwardCompatTool(tool.name) for tool in tools}
|
||||||
|
|
||||||
|
|
||||||
|
# 自动迁移装饰器
|
||||||
|
def auto_migrate(func: Callable) -> Callable:
|
||||||
|
"""自动迁移装饰器
|
||||||
|
|
||||||
|
用于装饰新的 @tool 函数,自动注册到 registry。
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
# 迁移到 registry
|
||||||
|
migrate_tool(wrapper)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
# 便捷函数:获取兼容的工具执行器
|
||||||
|
def get_tool_executor(name: str) -> Callable | None:
|
||||||
|
"""获取工具执行器(兼容层)
|
||||||
|
|
||||||
|
优先从 registry 获取,fallback 到直接导入。
|
||||||
|
"""
|
||||||
|
registry = get_tool_registry()
|
||||||
|
executor = registry.get_executor(name)
|
||||||
|
|
||||||
|
if executor is not None:
|
||||||
|
return executor
|
||||||
|
|
||||||
|
# Fallback: 直接从模块导入(仅用于迁移期间)
|
||||||
|
try:
|
||||||
|
from app.agents.tools import (
|
||||||
|
TASK_TOOLS,
|
||||||
|
SCHEDULE_READ_TOOLS,
|
||||||
|
SCHEDULE_WRITE_TOOLS,
|
||||||
|
FORUM_TOOLS,
|
||||||
|
KNOWLEDGE_RETRIEVAL_TOOLS,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_tools = (
|
||||||
|
KNOWLEDGE_RETRIEVAL_TOOLS
|
||||||
|
+ TASK_TOOLS
|
||||||
|
+ SCHEDULE_READ_TOOLS
|
||||||
|
+ SCHEDULE_WRITE_TOOLS
|
||||||
|
+ FORUM_TOOLS
|
||||||
|
)
|
||||||
|
|
||||||
|
for tool in all_tools:
|
||||||
|
if hasattr(tool, "name") and tool.name == name:
|
||||||
|
return tool
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
206
backend/app/agents/tools/registry.py
Normal file
206
backend/app/agents/tools/registry.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
"""工具注册表 - 工具系统重构 Phase 6.1"""
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
from app.agents.tools.manifest import HookConfig, ToolManifest
|
||||||
|
|
||||||
|
|
||||||
|
class ToolRegistry:
|
||||||
|
"""工具注册表
|
||||||
|
|
||||||
|
统一管理所有工具的注册、发现和调用。
|
||||||
|
支持工具元数据、权限分类、Hook 拦截。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._tools: dict[str, ToolManifest] = {}
|
||||||
|
self._executors: dict[str, Callable] = {}
|
||||||
|
self._hooks: dict[str, list[HookConfig]] = defaultdict(list)
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self, manifest: ToolManifest, executor: Callable, hooks: list[HookConfig] | None = None
|
||||||
|
) -> None:
|
||||||
|
"""注册工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
manifest: 工具元数据
|
||||||
|
executor: 工具执行函数
|
||||||
|
hooks: 可选的 Hook 配置列表
|
||||||
|
"""
|
||||||
|
if manifest.name in self._tools:
|
||||||
|
raise ValueError(f"Tool already registered: {manifest.name}")
|
||||||
|
|
||||||
|
self._tools[manifest.name] = manifest
|
||||||
|
self._executors[manifest.name] = executor
|
||||||
|
|
||||||
|
if hooks:
|
||||||
|
for hook in hooks:
|
||||||
|
self._hooks[manifest.name].append(hook)
|
||||||
|
|
||||||
|
def unregister(self, name: str) -> bool:
|
||||||
|
"""注销工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 工具名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功注销
|
||||||
|
"""
|
||||||
|
if name not in self._tools:
|
||||||
|
return False
|
||||||
|
|
||||||
|
del self._tools[name]
|
||||||
|
del self._executors[name]
|
||||||
|
if name in self._hooks:
|
||||||
|
del self._hooks[name]
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get(self, name: str) -> ToolManifest | None:
|
||||||
|
"""获取工具元数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 工具名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具元数据,不存在返回 None
|
||||||
|
"""
|
||||||
|
return self._tools.get(name)
|
||||||
|
|
||||||
|
def get_executor(self, name: str) -> Callable | None:
|
||||||
|
"""获取工具执行器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 工具名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具执行函数,不存在返回 None
|
||||||
|
"""
|
||||||
|
return self._executors.get(name)
|
||||||
|
|
||||||
|
def get_hooks(self, name: str) -> list[HookConfig]:
|
||||||
|
"""获取工具的 Hook 配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 工具名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hook 配置列表
|
||||||
|
"""
|
||||||
|
return self._hooks.get(name, [])
|
||||||
|
|
||||||
|
def list_all(self) -> list[ToolManifest]:
|
||||||
|
"""列出所有已注册的工具
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具元数据列表
|
||||||
|
"""
|
||||||
|
return list(self._tools.values())
|
||||||
|
|
||||||
|
def list_by_category(self, category: Any) -> list[ToolManifest]:
|
||||||
|
"""按类别列出工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
category: 工具类别
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
该类别下的所有工具
|
||||||
|
"""
|
||||||
|
return [t for t in self._tools.values() if t.category == category]
|
||||||
|
|
||||||
|
def list_by_permission(self, permission: Any) -> list[ToolManifest]:
|
||||||
|
"""按权限级别列出工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
permission: 权限级别
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
该权限级别下的所有工具
|
||||||
|
"""
|
||||||
|
return [t for t in self._tools.values() if t.permission_class == permission]
|
||||||
|
|
||||||
|
def search_by_tag(self, tag: str) -> list[ToolManifest]:
|
||||||
|
"""按标签搜索工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tag: 标签
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含该标签的工具
|
||||||
|
"""
|
||||||
|
return [t for t in self._tools.values() if tag in t.tags]
|
||||||
|
|
||||||
|
def search_by_name(self, keyword: str) -> list[ToolManifest]:
|
||||||
|
"""按名称关键词搜索工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keyword: 关键词
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
名称包含关键词的工具
|
||||||
|
"""
|
||||||
|
keyword = keyword.lower()
|
||||||
|
return [t for t in self._tools.values() if keyword in t.name.lower()]
|
||||||
|
|
||||||
|
def get_requires_confirmation(self, name: str) -> bool:
|
||||||
|
"""检查工具是否需要确认
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 工具名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否需要确认
|
||||||
|
"""
|
||||||
|
manifest = self._tools.get(name)
|
||||||
|
return manifest.requires_confirmation if manifest else False
|
||||||
|
|
||||||
|
def get_is_streaming(self, name: str) -> bool:
|
||||||
|
"""检查工具是否支持流式执行
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 工具名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否支持流式
|
||||||
|
"""
|
||||||
|
manifest = self._tools.get(name)
|
||||||
|
return manifest.is_streaming if manifest else False
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""清空注册表"""
|
||||||
|
self._tools.clear()
|
||||||
|
self._executors.clear()
|
||||||
|
self._hooks.clear()
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._tools)
|
||||||
|
|
||||||
|
def __contains__(self, name: str) -> bool:
|
||||||
|
return name in self._tools
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._tools.values())
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例实例
|
||||||
|
_global_registry: ToolRegistry | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_registry() -> ToolRegistry:
|
||||||
|
"""获取全局工具注册表单例
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
全局 ToolRegistry 实例
|
||||||
|
"""
|
||||||
|
global _global_registry
|
||||||
|
if _global_registry is None:
|
||||||
|
_global_registry = ToolRegistry()
|
||||||
|
return _global_registry
|
||||||
|
|
||||||
|
|
||||||
|
def reset_tool_registry() -> None:
|
||||||
|
"""重置全局工具注册表(用于测试)"""
|
||||||
|
global _global_registry
|
||||||
|
if _global_registry is not None:
|
||||||
|
_global_registry.clear()
|
||||||
|
_global_registry = None
|
||||||
210
backend/app/agents/tools/streaming.py
Normal file
210
backend/app/agents/tools/streaming.py
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
"""流式工具执行器 - Phase 6.3
|
||||||
|
|
||||||
|
支持流式输出的工具执行器。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
|
from app.agents.tools.hooks.executor import get_hook_executor
|
||||||
|
from app.agents.tools.hooks.types import ExecutionContext
|
||||||
|
from app.agents.tools.registry import get_tool_registry
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingToolExecutor:
|
||||||
|
"""流式工具执行器
|
||||||
|
|
||||||
|
支持:
|
||||||
|
- 普通工具的同步/异步执行
|
||||||
|
- 流式工具的流式输出
|
||||||
|
- Hook 拦截(pre/post/error)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._registry = get_tool_registry()
|
||||||
|
self._hook_executor = get_hook_executor()
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
tool_input: dict[str, Any],
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> Any:
|
||||||
|
"""执行工具(非流式)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
tool_input: 工具输入参数
|
||||||
|
user_id: 用户 ID(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具执行结果
|
||||||
|
"""
|
||||||
|
# 创建执行上下文
|
||||||
|
context = ExecutionContext(
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_input=tool_input,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取工具和执行器
|
||||||
|
manifest = self._registry.get(tool_name)
|
||||||
|
if manifest is None:
|
||||||
|
raise ValueError(f"Tool not found: {tool_name}")
|
||||||
|
|
||||||
|
executor = self._registry.get_executor(tool_name)
|
||||||
|
if executor is None:
|
||||||
|
raise ValueError(f"Executor not found for tool: {tool_name}")
|
||||||
|
|
||||||
|
# 检查是否跳过
|
||||||
|
if await self._hook_executor.execute_skip_check(context):
|
||||||
|
return {"skipped": True, "tool": tool_name}
|
||||||
|
|
||||||
|
# 执行 pre-hooks
|
||||||
|
continue_execution, modified_input = await self._hook_executor.execute_pre_hooks(context)
|
||||||
|
if not continue_execution:
|
||||||
|
return {"pre_hook_aborted": True, "tool": tool_name}
|
||||||
|
|
||||||
|
# 执行工具
|
||||||
|
try:
|
||||||
|
context.start_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
|
# 判断是同步还是异步执行
|
||||||
|
if asyncio.iscoroutinefunction(executor):
|
||||||
|
result = await executor(**modified_input)
|
||||||
|
else:
|
||||||
|
# 同步函数在线程池中执行
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
result = await loop.run_in_executor(None, lambda: executor(**modified_input))
|
||||||
|
|
||||||
|
context.result = result
|
||||||
|
context.end_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
|
# 执行 post-hooks
|
||||||
|
result = await self._hook_executor.execute_post_hooks(context, result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
context.error = e
|
||||||
|
context.end_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
|
# 执行 error-hooks
|
||||||
|
error_result = await self._hook_executor.execute_error_hooks(context, e)
|
||||||
|
if error_result is not None:
|
||||||
|
return error_result
|
||||||
|
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def execute_streaming(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
tool_input: dict[str, Any],
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> AsyncGenerator[dict[str, Any], None]:
|
||||||
|
"""执行流式工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
tool_input: 工具输入参数
|
||||||
|
user_id: 用户 ID(可选)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
流式输出片段
|
||||||
|
"""
|
||||||
|
# 创建执行上下文
|
||||||
|
context = ExecutionContext(
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_input=tool_input,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取工具和执行器
|
||||||
|
manifest = self._registry.get(tool_name)
|
||||||
|
if manifest is None:
|
||||||
|
raise ValueError(f"Tool not found: {tool_name}")
|
||||||
|
|
||||||
|
if not manifest.is_streaming:
|
||||||
|
raise ValueError(f"Tool is not streaming: {tool_name}")
|
||||||
|
|
||||||
|
executor = self._registry.get_executor(tool_name)
|
||||||
|
if executor is None:
|
||||||
|
raise ValueError(f"Executor not found for tool: {tool_name}")
|
||||||
|
|
||||||
|
# 检查是否跳过
|
||||||
|
if await self._hook_executor.execute_skip_check(context):
|
||||||
|
yield {"type": "skipped", "tool": tool_name}
|
||||||
|
return
|
||||||
|
|
||||||
|
# 执行 pre-hooks
|
||||||
|
continue_execution, modified_input = await self._hook_executor.execute_pre_hooks(context)
|
||||||
|
if not continue_execution:
|
||||||
|
yield {"type": "pre_hook_aborted", "tool": tool_name}
|
||||||
|
return
|
||||||
|
|
||||||
|
# 执行流式工具
|
||||||
|
try:
|
||||||
|
context.start_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
|
# 调用执行器(应该返回 AsyncGenerator)
|
||||||
|
result = executor(**modified_input)
|
||||||
|
|
||||||
|
# 如果是 async generator
|
||||||
|
if asyncio.isasyncgen(result):
|
||||||
|
async for chunk in result:
|
||||||
|
yield {"type": "chunk", "data": chunk}
|
||||||
|
else:
|
||||||
|
# 普通协程
|
||||||
|
data = await result
|
||||||
|
yield {"type": "chunk", "data": data}
|
||||||
|
|
||||||
|
context.end_time = asyncio.get_event_loop().time()
|
||||||
|
yield {"type": "done", "tool": tool_name}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
context.error = e
|
||||||
|
context.end_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
|
yield {"type": "error", "error": str(e), "tool": tool_name}
|
||||||
|
|
||||||
|
# 执行 error-hooks
|
||||||
|
await self._hook_executor.execute_error_hooks(context, e)
|
||||||
|
|
||||||
|
async def execute_batch(
|
||||||
|
self,
|
||||||
|
tool_calls: list[dict[str, Any]],
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""批量执行工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_calls: 工具调用列表,每个元素包含 tool_name 和 tool_input
|
||||||
|
user_id: 用户 ID(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
执行结果列表
|
||||||
|
"""
|
||||||
|
tasks = []
|
||||||
|
for call in tool_calls:
|
||||||
|
tool_name = call.get("tool_name") or call.get("name")
|
||||||
|
tool_input = call.get("tool_input") or call.get("input") or {}
|
||||||
|
tasks.append(self.execute(tool_name, tool_input, user_id))
|
||||||
|
|
||||||
|
return await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
_executor: StreamingToolExecutor | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_streaming_executor() -> StreamingToolExecutor:
|
||||||
|
"""获取全局流式执行器
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
全局 StreamingToolExecutor 实例
|
||||||
|
"""
|
||||||
|
global _executor
|
||||||
|
if _executor is None:
|
||||||
|
_executor = StreamingToolExecutor()
|
||||||
|
return _executor
|
||||||
324
backend/tests/backend/app/agents/test_tools_registry.py
Normal file
324
backend/tests/backend/app/agents/test_tools_registry.py
Normal file
@@ -0,0 +1,324 @@
|
|||||||
|
"""ToolRegistry 单元测试"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
from app.agents.tools.manifest import (
|
||||||
|
HookConfig,
|
||||||
|
PermissionClass,
|
||||||
|
SideEffectScope,
|
||||||
|
ToolCategory,
|
||||||
|
ToolManifest,
|
||||||
|
)
|
||||||
|
from app.agents.tools.registry import (
|
||||||
|
ToolRegistry,
|
||||||
|
get_tool_registry,
|
||||||
|
reset_tool_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def registry():
|
||||||
|
"""创建空的 ToolRegistry 实例"""
|
||||||
|
return ToolRegistry()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_manifest():
|
||||||
|
"""创建示例工具元数据"""
|
||||||
|
return ToolManifest(
|
||||||
|
name="test_tool",
|
||||||
|
description="A test tool",
|
||||||
|
category=ToolCategory.READ,
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"input": {"type": "string"}},
|
||||||
|
"required": ["input"],
|
||||||
|
},
|
||||||
|
return_schema={"type": "string"},
|
||||||
|
permission_class=PermissionClass.READ,
|
||||||
|
side_effect_scope=SideEffectScope.NONE,
|
||||||
|
requires_confirmation=False,
|
||||||
|
is_streaming=False,
|
||||||
|
tags=["test", "sample"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_executor():
|
||||||
|
"""创建示例执行器"""
|
||||||
|
|
||||||
|
def executor(input: str) -> str:
|
||||||
|
return f"processed: {input}"
|
||||||
|
|
||||||
|
return executor
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolRegistryInit:
|
||||||
|
"""测试 ToolRegistry 初始化"""
|
||||||
|
|
||||||
|
def test_empty_registry(self, registry):
|
||||||
|
assert len(registry) == 0
|
||||||
|
assert list(registry) == []
|
||||||
|
|
||||||
|
def test_empty_contains_false(self, registry):
|
||||||
|
assert "test_tool" not in registry
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolRegistryRegister:
|
||||||
|
"""测试工具注册"""
|
||||||
|
|
||||||
|
def test_register_single_tool(self, registry, sample_manifest, sample_executor):
|
||||||
|
registry.register(sample_manifest, sample_executor)
|
||||||
|
|
||||||
|
assert len(registry) == 1
|
||||||
|
assert "test_tool" in registry
|
||||||
|
|
||||||
|
def test_register_duplicate_raises(self, registry, sample_manifest, sample_executor):
|
||||||
|
registry.register(sample_manifest, sample_executor)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Tool already registered"):
|
||||||
|
registry.register(sample_manifest, sample_executor)
|
||||||
|
|
||||||
|
def test_register_with_hooks(self, registry, sample_manifest, sample_executor):
|
||||||
|
hooks = [
|
||||||
|
HookConfig(name="audit", hook_type="pre_tool_use"),
|
||||||
|
HookConfig(name="security", hook_type="post_tool_use"),
|
||||||
|
]
|
||||||
|
registry.register(sample_manifest, sample_executor, hooks=hooks)
|
||||||
|
|
||||||
|
tool_hooks = registry.get_hooks("test_tool")
|
||||||
|
assert len(tool_hooks) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolRegistryGet:
|
||||||
|
"""测试获取工具"""
|
||||||
|
|
||||||
|
def test_get_existing_tool(self, registry, sample_manifest, sample_executor):
|
||||||
|
registry.register(sample_manifest, sample_executor)
|
||||||
|
|
||||||
|
manifest = registry.get("test_tool")
|
||||||
|
assert manifest is not None
|
||||||
|
assert manifest.name == "test_tool"
|
||||||
|
|
||||||
|
def test_get_nonexistent_tool(self, registry):
|
||||||
|
assert registry.get("nonexistent") is None
|
||||||
|
|
||||||
|
def test_get_executor(self, registry, sample_manifest, sample_executor):
|
||||||
|
registry.register(sample_manifest, sample_executor)
|
||||||
|
|
||||||
|
executor = registry.get_executor("test_tool")
|
||||||
|
assert executor is not None
|
||||||
|
assert executor("hello") == "processed: hello"
|
||||||
|
|
||||||
|
def test_get_nonexistent_executor(self, registry):
|
||||||
|
assert registry.get_executor("nonexistent") is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolRegistryList:
|
||||||
|
"""测试工具列表"""
|
||||||
|
|
||||||
|
def test_list_all_empty(self, registry):
|
||||||
|
assert registry.list_all() == []
|
||||||
|
|
||||||
|
def test_list_all_with_tools(self, registry, sample_manifest, sample_executor):
|
||||||
|
registry.register(sample_manifest, sample_executor)
|
||||||
|
|
||||||
|
manifest2 = ToolManifest(
|
||||||
|
name="test_tool2",
|
||||||
|
description="Another test tool",
|
||||||
|
category=ToolCategory.WRITE,
|
||||||
|
parameters={"type": "object"},
|
||||||
|
return_schema={"type": "string"},
|
||||||
|
permission_class=PermissionClass.WRITE,
|
||||||
|
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||||
|
requires_confirmation=True,
|
||||||
|
)
|
||||||
|
registry.register(manifest2, sample_executor)
|
||||||
|
|
||||||
|
all_tools = registry.list_all()
|
||||||
|
assert len(all_tools) == 2
|
||||||
|
|
||||||
|
def test_list_by_category(self, registry, sample_manifest, sample_executor):
|
||||||
|
registry.register(sample_manifest, sample_executor)
|
||||||
|
|
||||||
|
manifest2 = ToolManifest(
|
||||||
|
name="write_tool",
|
||||||
|
description="A write tool",
|
||||||
|
category=ToolCategory.WRITE,
|
||||||
|
parameters={"type": "object"},
|
||||||
|
return_schema={"type": "string"},
|
||||||
|
permission_class=PermissionClass.WRITE,
|
||||||
|
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||||
|
requires_confirmation=True,
|
||||||
|
)
|
||||||
|
registry.register(manifest2, sample_executor)
|
||||||
|
|
||||||
|
read_tools = registry.list_by_category(ToolCategory.READ)
|
||||||
|
write_tools = registry.list_by_category(ToolCategory.WRITE)
|
||||||
|
|
||||||
|
assert len(read_tools) == 1
|
||||||
|
assert len(write_tools) == 1
|
||||||
|
|
||||||
|
def test_list_by_permission(self, registry, sample_manifest, sample_executor):
|
||||||
|
registry.register(sample_manifest, sample_executor)
|
||||||
|
|
||||||
|
read_tools = registry.list_by_permission(PermissionClass.READ)
|
||||||
|
write_tools = registry.list_by_permission(PermissionClass.WRITE)
|
||||||
|
|
||||||
|
assert len(read_tools) == 1
|
||||||
|
assert len(write_tools) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolRegistrySearch:
|
||||||
|
"""测试工具搜索"""
|
||||||
|
|
||||||
|
def test_search_by_tag(self, registry, sample_manifest, sample_executor):
|
||||||
|
registry.register(sample_manifest, sample_executor)
|
||||||
|
|
||||||
|
results = registry.search_by_tag("test")
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].name == "test_tool"
|
||||||
|
|
||||||
|
results = registry.search_by_tag("nonexistent")
|
||||||
|
assert len(results) == 0
|
||||||
|
|
||||||
|
def test_search_by_name(self, registry, sample_manifest, sample_executor):
|
||||||
|
registry.register(sample_manifest, sample_executor)
|
||||||
|
|
||||||
|
results = registry.search_by_name("test")
|
||||||
|
assert len(results) == 1
|
||||||
|
|
||||||
|
results = registry.search_by_name("tool")
|
||||||
|
assert len(results) == 1
|
||||||
|
|
||||||
|
results = registry.search_by_name("xyz")
|
||||||
|
assert len(results) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolRegistryUtility:
|
||||||
|
"""测试工具方法"""
|
||||||
|
|
||||||
|
def test_requires_confirmation_true(self, registry, sample_manifest, sample_executor):
|
||||||
|
sample_manifest.requires_confirmation = True
|
||||||
|
registry.register(sample_manifest, sample_executor)
|
||||||
|
|
||||||
|
assert registry.get_requires_confirmation("test_tool") is True
|
||||||
|
|
||||||
|
def test_requires_confirmation_false(self, registry, sample_manifest, sample_executor):
|
||||||
|
registry.register(sample_manifest, sample_executor)
|
||||||
|
|
||||||
|
assert registry.get_requires_confirmation("test_tool") is False
|
||||||
|
|
||||||
|
def test_requires_confirmation_nonexistent(self, registry):
|
||||||
|
assert registry.get_requires_confirmation("nonexistent") is False
|
||||||
|
|
||||||
|
def test_is_streaming_true(self, registry, sample_manifest, sample_executor):
|
||||||
|
sample_manifest.is_streaming = True
|
||||||
|
registry.register(sample_manifest, sample_executor)
|
||||||
|
|
||||||
|
assert registry.get_is_streaming("test_tool") is True
|
||||||
|
|
||||||
|
def test_is_streaming_false(self, registry, sample_manifest, sample_executor):
|
||||||
|
registry.register(sample_manifest, sample_executor)
|
||||||
|
|
||||||
|
assert registry.get_is_streaming("test_tool") is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolRegistryUnregister:
|
||||||
|
"""测试工具注销"""
|
||||||
|
|
||||||
|
def test_unregister_existing(self, registry, sample_manifest, sample_executor):
|
||||||
|
registry.register(sample_manifest, sample_executor)
|
||||||
|
|
||||||
|
assert registry.unregister("test_tool") is True
|
||||||
|
assert len(registry) == 0
|
||||||
|
assert "test_tool" not in registry
|
||||||
|
|
||||||
|
def test_unregister_nonexistent(self, registry):
|
||||||
|
assert registry.unregister("nonexistent") is False
|
||||||
|
|
||||||
|
def test_unregister_clears_hooks(self, registry, sample_manifest, sample_executor):
|
||||||
|
hooks = [HookConfig(name="test", hook_type="pre_tool_use")]
|
||||||
|
registry.register(sample_manifest, sample_executor, hooks=hooks)
|
||||||
|
|
||||||
|
registry.unregister("test_tool")
|
||||||
|
|
||||||
|
assert registry.get_hooks("test_tool") == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolRegistryClear:
|
||||||
|
"""测试清空注册表"""
|
||||||
|
|
||||||
|
def test_clear(self, registry, sample_manifest, sample_executor):
|
||||||
|
registry.register(sample_manifest, sample_executor)
|
||||||
|
|
||||||
|
manifest2 = ToolManifest(
|
||||||
|
name="tool2",
|
||||||
|
description="Another tool",
|
||||||
|
category=ToolCategory.READ,
|
||||||
|
parameters={"type": "object"},
|
||||||
|
return_schema={"type": "string"},
|
||||||
|
permission_class=PermissionClass.READ,
|
||||||
|
side_effect_scope=SideEffectScope.NONE,
|
||||||
|
requires_confirmation=False,
|
||||||
|
)
|
||||||
|
registry.register(manifest2, sample_executor)
|
||||||
|
|
||||||
|
registry.clear()
|
||||||
|
|
||||||
|
assert len(registry) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestGlobalRegistry:
|
||||||
|
"""测试全局注册表"""
|
||||||
|
|
||||||
|
def test_get_tool_registry_returns_singleton(self):
|
||||||
|
reset_tool_registry()
|
||||||
|
|
||||||
|
reg1 = get_tool_registry()
|
||||||
|
reg2 = get_tool_registry()
|
||||||
|
|
||||||
|
assert reg1 is reg2
|
||||||
|
|
||||||
|
def test_reset_tool_registry(self):
|
||||||
|
reset_tool_registry()
|
||||||
|
reg1 = get_tool_registry()
|
||||||
|
|
||||||
|
reset_tool_registry()
|
||||||
|
reg2 = get_tool_registry()
|
||||||
|
|
||||||
|
# After reset, should get a new instance
|
||||||
|
# (Note: this tests the reset function works)
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolManifest:
|
||||||
|
"""测试 ToolManifest"""
|
||||||
|
|
||||||
|
def test_to_dict(self, sample_manifest):
|
||||||
|
d = sample_manifest.to_dict()
|
||||||
|
|
||||||
|
assert d["name"] == "test_tool"
|
||||||
|
assert d["description"] == "A test tool"
|
||||||
|
assert d["category"] == "read"
|
||||||
|
assert d["permission_class"] == "read"
|
||||||
|
assert d["side_effect_scope"] == "none"
|
||||||
|
assert d["requires_confirmation"] is False
|
||||||
|
assert d["is_streaming"] is False
|
||||||
|
assert d["tags"] == ["test", "sample"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestHookConfig:
|
||||||
|
"""测试 HookConfig"""
|
||||||
|
|
||||||
|
def test_matches_tool_with_no_filter(self):
|
||||||
|
hook = HookConfig(name="test", hook_type="pre_tool_use")
|
||||||
|
|
||||||
|
assert hook.matches_tool("any_tool") is True
|
||||||
|
|
||||||
|
def test_matches_tool_with_filter(self):
|
||||||
|
hook = HookConfig(name="test", hook_type="pre_tool_use", filter_names=["tool_a", "tool_b"])
|
||||||
|
|
||||||
|
assert hook.matches_tool("tool_a") is True
|
||||||
|
assert hook.matches_tool("tool_b") is True
|
||||||
|
assert hook.matches_tool("tool_c") is False
|
||||||
Reference in New Issue
Block a user