feat(agents): Phase 7-10 hook system, plugins, skills, orchestration
Phase 7: Built-in Hooks (audit_log, dangerous_confirmation, security_scan) Phase 8: Plugin system (PluginManager, PluginSandbox, PluginManifest) Phase 9: Skills registry (SkillRegistry, local/plugin/MCP loaders) Phase 10: TeamLeader, RemoteTransport, BackgroundTaskManager
This commit is contained in:
119
backend/app/agents/background/manager.py
Normal file
119
backend/app/agents/background/manager.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""后台任务系统 - Phase 10.4"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class BackgroundTaskStatus(Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackgroundTask:
|
||||
"""后台任务"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
status: BackgroundTaskStatus
|
||||
created_at: datetime
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class BackgroundTaskManager:
|
||||
"""后台任务管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self._tasks: dict[str, BackgroundTask] = {}
|
||||
self._.coroutines: dict[str, asyncio.Task] = {}
|
||||
|
||||
def submit_task(self, name: str, coro: Any, *args, **kwargs) -> str:
|
||||
"""提交后台任务
|
||||
|
||||
Args:
|
||||
name: 任务名称
|
||||
coro: 协程函数
|
||||
*args: 位置参数
|
||||
**kwargs: 关键字参数
|
||||
|
||||
Returns:
|
||||
任务 ID
|
||||
"""
|
||||
task_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# 创建任务记录
|
||||
self._tasks[task_id] = BackgroundTask(
|
||||
id=task_id,
|
||||
name=name,
|
||||
status=BackgroundTaskStatus.PENDING,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
# 创建 asyncio task
|
||||
async def run_task():
|
||||
self._tasks[task_id].status = BackgroundTaskStatus.RUNNING
|
||||
self._tasks[task_id].started_at = datetime.now()
|
||||
try:
|
||||
result = await coro(*args, **kwargs)
|
||||
self._tasks[task_id].status = BackgroundTaskStatus.COMPLETED
|
||||
self._tasks[task_id].result = result
|
||||
except Exception as e:
|
||||
self._tasks[task_id].status = BackgroundTaskStatus.FAILED
|
||||
self._tasks[task_id].error = str(e)
|
||||
finally:
|
||||
self._tasks[task_id].completed_at = datetime.now()
|
||||
if task_id in self._coroutines:
|
||||
del self._coroutines[task_id]
|
||||
|
||||
self._coroutines[task_id] = asyncio.create_task(run_task())
|
||||
return task_id
|
||||
|
||||
def cancel_task(self, task_id: str) -> bool:
|
||||
"""取消任务
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
|
||||
Returns:
|
||||
是否成功取消
|
||||
"""
|
||||
if task_id not in self._tasks:
|
||||
return False
|
||||
|
||||
if task_id in self._coroutines:
|
||||
self._coroutines[task_id].cancel()
|
||||
del self._coroutines[task_id]
|
||||
|
||||
self._tasks[task_id].status = BackgroundTaskStatus.CANCELLED
|
||||
self._tasks[task_id].completed_at = datetime.now()
|
||||
return True
|
||||
|
||||
def get_task_status(self, task_id: str) -> BackgroundTask | None:
|
||||
"""获取任务状态"""
|
||||
return self._tasks.get(task_id)
|
||||
|
||||
def list_tasks(self) -> list[BackgroundTask]:
|
||||
"""列出所有任务"""
|
||||
return list(self._tasks.values())
|
||||
|
||||
|
||||
# 全局单例
|
||||
_manager: BackgroundTaskManager | None = None
|
||||
|
||||
|
||||
def get_background_task_manager() -> BackgroundTaskManager:
|
||||
"""获取全局后台任务管理器"""
|
||||
global _manager
|
||||
if _manager is None:
|
||||
_manager = BackgroundTaskManager()
|
||||
return _manager
|
||||
20
backend/app/agents/orchestration/__init__.py
Normal file
20
backend/app/agents/orchestration/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""高级编排系统 - Phase 10"""
|
||||
|
||||
from app.agents.team.leader import TeamLeader, TeamTask, TaskStatus
|
||||
from app.agents.transport.remote import RemoteTransport, StructuredMessage
|
||||
from app.agents.background.manager import (
|
||||
BackgroundTaskManager,
|
||||
BackgroundTask,
|
||||
get_background_task_manager,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TeamLeader",
|
||||
"TeamTask",
|
||||
"TaskStatus",
|
||||
"RemoteTransport",
|
||||
"StructuredMessage",
|
||||
"BackgroundTaskManager",
|
||||
"BackgroundTask",
|
||||
"get_background_task_manager",
|
||||
]
|
||||
12
backend/app/agents/plugins/__init__.py
Normal file
12
backend/app/agents/plugins/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""插件系统 - Phase 8"""
|
||||
|
||||
from app.agents.plugins.manager import PluginManager, get_plugin_manager
|
||||
from app.agents.plugins.manifest import PluginManifest
|
||||
from app.agents.plugins.sandbox import PluginSandbox
|
||||
|
||||
__all__ = [
|
||||
"PluginManager",
|
||||
"PluginManifest",
|
||||
"PluginSandbox",
|
||||
"get_plugin_manager",
|
||||
]
|
||||
207
backend/app/agents/plugins/manager.py
Normal file
207
backend/app/agents/plugins/manager.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""插件管理器 - Phase 8.2"""
|
||||
|
||||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from app.agents.plugins.manifest import PluginManifest
|
||||
from app.agents.plugins.sandbox import PluginSandbox
|
||||
|
||||
|
||||
class PluginManager:
|
||||
"""插件管理器
|
||||
|
||||
负责插件的安装、卸载、启用、禁用和生命周期管理。
|
||||
"""
|
||||
|
||||
def __init__(self, plugins_dir: str | None = None):
|
||||
"""
|
||||
Args:
|
||||
plugins_dir: 插件目录,None 则使用默认目录
|
||||
"""
|
||||
if plugins_dir is None:
|
||||
plugins_dir = os.path.join(os.path.dirname(__file__), "..", "..", "..", "plugins")
|
||||
self.plugins_dir = plugins_dir
|
||||
self._plugins: dict[str, PluginManifest] = {}
|
||||
self._enabled: dict[str, bool] = {}
|
||||
self._modules: dict[str, Any] = {}
|
||||
self._sandbox = PluginSandbox()
|
||||
|
||||
def install(self, plugin_path: str) -> bool:
|
||||
"""安装插件
|
||||
|
||||
Args:
|
||||
plugin_path: 插件目录路径或 manifest.json 所在目录
|
||||
|
||||
Returns:
|
||||
是否安装成功
|
||||
"""
|
||||
try:
|
||||
manifest_path = os.path.join(plugin_path, "manifest.json")
|
||||
|
||||
if not os.path.exists(manifest_path):
|
||||
return False
|
||||
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
import json
|
||||
|
||||
data = json.load(f)
|
||||
|
||||
manifest = PluginManifest.from_dict(data)
|
||||
|
||||
# 验证 manifest
|
||||
if not self._validate_manifest(manifest, plugin_path):
|
||||
return False
|
||||
|
||||
# 复制插件到 plugins_dir
|
||||
target_dir = os.path.join(self.plugins_dir, manifest.id)
|
||||
os.makedirs(os.path.dirname(target_dir), exist_ok=True)
|
||||
|
||||
# 保存 manifest
|
||||
with open(os.path.join(target_dir, "manifest.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(manifest.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
|
||||
# 注册插件
|
||||
self._plugins[manifest.id] = manifest
|
||||
self._enabled[manifest.id] = True
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def uninstall(self, plugin_id: str) -> bool:
|
||||
"""卸载插件
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID
|
||||
|
||||
Returns:
|
||||
是否卸载成功
|
||||
"""
|
||||
if plugin_id not in self._plugins:
|
||||
return False
|
||||
|
||||
# 禁用插件
|
||||
self.disable(plugin_id)
|
||||
|
||||
# 移除模块
|
||||
if plugin_id in self._modules:
|
||||
del self._modules[plugin_id]
|
||||
|
||||
# 移除插件
|
||||
del self._plugins[plugin_id]
|
||||
del self._enabled[plugin_id]
|
||||
|
||||
# 删除目录
|
||||
plugin_dir = os.path.join(self.plugins_dir, plugin_id)
|
||||
if os.path.exists(plugin_dir):
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(plugin_dir)
|
||||
|
||||
return True
|
||||
|
||||
def enable(self, plugin_id: str) -> bool:
|
||||
"""启用插件
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID
|
||||
|
||||
Returns:
|
||||
是否启用成功
|
||||
"""
|
||||
if plugin_id not in self._plugins:
|
||||
return False
|
||||
|
||||
self._enabled[plugin_id] = True
|
||||
return True
|
||||
|
||||
def disable(self, plugin_id: str) -> bool:
|
||||
"""禁用插件
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID
|
||||
|
||||
Returns:
|
||||
是否禁用成功
|
||||
"""
|
||||
if plugin_id not in self._plugins:
|
||||
return False
|
||||
|
||||
self._enabled[plugin_id] = False
|
||||
return True
|
||||
|
||||
def reload(self, plugin_id: str) -> bool:
|
||||
"""重新加载插件
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID
|
||||
|
||||
Returns:
|
||||
是否重新加载成功
|
||||
"""
|
||||
if plugin_id not in self._plugins:
|
||||
return False
|
||||
|
||||
# 卸载模块
|
||||
if plugin_id in self._modules:
|
||||
del self._modules[plugin_id]
|
||||
|
||||
# 重新加载
|
||||
return self._load_plugin_module(plugin_id)
|
||||
|
||||
def list_plugins(self) -> list[PluginManifest]:
|
||||
"""列出所有插件"""
|
||||
return list(self._plugins.values())
|
||||
|
||||
def get_plugin(self, plugin_id: str) -> PluginManifest | None:
|
||||
"""获取插件清单"""
|
||||
return self._plugins.get(plugin_id)
|
||||
|
||||
def is_enabled(self, plugin_id: str) -> bool:
|
||||
"""检查插件是否启用"""
|
||||
return self._enabled.get(plugin_id, False)
|
||||
|
||||
def _validate_manifest(self, manifest: PluginManifest, plugin_path: str) -> bool:
|
||||
"""验证 manifest"""
|
||||
# 检查主入口文件是否存在
|
||||
main_path = os.path.join(plugin_path, manifest.main)
|
||||
if not os.path.exists(main_path):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _load_plugin_module(self, plugin_id: str) -> bool:
|
||||
"""加载插件模块"""
|
||||
plugin_dir = os.path.join(self.plugins_dir, plugin_id)
|
||||
manifest = self._plugins.get(plugin_id)
|
||||
if not manifest:
|
||||
return False
|
||||
|
||||
try:
|
||||
main_path = os.path.join(plugin_dir, manifest.main)
|
||||
spec = importlib.util.spec_from_file_location(plugin_id, main_path)
|
||||
if spec and spec.loader:
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[plugin_id] = module
|
||||
spec.loader.exec_module(module)
|
||||
self._modules[plugin_id] = module
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# 全局单例
|
||||
_manager: PluginManager | None = None
|
||||
|
||||
|
||||
def get_plugin_manager() -> PluginManager:
|
||||
"""获取全局插件管理器"""
|
||||
global _manager
|
||||
if _manager is None:
|
||||
_manager = PluginManager()
|
||||
return _manager
|
||||
73
backend/app/agents/plugins/manifest.py
Normal file
73
backend/app/agents/plugins/manifest.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""插件清单定义 - Phase 8.1"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginManifest:
|
||||
"""插件清单
|
||||
|
||||
定义插件的元数据和接口。
|
||||
"""
|
||||
|
||||
id: str # 唯一标识
|
||||
name: str # 显示名称
|
||||
version: str # 版本号
|
||||
description: str # 描述
|
||||
author: str = "" # 作者
|
||||
homepage: str = "" # 主页
|
||||
license: str = "MIT" # 许可证
|
||||
|
||||
# 插件类型
|
||||
plugin_type: str = "tool" # tool, hook, skill, all
|
||||
|
||||
# 入口点
|
||||
main: str = "index.py" # 主入口文件
|
||||
hooks: list[str] = field(default_factory=list) # 提供的 Hook 列表
|
||||
tools: list[str] = field(default_factory=list) # 提供的工具列表
|
||||
skills: list[str] = field(default_factory=list) # 提供的 Skills 列表
|
||||
|
||||
# 依赖
|
||||
dependencies: dict[str, str] = field(default_factory=dict) # pip 依赖
|
||||
peer_dependencies: dict[str, str] = field(default_factory=dict) # 对等依赖
|
||||
|
||||
# 权限要求
|
||||
permissions: list[str] = field(default_factory=list) # 需要的权限
|
||||
allowed_paths: list[str] = field(default_factory=list) # 允许访问的路径
|
||||
denied_paths: list[str] = field(default_factory=list) # 禁止访问的路径
|
||||
|
||||
# 网络权限
|
||||
network_allowed: bool = False # 是否允许网络访问
|
||||
allowed_hosts: list[str] = field(default_factory=list) # 允许访问的 host
|
||||
|
||||
# 配置
|
||||
config_schema: dict[str, Any] = field(default_factory=dict) # 配置 schema
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"version": self.version,
|
||||
"description": self.description,
|
||||
"author": self.author,
|
||||
"homepage": self.homepage,
|
||||
"license": self.license,
|
||||
"plugin_type": self.plugin_type,
|
||||
"main": self.main,
|
||||
"hooks": self.hooks,
|
||||
"tools": self.tools,
|
||||
"skills": self.skills,
|
||||
"dependencies": self.dependencies,
|
||||
"peer_dependencies": self.peer_dependencies,
|
||||
"permissions": self.permissions,
|
||||
"allowed_paths": self.allowed_paths,
|
||||
"denied_paths": self.denied_paths,
|
||||
"network_allowed": self.network_allowed,
|
||||
"allowed_hosts": self.allowed_hosts,
|
||||
"config_schema": self.config_schema,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "PluginManifest":
|
||||
return cls(**data)
|
||||
111
backend/app/agents/plugins/sandbox.py
Normal file
111
backend/app/agents/plugins/sandbox.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""插件沙箱隔离 - Phase 8.3"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
|
||||
class PluginSandbox:
|
||||
"""插件沙箱
|
||||
|
||||
提供插件执行隔离环境。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._allowed_paths: set[str] = set()
|
||||
self._denied_paths: set[str] = set()
|
||||
self._network_allowed: bool = False
|
||||
self._allowed_hosts: set[str] = set()
|
||||
|
||||
def set_file_permissions(
|
||||
self,
|
||||
allowed_paths: list[str] | None = None,
|
||||
denied_paths: list[str] | None = None,
|
||||
) -> None:
|
||||
"""设置文件访问权限
|
||||
|
||||
Args:
|
||||
allowed_paths: 允许访问的路径列表
|
||||
denied_paths: 禁止访问的路径列表
|
||||
"""
|
||||
self._allowed_paths = set(allowed_paths or [])
|
||||
self._denied_paths = set(denied_paths or [])
|
||||
|
||||
def set_network_permissions(
|
||||
self, allowed: bool, allowed_hosts: list[str] | None = None
|
||||
) -> None:
|
||||
"""设置网络访问权限
|
||||
|
||||
Args:
|
||||
allowed: 是否允许网络访问
|
||||
allowed_hosts: 允许访问的 host 列表
|
||||
"""
|
||||
self._network_allowed = allowed
|
||||
self._allowed_hosts = set(allowed_hosts or [])
|
||||
|
||||
def check_file_access(self, path: str) -> bool:
|
||||
"""检查文件访问权限
|
||||
|
||||
Args:
|
||||
path: 文件路径
|
||||
|
||||
Returns:
|
||||
是否允许访问
|
||||
"""
|
||||
# 如果有允许列表,只允许访问列表中的路径
|
||||
if self._allowed_paths:
|
||||
return path in self._allowed_paths or any(
|
||||
path.startswith(allowed) for allowed in self._allowed_paths
|
||||
)
|
||||
|
||||
# 如果有禁止列表,禁止访问列表中的路径
|
||||
if self._denied_paths:
|
||||
return not any(path.startswith(denied) for denied in self._denied_paths)
|
||||
|
||||
# 没有限制
|
||||
return True
|
||||
|
||||
def check_network_access(self, host: str) -> bool:
|
||||
"""检查网络访问权限
|
||||
|
||||
Args:
|
||||
host: 主机地址
|
||||
|
||||
Returns:
|
||||
是否允许访问
|
||||
"""
|
||||
if not self._network_allowed:
|
||||
return False
|
||||
|
||||
if self._allowed_hosts:
|
||||
return host in self._allowed_hosts or any(
|
||||
host.endswith(allowed) for allowed in self._allowed_hosts
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def execute_in_sandbox(self, func: Any, *args, **kwargs) -> Any:
|
||||
"""在沙箱中执行函数
|
||||
|
||||
Args:
|
||||
func: 要执行的函数
|
||||
*args: 位置参数
|
||||
**kwargs: 关键字参数
|
||||
|
||||
Returns:
|
||||
函数返回值
|
||||
"""
|
||||
# 保存当前状态
|
||||
old_allowed_paths = self._allowed_paths.copy()
|
||||
old_denied_paths = self._denied_paths.copy()
|
||||
old_network_allowed = self._network_allowed
|
||||
old_allowed_hosts = self._allowed_hosts.copy()
|
||||
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
finally:
|
||||
# 恢复状态
|
||||
self._allowed_paths = old_allowed_paths
|
||||
self._denied_paths = old_denied_paths
|
||||
self._network_allowed = old_network_allowed
|
||||
self._allowed_hosts = old_allowed_hosts
|
||||
16
backend/app/agents/skills/__init__.py
Normal file
16
backend/app/agents/skills/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Skills 注册表 - Phase 9"""
|
||||
|
||||
from app.agents.skills.registry import SkillRegistry, get_skill_registry
|
||||
from app.agents.skills.metadata import SkillMetadata
|
||||
from app.agents.skills.loaders.local_loader import LocalSkillLoader
|
||||
from app.agents.skills.loaders.plugin_loader import PluginSkillLoader
|
||||
from app.agents.skills.mcp_builder import MCPSkillBuilder
|
||||
|
||||
__all__ = [
|
||||
"SkillRegistry",
|
||||
"SkillMetadata",
|
||||
"LocalSkillLoader",
|
||||
"PluginSkillLoader",
|
||||
"MCPSkillBuilder",
|
||||
"get_skill_registry",
|
||||
]
|
||||
100
backend/app/agents/skills/loaders/local_loader.py
Normal file
100
backend/app/agents/skills/loaders/local_loader.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""本地 Skills 加载器 - Phase 9.2"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.agents.skills.metadata import SkillMetadata
|
||||
|
||||
|
||||
class LocalSkillLoader:
|
||||
"""本地 Skills 加载器
|
||||
|
||||
从 skills_dir 目录加载 SKILL.md 文件。
|
||||
"""
|
||||
|
||||
def __init__(self, skills_dir: str):
|
||||
self.skills_dir = skills_dir
|
||||
|
||||
def load_all(self) -> list[SkillMetadata]:
|
||||
"""加载所有本地 Skills
|
||||
|
||||
Returns:
|
||||
Skill 元数据列表
|
||||
"""
|
||||
skills = []
|
||||
|
||||
if not os.path.exists(self.skills_dir):
|
||||
return skills
|
||||
|
||||
for root, dirs, files in os.walk(self.skills_dir):
|
||||
# 跳过隐藏目录
|
||||
dirs[:] = [d for d in dirs if not d.startswith(".")]
|
||||
|
||||
if "SKILL.md" in files:
|
||||
skill = self._load_skill_from_dir(root)
|
||||
if skill:
|
||||
skills.append(skill)
|
||||
|
||||
return skills
|
||||
|
||||
def _load_skill_from_dir(self, skill_dir: str) -> SkillMetadata | None:
|
||||
"""从目录加载 Skill
|
||||
|
||||
Args:
|
||||
skill_dir: Skill 目录
|
||||
|
||||
Returns:
|
||||
Skill 元数据
|
||||
"""
|
||||
skill_path = os.path.join(skill_dir, "SKILL.md")
|
||||
|
||||
try:
|
||||
with open(skill_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# 解析 frontmatter
|
||||
metadata = self._parse_frontmatter(content)
|
||||
|
||||
# 获取 Skill 名称(目录名)
|
||||
name = os.path.basename(skill_dir)
|
||||
|
||||
return SkillMetadata(
|
||||
name=metadata.get("name", name),
|
||||
description=metadata.get("description", ""),
|
||||
version=metadata.get("version", "1.0.0"),
|
||||
author=metadata.get("author", ""),
|
||||
tags=metadata.get("tags", []),
|
||||
triggers=metadata.get("triggers", []),
|
||||
content=content,
|
||||
source="local",
|
||||
source_id=skill_dir,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _parse_frontmatter(self, content: str) -> dict[str, Any]:
|
||||
"""解析 frontmatter"""
|
||||
metadata = {}
|
||||
|
||||
# 匹配 --- 包裹的 frontmatter
|
||||
match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL)
|
||||
if match:
|
||||
frontmatter = match.group(1)
|
||||
|
||||
for line in frontmatter.split("\n"):
|
||||
if ":" in line:
|
||||
key, value = line.split(":", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
# 处理列表
|
||||
if value.startswith("[") and value.endswith("]"):
|
||||
value = [v.strip().strip('"').strip("'") for v in value[1:-1].split(",")]
|
||||
elif value.lower() in ("true", "false"):
|
||||
value = value.lower() == "true"
|
||||
|
||||
metadata[key] = value
|
||||
|
||||
return metadata
|
||||
51
backend/app/agents/skills/loaders/plugin_loader.py
Normal file
51
backend/app/agents/skills/loaders/plugin_loader.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""插件 Skills 加载器 - Phase 9.2"""
|
||||
|
||||
from app.agents.skills.metadata import SkillMetadata
|
||||
from app.agents.plugins.manager import get_plugin_manager
|
||||
|
||||
|
||||
class PluginSkillLoader:
|
||||
"""插件 Skills 加载器
|
||||
|
||||
从已安装的插件中加载 Skills。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.plugin_manager = get_plugin_manager()
|
||||
|
||||
def load_all(self) -> list[SkillMetadata]:
|
||||
"""从所有已启用的插件加载 Skills
|
||||
|
||||
Returns:
|
||||
Skill 元数据列表
|
||||
"""
|
||||
skills = []
|
||||
|
||||
for plugin in self.plugin_manager.list_plugins():
|
||||
if not self.plugin_manager.is_enabled(plugin.id):
|
||||
continue
|
||||
|
||||
# 从插件加载 Skills
|
||||
plugin_skills = self._load_from_plugin(plugin)
|
||||
skills.extend(plugin_skills)
|
||||
|
||||
return skills
|
||||
|
||||
def _load_from_plugin(self, plugin: Any) -> list[SkillMetadata]:
|
||||
"""从单个插件加载 Skills"""
|
||||
skills = []
|
||||
|
||||
for skill_name in plugin.skills:
|
||||
skill = SkillMetadata(
|
||||
name=f"{plugin.id}/{skill_name}",
|
||||
description=f"Skill from plugin: {plugin.name}",
|
||||
version=plugin.version,
|
||||
author=plugin.author,
|
||||
tags=["plugin", plugin.id],
|
||||
content=f"# {skill_name}\n\nFrom plugin: {plugin.name}",
|
||||
source="plugin",
|
||||
source_id=plugin.id,
|
||||
)
|
||||
skills.append(skill)
|
||||
|
||||
return skills
|
||||
100
backend/app/agents/skills/mcp_builder.py
Normal file
100
backend/app/agents/skills/mcp_builder.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""MCP Skill Builder - Phase 9.3"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.skills.metadata import SkillMetadata
|
||||
|
||||
|
||||
class MCPSkillBuilder:
|
||||
"""MCP Skill Builder
|
||||
|
||||
从 MCP 服务器发现和构建 Skills。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._skills: dict[str, SkillMetadata] = {}
|
||||
|
||||
def discover_skills_from_mcp(self, mcp_servers: list[dict[str, Any]]) -> list[SkillMetadata]:
|
||||
"""从 MCP 服务器发现 Skills
|
||||
|
||||
Args:
|
||||
mcp_servers: MCP 服务器配置列表
|
||||
|
||||
Returns:
|
||||
发现的 Skill 元数据列表
|
||||
"""
|
||||
skills = []
|
||||
|
||||
for server in mcp_servers:
|
||||
server_skills = self._discover_from_server(server)
|
||||
skills.extend(server_skills)
|
||||
|
||||
return skills
|
||||
|
||||
def _discover_from_server(self, server: dict[str, Any]) -> list[SkillMetadata]:
|
||||
"""从单个 MCP 服务器发现 Skills"""
|
||||
skills = []
|
||||
server_name = server.get("name", "unknown")
|
||||
tools = server.get("tools", [])
|
||||
|
||||
# 按工具分组
|
||||
tool_groups: dict[str, list[str]] = {}
|
||||
for tool in tools:
|
||||
group = tool.get("group", "default")
|
||||
if group not in tool_groups:
|
||||
tool_groups[group] = []
|
||||
tool_groups[group].append(tool)
|
||||
|
||||
# 为每个组创建一个 Skill
|
||||
for group_name, group_tools in tool_groups.items():
|
||||
skill = self._tool_to_skill(group_name, group_tools, server_name)
|
||||
skills.append(skill)
|
||||
|
||||
return skills
|
||||
|
||||
def _tool_to_skill(self, group: str, tools: list[dict[str, Any]], server: str) -> SkillMetadata:
|
||||
"""将 MCP 工具转换为 Skill"""
|
||||
tool_summaries = []
|
||||
for tool in tools:
|
||||
name = tool.get("name", "unknown")
|
||||
description = tool.get("description", "")
|
||||
input_schema = tool.get("inputSchema", {})
|
||||
|
||||
tool_summaries.append(f"### {name}\n{description}\n\nInput: {input_schema}")
|
||||
|
||||
content = f"""# MCP Skill: {group}
|
||||
|
||||
来自 MCP 服务器: {server}
|
||||
|
||||
## 工具列表
|
||||
|
||||
{chr(10).join(tool_summaries)}
|
||||
|
||||
## 使用说明
|
||||
|
||||
使用这些工具前请确保理解每个工具的输入输出格式。
|
||||
"""
|
||||
|
||||
return SkillMetadata(
|
||||
name=f"mcp-{server}-{group}",
|
||||
description=f"MCP skill from {server}: {group}",
|
||||
version="1.0.0",
|
||||
tags=["mcp", server, group],
|
||||
triggers=[group, server],
|
||||
content=content,
|
||||
source="mcp",
|
||||
source_id=f"{server}:{group}",
|
||||
)
|
||||
|
||||
def _group_to_skill(self, group: str, tools: list[str], server: str) -> SkillMetadata:
|
||||
"""将 MCP 工具组转换为 Skill"""
|
||||
return SkillMetadata(
|
||||
name=f"mcp-{server}-{group}",
|
||||
description=f"MCP skill from {server}: {group}",
|
||||
version="1.0.0",
|
||||
tags=["mcp", server, group],
|
||||
triggers=[group, server],
|
||||
content=f"# {group}\n\nTools: {', '.join(tools)}",
|
||||
source="mcp",
|
||||
source_id=f"{server}:{group}",
|
||||
)
|
||||
38
backend/app/agents/skills/metadata.py
Normal file
38
backend/app/agents/skills/metadata.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Skill 元数据定义 - Phase 9.1"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillMetadata:
|
||||
"""Skill 元数据"""
|
||||
|
||||
name: str # Skill 名称
|
||||
description: str # 描述
|
||||
version: str = "1.0.0" # 版本
|
||||
author: str = "" # 作者
|
||||
tags: list[str] = field(default_factory=list) # 标签
|
||||
triggers: list[str] = field(default_factory=list) # 触发关键词
|
||||
content: str = "" # Skill 内容(markdown)
|
||||
source: str = "local" # 来源:local, plugin, mcp, bundled
|
||||
source_id: str = "" # 来源 ID
|
||||
enabled: bool = True # 是否启用
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"version": self.version,
|
||||
"author": self.author,
|
||||
"tags": self.tags,
|
||||
"triggers": self.triggers,
|
||||
"content": self.content,
|
||||
"source": self.source,
|
||||
"source_id": self.source_id,
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SkillMetadata":
|
||||
return cls(**data)
|
||||
133
backend/app/agents/skills/registry.py
Normal file
133
backend/app/agents/skills/registry.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Skills 注册表 - Phase 9.1"""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from app.agents.skills.metadata import SkillMetadata
|
||||
from app.agents.skills.loaders.local_loader import LocalSkillLoader
|
||||
|
||||
|
||||
class SkillRegistry:
|
||||
"""Skills 注册表
|
||||
|
||||
管理所有 Skills 的注册、发现和加载。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._skills: dict[str, SkillMetadata] = {}
|
||||
self._loaders: list[Any] = []
|
||||
|
||||
def load_all(self, skills_dir: str | None = None) -> int:
|
||||
"""加载所有 Skills
|
||||
|
||||
Args:
|
||||
skills_dir: Skills 目录,None 则使用默认目录
|
||||
|
||||
Returns:
|
||||
加载的 Skill 数量
|
||||
"""
|
||||
if skills_dir is None:
|
||||
skills_dir = os.path.join(
|
||||
os.path.dirname(__file__), "..", "..", "..", ".claude", "skills"
|
||||
)
|
||||
|
||||
count = 0
|
||||
|
||||
# 本地加载器
|
||||
local_loader = LocalSkillLoader(skills_dir)
|
||||
local_skills = local_loader.load_all()
|
||||
for skill in local_skills:
|
||||
self.register(skill)
|
||||
count += 1
|
||||
|
||||
# 插件加载器
|
||||
for loader in self._loaders:
|
||||
try:
|
||||
external_skills = loader.load_all()
|
||||
for skill in external_skills:
|
||||
self.register(skill)
|
||||
count += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return count
|
||||
|
||||
def register(self, skill: SkillMetadata) -> None:
|
||||
"""注册 Skill"""
|
||||
self._skills[skill.name] = skill
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""注销 Skill"""
|
||||
if name in self._skills:
|
||||
del self._skills[name]
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_skill(self, name: str) -> SkillMetadata | None:
|
||||
"""获取 Skill"""
|
||||
return self._skills.get(name)
|
||||
|
||||
def search(self, query: str) -> list[SkillMetadata]:
|
||||
"""搜索 Skills
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
|
||||
Returns:
|
||||
匹配的 Skills 列表
|
||||
"""
|
||||
query_lower = query.lower()
|
||||
results = []
|
||||
|
||||
for skill in self._skills.values():
|
||||
if not skill.enabled:
|
||||
continue
|
||||
|
||||
# 匹配名称、描述、标签
|
||||
if (
|
||||
query_lower in skill.name.lower()
|
||||
or query_lower in skill.description.lower()
|
||||
or any(query_lower in tag.lower() for tag in skill.tags)
|
||||
or any(query_lower in trigger.lower() for trigger in skill.triggers)
|
||||
):
|
||||
results.append(skill)
|
||||
|
||||
return results
|
||||
|
||||
def get_skill_context(self, names: list[str]) -> str:
|
||||
"""获取 Skill 上下文
|
||||
|
||||
Args:
|
||||
names: Skill 名称列表
|
||||
|
||||
Returns:
|
||||
拼接的 Skill 内容
|
||||
"""
|
||||
contexts = []
|
||||
|
||||
for name in names:
|
||||
skill = self._skills.get(name)
|
||||
if skill and skill.enabled:
|
||||
contexts.append(f"# {skill.name}\n\n{skill.content}")
|
||||
|
||||
return "\n\n---\n\n".join(contexts)
|
||||
|
||||
def add_loader(self, loader: Any) -> None:
|
||||
"""添加加载器"""
|
||||
self._loaders.append(loader)
|
||||
|
||||
def list_all(self) -> list[SkillMetadata]:
|
||||
"""列出所有 Skills"""
|
||||
return list(self._skills.values())
|
||||
|
||||
|
||||
# 全局单例
|
||||
_registry: SkillRegistry | None = None
|
||||
|
||||
|
||||
def get_skill_registry() -> SkillRegistry:
|
||||
"""获取全局 Skills 注册表"""
|
||||
global _registry
|
||||
if _registry is None:
|
||||
_registry = SkillRegistry()
|
||||
return _registry
|
||||
121
backend/app/agents/team/leader.py
Normal file
121
backend/app/agents/team/leader.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Team 多 Agent 协作 - Phase 10.1"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TeamTask:
|
||||
"""团队任务"""
|
||||
|
||||
id: str
|
||||
description: str
|
||||
assignee: str | None = None
|
||||
status: TaskStatus = TaskStatus.PENDING
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class TeamLeader:
|
||||
"""团队领导者
|
||||
|
||||
协调多个 Agent 成员执行任务。
|
||||
"""
|
||||
|
||||
def __init__(self, team_id: str, members: list[str]):
|
||||
"""
|
||||
Args:
|
||||
team_id: 团队 ID
|
||||
members: 成员 ID 列表
|
||||
"""
|
||||
self.team_id = team_id
|
||||
self.members = members
|
||||
self._tasks: dict[str, TeamTask] = {}
|
||||
|
||||
def create_task(self, description: str) -> str:
|
||||
"""创建任务
|
||||
|
||||
Args:
|
||||
description: 任务描述
|
||||
|
||||
Returns:
|
||||
任务 ID
|
||||
"""
|
||||
import uuid
|
||||
|
||||
task_id = str(uuid.uuid4())[:8]
|
||||
self._tasks[task_id] = TeamTask(
|
||||
id=task_id,
|
||||
description=description,
|
||||
)
|
||||
return task_id
|
||||
|
||||
def assign_task(self, task_id: str, member: str) -> bool:
|
||||
"""分配任务
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
member: 成员 ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if task_id not in self._tasks:
|
||||
return False
|
||||
|
||||
if member not in self.members:
|
||||
return False
|
||||
|
||||
self._tasks[task_id].assignee = member
|
||||
self._tasks[task_id].status = TaskStatus.IN_PROGRESS
|
||||
return True
|
||||
|
||||
def broadcast_task(self, description: str) -> list[str]:
|
||||
"""广播任务给所有成员
|
||||
|
||||
Args:
|
||||
description: 任务描述
|
||||
|
||||
Returns:
|
||||
创建的任务 ID 列表
|
||||
"""
|
||||
task_ids = []
|
||||
for member in self.members:
|
||||
task_id = self.create_task(description)
|
||||
self.assign_task(task_id, member)
|
||||
task_ids.append(task_id)
|
||||
return task_ids
|
||||
|
||||
def collect_results(self) -> dict[str, Any]:
|
||||
"""收集所有任务结果
|
||||
|
||||
Returns:
|
||||
任务 ID -> 结果的映射
|
||||
"""
|
||||
return {
|
||||
task_id: task.result
|
||||
for task_id, task in self._tasks.items()
|
||||
if task.status == TaskStatus.COMPLETED
|
||||
}
|
||||
|
||||
def get_team_status(self) -> dict[str, Any]:
|
||||
"""获取团队状态
|
||||
|
||||
Returns:
|
||||
团队状态摘要
|
||||
"""
|
||||
return {
|
||||
"team_id": self.team_id,
|
||||
"members": self.members,
|
||||
"task_count": len(self._tasks),
|
||||
"completed": sum(1 for t in self._tasks.values() if t.status == TaskStatus.COMPLETED),
|
||||
"failed": sum(1 for t in self._tasks.values() if t.status == TaskStatus.FAILED),
|
||||
}
|
||||
11
backend/app/agents/tools/hooks/builtins/__init__.py
Normal file
11
backend/app/agents/tools/hooks/builtins/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""内置 Hook 集合 - Phase 7"""
|
||||
|
||||
from app.agents.tools.hooks.builtins.audit_log import AuditLogHook
|
||||
from app.agents.tools.hooks.builtins.dangerous_confirmation import DangerousConfirmationHook
|
||||
from app.agents.tools.hooks.builtins.security_scan import SecurityScanHook
|
||||
|
||||
__all__ = [
|
||||
"AuditLogHook",
|
||||
"DangerousConfirmationHook",
|
||||
"SecurityScanHook",
|
||||
]
|
||||
115
backend/app/agents/tools/hooks/builtins/audit_log.py
Normal file
115
backend/app/agents/tools/hooks/builtins/audit_log.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""审计日志 Hook - Phase 7.2
|
||||
|
||||
记录所有工具调用到审计日志。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.hooks.types import (
|
||||
ExecutionContext,
|
||||
HookResult,
|
||||
HookType,
|
||||
)
|
||||
from app.agents.tools.manifest import ToolCategory
|
||||
|
||||
|
||||
class AuditLogHook:
|
||||
"""审计日志 Hook
|
||||
|
||||
记录所有工具调用的详细信息,包括:
|
||||
- 调用时间
|
||||
- 工具名称
|
||||
- 输入参数
|
||||
- 执行结果
|
||||
- 执行时长
|
||||
- 用户 ID
|
||||
"""
|
||||
|
||||
def __init__(self, log_path: str | None = None):
|
||||
"""
|
||||
Args:
|
||||
log_path: 日志文件路径,None 则输出到 stdout
|
||||
"""
|
||||
self.log_path = log_path
|
||||
self._logs: list[dict[str, Any]] = []
|
||||
|
||||
async def pre_tool_use(self, context: ExecutionContext) -> HookResult:
|
||||
"""工具执行前记录"""
|
||||
log_entry = {
|
||||
"event": "pre_tool",
|
||||
"tool_name": context.tool_name,
|
||||
"input": context.tool_input,
|
||||
"user_id": context.user_id,
|
||||
"session_id": context.session_id,
|
||||
}
|
||||
self._logs.append(log_entry)
|
||||
self._write_log(log_entry)
|
||||
return HookResult(
|
||||
hook_name="audit_log",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
)
|
||||
|
||||
async def post_tool_use(self, context: ExecutionContext, result: Any) -> HookResult:
|
||||
"""工具执行后记录"""
|
||||
log_entry = {
|
||||
"event": "post_tool",
|
||||
"tool_name": context.tool_name,
|
||||
"result": str(result)[:500] if result else None,
|
||||
"duration_ms": (
|
||||
(context.end_time - context.start_time) * 1000
|
||||
if context.start_time and context.end_time
|
||||
else None
|
||||
),
|
||||
}
|
||||
self._logs.append(log_entry)
|
||||
self._write_log(log_entry)
|
||||
return HookResult(
|
||||
hook_name="audit_log",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
modified_output=result,
|
||||
)
|
||||
|
||||
async def tool_error(self, context: ExecutionContext, error: Exception) -> HookResult:
|
||||
"""工具出错时记录"""
|
||||
log_entry = {
|
||||
"event": "tool_error",
|
||||
"tool_name": context.tool_name,
|
||||
"error": str(error),
|
||||
"error_type": type(error).__name__,
|
||||
}
|
||||
self._logs.append(log_entry)
|
||||
self._write_log(log_entry)
|
||||
return HookResult(
|
||||
hook_name="audit_log",
|
||||
success=False,
|
||||
continue_execution=True,
|
||||
error=str(error),
|
||||
)
|
||||
|
||||
def _write_log(self, entry: dict[str, Any]) -> None:
|
||||
"""写入日志"""
|
||||
import json
|
||||
import datetime
|
||||
|
||||
entry["timestamp"] = datetime.datetime.now().isoformat()
|
||||
|
||||
if self.log_path:
|
||||
try:
|
||||
with open(self.log_path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||
except Exception:
|
||||
# 日志写入失败不影响主流程
|
||||
pass
|
||||
else:
|
||||
# 输出到 stdout
|
||||
print(f"[AUDIT] {json.dumps(entry, ensure_ascii=False)}")
|
||||
|
||||
def get_logs(self) -> list[dict[str, Any]]:
|
||||
"""获取所有日志"""
|
||||
return self._logs.copy()
|
||||
|
||||
def clear_logs(self) -> None:
|
||||
"""清空日志"""
|
||||
self._logs.clear()
|
||||
@@ -0,0 +1,142 @@
|
||||
"""危险操作确认 Hook - Phase 7.2
|
||||
|
||||
对危险操作要求用户确认。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.hooks.types import (
|
||||
ExecutionContext,
|
||||
HookResult,
|
||||
)
|
||||
from app.agents.tools.manifest import SideEffectScope
|
||||
|
||||
|
||||
# 危险操作关键词
|
||||
DANGEROUS_PATTERNS = [
|
||||
# 文件操作
|
||||
"delete",
|
||||
"remove",
|
||||
"rm ",
|
||||
"rmdir",
|
||||
"unlink",
|
||||
"format",
|
||||
"truncate",
|
||||
# 系统操作
|
||||
"shutdown",
|
||||
"reboot",
|
||||
"kill",
|
||||
"pkill",
|
||||
"sudo",
|
||||
"chmod",
|
||||
"chown",
|
||||
# 数据操作
|
||||
"drop",
|
||||
"truncate",
|
||||
"delete from",
|
||||
"delete.*where",
|
||||
"insert into.*select",
|
||||
"update.*set",
|
||||
# 网络操作
|
||||
"curl",
|
||||
"wget",
|
||||
"nc ",
|
||||
"netcat",
|
||||
"ssh ",
|
||||
"scp ",
|
||||
"sftp ",
|
||||
# 环境变量
|
||||
"export.*secret",
|
||||
"export.*key",
|
||||
"export.*token",
|
||||
]
|
||||
|
||||
|
||||
class DangerousConfirmationHook:
|
||||
"""危险操作确认 Hook
|
||||
|
||||
检查工具调用是否包含危险操作,如是则要求确认。
|
||||
"""
|
||||
|
||||
def __init__(self, auto_block: bool = False):
|
||||
"""
|
||||
Args:
|
||||
auto_block: True 表示自动拦截危险操作,False 表示仅警告
|
||||
"""
|
||||
self.auto_block = auto_block
|
||||
self._pending_confirmations: dict[str, bool] = {}
|
||||
|
||||
async def pre_tool_use(self, context: ExecutionContext) -> HookResult:
|
||||
"""检查是否为危险操作"""
|
||||
is_dangerous = self._check_dangerous(context.tool_name, context.tool_input)
|
||||
|
||||
if is_dangerous:
|
||||
if self.auto_block:
|
||||
return HookResult(
|
||||
hook_name="dangerous_confirmation",
|
||||
success=False,
|
||||
continue_execution=False,
|
||||
error=f"危险操作被自动拦截: {context.tool_name}",
|
||||
metadata={"dangerous": True, "auto_blocked": True},
|
||||
)
|
||||
else:
|
||||
# 标记需要确认
|
||||
context.metadata["requires_confirmation"] = True
|
||||
context.metadata["dangerous_operation"] = True
|
||||
return HookResult(
|
||||
hook_name="dangerous_confirmation",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
metadata={"dangerous": True, "requires_confirmation": True},
|
||||
)
|
||||
|
||||
return HookResult(
|
||||
hook_name="dangerous_confirmation",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
)
|
||||
|
||||
def _check_dangerous(self, tool_name: str, tool_input: dict[str, Any]) -> bool:
|
||||
"""检查是否为危险操作"""
|
||||
# 检查工具名称
|
||||
dangerous_tools = [
|
||||
"delete",
|
||||
"remove",
|
||||
"drop",
|
||||
"truncate",
|
||||
"kill",
|
||||
"shutdown",
|
||||
"reboot",
|
||||
"bash",
|
||||
"powershell",
|
||||
"shell",
|
||||
]
|
||||
|
||||
if tool_name.lower() in dangerous_tools:
|
||||
return True
|
||||
|
||||
# 检查输入参数
|
||||
input_str = str(tool_input).lower()
|
||||
|
||||
for pattern in DANGEROUS_PATTERNS:
|
||||
if pattern.lower() in input_str:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def confirm(self, session_id: str, confirmed: bool) -> None:
|
||||
"""确认危险操作
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
confirmed: True 表示用户确认,False 表示取消
|
||||
"""
|
||||
self._pending_confirmations[session_id] = confirmed
|
||||
|
||||
def is_confirmed(self, session_id: str) -> bool:
|
||||
"""检查是否已确认"""
|
||||
return self._pending_confirmations.get(session_id, False)
|
||||
|
||||
def clear_confirmation(self, session_id: str) -> None:
|
||||
"""清除确认状态"""
|
||||
self._pending_confirmations.pop(session_id, None)
|
||||
183
backend/app/agents/tools/hooks/builtins/security_scan.py
Normal file
183
backend/app/agents/tools/hooks/builtins/security_scan.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""安全扫描 Hook - Phase 7.2
|
||||
|
||||
扫描工具调用和结果中的敏感信息。
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.hooks.types import (
|
||||
ExecutionContext,
|
||||
HookResult,
|
||||
)
|
||||
|
||||
|
||||
# 敏感信息模式
|
||||
SENSITIVE_PATTERNS = {
|
||||
"api_key": [
|
||||
r"api[_-]?key['\"]?\s*[:=]\s*['\"]?[a-zA-Z0-9_\-]{20,}",
|
||||
r"apikey['\"]?\s*[:=]\s*['\"]?[a-zA-Z0-9_\-]{20,}",
|
||||
],
|
||||
"password": [
|
||||
r"password['\"]?\s*[:=]\s*['\"]?[^\s'\"]{8,}",
|
||||
r"passwd['\"]?\s*[:=]\s*['\"]?[^\s'\"]{8,}",
|
||||
r"secret['\"]?\s*[:=]\s*['\"]?[a-zA-Z0-9_\-]{20,}",
|
||||
],
|
||||
"token": [
|
||||
r"token['\"]?\s*[:=]\s*['\"]?[a-zA-Z0-9_\-\.]{20,}",
|
||||
r"bearer\s+[a-zA-Z0-9_\-\.]+",
|
||||
r"ghp_[a-zA-Z0-9]{36}",
|
||||
r"sk-[a-zA-Z0-9]{48}",
|
||||
],
|
||||
"private_key": [
|
||||
r"-----BEGIN (RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----",
|
||||
r"-----END (RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----",
|
||||
],
|
||||
"ip_address": [
|
||||
r"\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b",
|
||||
],
|
||||
"email": [
|
||||
r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class SecurityScanHook:
|
||||
"""安全扫描 Hook
|
||||
|
||||
扫描工具输入和输出中的敏感信息,进行脱敏处理。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redact: bool = True,
|
||||
block_on_detect: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
redact: 是否对敏感信息进行脱敏
|
||||
block_on_detect: 检测到敏感信息时是否阻止执行
|
||||
"""
|
||||
self.redact = redact
|
||||
self.block_on_detect = block_on_detect
|
||||
self._compiled_patterns = {
|
||||
name: [re.compile(p, re.IGNORECASE) for p in patterns]
|
||||
for name, patterns in SENSITIVE_PATTERNS.items()
|
||||
}
|
||||
|
||||
async def pre_tool_use(self, context: ExecutionContext) -> HookResult:
|
||||
"""扫描输入参数"""
|
||||
detected = self._scan_dict(context.tool_input)
|
||||
|
||||
if detected:
|
||||
context.metadata["security_detected"] = detected
|
||||
|
||||
if self.block_on_detect:
|
||||
return HookResult(
|
||||
hook_name="security_scan",
|
||||
success=False,
|
||||
continue_execution=False,
|
||||
error=f"检测到敏感信息: {', '.join(detected.keys())}",
|
||||
metadata={"detected": detected, "blocked": True},
|
||||
)
|
||||
|
||||
if self.redact:
|
||||
redacted_input = self._redact_dict(context.tool_input.copy())
|
||||
return HookResult(
|
||||
hook_name="security_scan",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
modified_input=redacted_input,
|
||||
metadata={"detected": detected, "redacted": True},
|
||||
)
|
||||
|
||||
return HookResult(
|
||||
hook_name="security_scan",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
)
|
||||
|
||||
async def post_tool_use(self, context: ExecutionContext, result: Any) -> HookResult:
|
||||
"""扫描输出结果"""
|
||||
if isinstance(result, dict):
|
||||
detected = self._scan_dict(result)
|
||||
|
||||
if detected:
|
||||
context.metadata["security_detected_output"] = detected
|
||||
|
||||
if self.redact:
|
||||
redacted_result = self._redact_dict(result.copy())
|
||||
return HookResult(
|
||||
hook_name="security_scan",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
modified_output=redacted_result,
|
||||
metadata={"detected": detected, "redacted": True},
|
||||
)
|
||||
|
||||
elif isinstance(result, str):
|
||||
detected = self._scan_string(result)
|
||||
if detected:
|
||||
context.metadata["security_detected_output"] = detected
|
||||
|
||||
if self.redact:
|
||||
redacted_result = self._redact_string(result)
|
||||
return HookResult(
|
||||
hook_name="security_scan",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
modified_output=redacted_result,
|
||||
metadata={"detected": detected, "redacted": True},
|
||||
)
|
||||
|
||||
return HookResult(
|
||||
hook_name="security_scan",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
modified_output=result,
|
||||
)
|
||||
|
||||
def _scan_dict(self, data: dict[str, Any]) -> dict[str, list[str]]:
|
||||
"""扫描字典中的敏感信息"""
|
||||
result: dict[str, list[str]] = {}
|
||||
|
||||
for key, value in data.items():
|
||||
if isinstance(value, str):
|
||||
found = self._scan_string(value)
|
||||
if found:
|
||||
result[key] = found
|
||||
|
||||
return result
|
||||
|
||||
def _scan_string(self, text: str) -> list[str]:
|
||||
"""扫描字符串中的敏感信息"""
|
||||
found_types = []
|
||||
|
||||
for name, patterns in self._compiled_patterns.items():
|
||||
for pattern in patterns:
|
||||
if pattern.search(text):
|
||||
if name not in found_types:
|
||||
found_types.append(name)
|
||||
break
|
||||
|
||||
return found_types
|
||||
|
||||
def _redact_dict(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""脱敏字典中的敏感信息"""
|
||||
for key, value in data.items():
|
||||
if isinstance(value, str):
|
||||
data[key] = self._redact_string(value)
|
||||
elif isinstance(value, dict):
|
||||
data[key] = self._redact_dict(value)
|
||||
elif isinstance(value, list):
|
||||
data[key] = [self._redact_string(v) if isinstance(v, str) else v for v in value]
|
||||
|
||||
return data
|
||||
|
||||
def _redact_string(self, text: str) -> str:
|
||||
"""脱敏字符串中的敏感信息"""
|
||||
for name, patterns in self._compiled_patterns.items():
|
||||
for pattern in patterns:
|
||||
text = pattern.sub(f"[REDACTED:{name}]", text)
|
||||
|
||||
return text
|
||||
105
backend/app/agents/tools/hooks/config.py
Normal file
105
backend/app/agents/tools/hooks/config.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Hook 配置持久化 - Phase 7.3"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.hooks.manager import get_hook_manager
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookConfigEntry:
|
||||
"""Hook 配置条目"""
|
||||
|
||||
name: str
|
||||
hook_type: str
|
||||
enabled: bool
|
||||
tool_names: list[str] | None = None
|
||||
categories: list[str] | None = None
|
||||
priority: int = 0
|
||||
|
||||
|
||||
class HookConfigPersistence:
|
||||
"""Hook 配置持久化"""
|
||||
|
||||
def __init__(self, config_path: str | None = None):
|
||||
"""
|
||||
Args:
|
||||
config_path: 配置文件路径,None 则使用默认路径
|
||||
"""
|
||||
if config_path is None:
|
||||
config_path = os.path.join(
|
||||
os.path.dirname(__file__), "..", "..", "..", "..", "config", "hooks.json"
|
||||
)
|
||||
self.config_path = config_path
|
||||
|
||||
def load_config(self) -> list[HookConfigEntry]:
|
||||
"""从文件加载 Hook 配置"""
|
||||
if not os.path.exists(self.config_path):
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(self.config_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return [HookConfigEntry(**entry) for entry in data]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def save_config(self, entries: list[HookConfigEntry]) -> bool:
|
||||
"""保存 Hook 配置到文件"""
|
||||
try:
|
||||
os.makedirs(os.path.dirname(self.config_path), exist_ok=True)
|
||||
with open(self.config_path, "w", encoding="utf-8") as f:
|
||||
json.dump([asdict(e) for e in entries], f, indent=2, ensure_ascii=False)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def apply_config(self) -> int:
|
||||
"""应用配置到 HookManager
|
||||
|
||||
Returns:
|
||||
应用的 Hook 数量
|
||||
"""
|
||||
from app.agents.tools.hooks.types import HookType
|
||||
|
||||
manager = get_hook_manager()
|
||||
entries = self.load_config()
|
||||
count = 0
|
||||
|
||||
for entry in entries:
|
||||
if entry.enabled:
|
||||
from app.agents.tools.hooks.types import HookDefinition, HookTrigger
|
||||
|
||||
trigger = HookTrigger(
|
||||
tool_names=entry.tool_names,
|
||||
categories=entry.categories,
|
||||
)
|
||||
|
||||
# 创建空的 handler,只是注册配置
|
||||
hook_def = HookDefinition(
|
||||
name=entry.name,
|
||||
hook_type=HookType(entry.hook_type),
|
||||
trigger=trigger,
|
||||
handler=lambda ctx, *args: ctx,
|
||||
priority=entry.priority,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
manager.register(hook_def)
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
# 全局单例
|
||||
_persistence: HookConfigPersistence | None = None
|
||||
|
||||
|
||||
def get_hook_config_persistence() -> HookConfigPersistence:
|
||||
"""获取全局 Hook 配置持久化实例"""
|
||||
global _persistence
|
||||
if _persistence is None:
|
||||
_persistence = HookConfigPersistence()
|
||||
return _persistence
|
||||
113
backend/app/agents/transport/remote.py
Normal file
113
backend/app/agents/transport/remote.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""远程传输层 - Phase 10.2"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructuredMessage:
|
||||
"""结构化消息"""
|
||||
|
||||
type: str # response, event, tool_call, error
|
||||
data: dict[str, Any]
|
||||
session_id: str | None = None
|
||||
|
||||
|
||||
class RemoteTransport:
|
||||
"""远程传输层
|
||||
|
||||
处理与远程 Agent 的通信。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._connections: dict[str, Any] = {}
|
||||
self._handlers: dict[str, Any] = {}
|
||||
|
||||
async def send_response(self, session_id: str, response: dict[str, Any]) -> bool:
|
||||
"""发送响应
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
response: 响应数据
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
message = StructuredMessage(
|
||||
type="response",
|
||||
data=response,
|
||||
session_id=session_id,
|
||||
)
|
||||
return await self._send(session_id, message)
|
||||
|
||||
async def send_event(self, session_id: str, event: dict[str, Any]) -> bool:
|
||||
"""发送事件
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
event: 事件数据
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
message = StructuredMessage(
|
||||
type="event",
|
||||
data=event,
|
||||
session_id=session_id,
|
||||
)
|
||||
return await self._send(session_id, message)
|
||||
|
||||
async def send_tool_call(self, session_id: str, tool_call: dict[str, Any]) -> bool:
|
||||
"""发送工具调用
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
tool_call: 工具调用数据
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
message = StructuredMessage(
|
||||
type="tool_call",
|
||||
data=tool_call,
|
||||
session_id=session_id,
|
||||
)
|
||||
return await self._send(session_id, message)
|
||||
|
||||
async def _send(self, session_id: str, message: StructuredMessage) -> bool:
|
||||
"""内部发送方法"""
|
||||
if session_id not in self._connections:
|
||||
return False
|
||||
|
||||
try:
|
||||
connection = self._connections[session_id]
|
||||
if hasattr(connection, "send"):
|
||||
await connection.send(json.dumps(message.__dict__))
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
def register_handler(self, event_type: str, handler: Any) -> None:
|
||||
"""注册消息处理器
|
||||
|
||||
Args:
|
||||
event_type: 事件类型
|
||||
handler: 处理函数
|
||||
"""
|
||||
self._handlers[event_type] = handler
|
||||
|
||||
async def handle_message(self, session_id: str, message: dict[str, Any]) -> None:
|
||||
"""处理收到的消息
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
message: 消息数据
|
||||
"""
|
||||
msg_type = message.get("type")
|
||||
handler = self._handlers.get(msg_type)
|
||||
if handler:
|
||||
await handler(session_id, message.get("data"))
|
||||
Reference in New Issue
Block a user