108 lines
3.1 KiB
Python
108 lines
3.1 KiB
Python
|
|
"""
|
|||
|
|
工具注册表 - 管理所有可用工具(白名单机制)
|
|||
|
|
"""
|
|||
|
|
from typing import Any, Callable, Optional, Dict
|
|||
|
|
from dataclasses import dataclass, asdict
|
|||
|
|
from enum import Enum
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SecurityLevel(Enum):
|
|||
|
|
"""工具安全等级"""
|
|||
|
|
SAFE = "safe" # 安全操作
|
|||
|
|
REVIEW = "review" # 需要审核
|
|||
|
|
DANGER = "danger" # 危险操作
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ToolMetadata:
|
|||
|
|
"""工具元数据"""
|
|||
|
|
name: str
|
|||
|
|
description: str
|
|||
|
|
security_level: str
|
|||
|
|
require_approval: bool = False
|
|||
|
|
allowed_roles: list = None
|
|||
|
|
|
|||
|
|
def dict(self):
|
|||
|
|
return {
|
|||
|
|
"name": self.name,
|
|||
|
|
"description": self.description,
|
|||
|
|
"security_level": self.security_level,
|
|||
|
|
"require_approval": self.require_approval
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ToolRegistry:
|
|||
|
|
"""工具注册表"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
self._tools: Dict[str, tuple[Callable, ToolMetadata]] = {}
|
|||
|
|
self._definitions: Dict[str, dict] = {}
|
|||
|
|
|
|||
|
|
def register(
|
|||
|
|
self,
|
|||
|
|
name: str,
|
|||
|
|
func: Callable,
|
|||
|
|
description: str = "",
|
|||
|
|
security_level: str = "safe",
|
|||
|
|
require_approval: bool = False,
|
|||
|
|
allowed_roles: list = None,
|
|||
|
|
parameters: dict = None
|
|||
|
|
):
|
|||
|
|
"""注册工具到白名单"""
|
|||
|
|
metadata = ToolMetadata(
|
|||
|
|
name=name,
|
|||
|
|
description=description,
|
|||
|
|
security_level=security_level,
|
|||
|
|
require_approval=require_approval,
|
|||
|
|
allowed_roles=allowed_roles or ["user", "admin"]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
self._tools[name] = (func, metadata)
|
|||
|
|
|
|||
|
|
# 生成工具定义(用于 LLM 调用)
|
|||
|
|
self._definitions[name] = {
|
|||
|
|
"name": name,
|
|||
|
|
"description": description,
|
|||
|
|
"parameters": parameters or {
|
|||
|
|
"type": "object",
|
|||
|
|
"properties": {},
|
|||
|
|
"required": []
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def get_tool(self, name: str) -> tuple[Callable, ToolMetadata]:
|
|||
|
|
"""获取工具函数和元数据"""
|
|||
|
|
if name not in self._tools:
|
|||
|
|
raise ValueError(f"Tool '{name}' not found in whitelist")
|
|||
|
|
return self._tools[name]
|
|||
|
|
|
|||
|
|
def get_tool_definition(self, name: str) -> Optional[dict]:
|
|||
|
|
"""获取工具定义(用于 LLM)"""
|
|||
|
|
return self._definitions.get(name)
|
|||
|
|
|
|||
|
|
def list_tools(self) -> list[ToolMetadata]:
|
|||
|
|
"""列出所有已注册工具"""
|
|||
|
|
return [meta for _, meta in self._tools.values()]
|
|||
|
|
|
|||
|
|
def list_definitions(self) -> list[dict]:
|
|||
|
|
"""列出所有工具定义(用于LLM)"""
|
|||
|
|
return list(self._definitions.values())
|
|||
|
|
|
|||
|
|
def check_permission(self, tool_name: str, user_role: str) -> bool:
|
|||
|
|
"""检查用户权限"""
|
|||
|
|
if tool_name not in self._tools:
|
|||
|
|
return False
|
|||
|
|
_, metadata = self._tools[tool_name]
|
|||
|
|
return user_role in metadata.allowed_roles
|
|||
|
|
|
|||
|
|
def need_approval(self, tool_name: str) -> bool:
|
|||
|
|
"""判断是否需要审批"""
|
|||
|
|
if tool_name not in self._tools:
|
|||
|
|
return False
|
|||
|
|
_, metadata = self._tools[tool_name]
|
|||
|
|
return metadata.require_approval
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局工具注册表
|
|||
|
|
global_registry = ToolRegistry()
|