108 lines
3.1 KiB
Python
108 lines
3.1 KiB
Python
"""
|
||
工具注册表 - 管理所有可用工具(白名单机制)
|
||
"""
|
||
from typing import Any, Callable, Optional, Dict
|
||
from dataclasses import dataclass, asdict
|
||
from enum import Enum
|
||
|
||
|
||
class SecurityLevel(Enum):
|
||
"""工具安全等级"""
|
||
SAFE = "safe" # 安全操作
|
||
REVIEW = "review" # 需要审核
|
||
DANGER = "danger" # 危险操作
|
||
|
||
|
||
@dataclass
|
||
class ToolMetadata:
|
||
"""工具元数据"""
|
||
name: str
|
||
description: str
|
||
security_level: str
|
||
require_approval: bool = False
|
||
allowed_roles: list = None
|
||
|
||
def dict(self):
|
||
return {
|
||
"name": self.name,
|
||
"description": self.description,
|
||
"security_level": self.security_level,
|
||
"require_approval": self.require_approval
|
||
}
|
||
|
||
|
||
class ToolRegistry:
|
||
"""工具注册表"""
|
||
|
||
def __init__(self):
|
||
self._tools: Dict[str, tuple[Callable, ToolMetadata]] = {}
|
||
self._definitions: Dict[str, dict] = {}
|
||
|
||
def register(
|
||
self,
|
||
name: str,
|
||
func: Callable,
|
||
description: str = "",
|
||
security_level: str = "safe",
|
||
require_approval: bool = False,
|
||
allowed_roles: list = None,
|
||
parameters: dict = None
|
||
):
|
||
"""注册工具到白名单"""
|
||
metadata = ToolMetadata(
|
||
name=name,
|
||
description=description,
|
||
security_level=security_level,
|
||
require_approval=require_approval,
|
||
allowed_roles=allowed_roles or ["user", "admin"]
|
||
)
|
||
|
||
self._tools[name] = (func, metadata)
|
||
|
||
# 生成工具定义(用于 LLM 调用)
|
||
self._definitions[name] = {
|
||
"name": name,
|
||
"description": description,
|
||
"parameters": parameters or {
|
||
"type": "object",
|
||
"properties": {},
|
||
"required": []
|
||
}
|
||
}
|
||
|
||
def get_tool(self, name: str) -> tuple[Callable, ToolMetadata]:
|
||
"""获取工具函数和元数据"""
|
||
if name not in self._tools:
|
||
raise ValueError(f"Tool '{name}' not found in whitelist")
|
||
return self._tools[name]
|
||
|
||
def get_tool_definition(self, name: str) -> Optional[dict]:
|
||
"""获取工具定义(用于 LLM)"""
|
||
return self._definitions.get(name)
|
||
|
||
def list_tools(self) -> list[ToolMetadata]:
|
||
"""列出所有已注册工具"""
|
||
return [meta for _, meta in self._tools.values()]
|
||
|
||
def list_definitions(self) -> list[dict]:
|
||
"""列出所有工具定义(用于LLM)"""
|
||
return list(self._definitions.values())
|
||
|
||
def check_permission(self, tool_name: str, user_role: str) -> bool:
|
||
"""检查用户权限"""
|
||
if tool_name not in self._tools:
|
||
return False
|
||
_, metadata = self._tools[tool_name]
|
||
return user_role in metadata.allowed_roles
|
||
|
||
def need_approval(self, tool_name: str) -> bool:
|
||
"""判断是否需要审批"""
|
||
if tool_name not in self._tools:
|
||
return False
|
||
_, metadata = self._tools[tool_name]
|
||
return metadata.require_approval
|
||
|
||
|
||
# 全局工具注册表
|
||
global_registry = ToolRegistry()
|