Phase 7-10: CustomHookLoader, MCPSkillLoader, SkillTriggerDetector, TeamMember, WebSocketManager

This commit is contained in:
2026-04-05 10:56:21 +08:00
parent d18167826e
commit fca7a7cf3d
11 changed files with 958 additions and 14 deletions

View File

@@ -0,0 +1,12 @@
"""Skills 加载器包"""
from app.agents.skills.loaders.local_loader import LocalSkillLoader
from app.agents.skills.loaders.plugin_loader import PluginSkillLoader
from app.agents.skills.loaders.mcp_loader import MCPSkillLoader, get_mcp_skill_loader
__all__ = [
"LocalSkillLoader",
"PluginSkillLoader",
"MCPSkillLoader",
"get_mcp_skill_loader",
]

View File

@@ -0,0 +1,169 @@
"""MCP Skill 加载器 - Phase 9.2
从 MCP (Model Context Protocol) 服务器发现和加载 Skills。
"""
import os
from typing import Any
from app.agents.skills.metadata import SkillMetadata
class MCPSkillLoader:
"""MCP Skill 加载器
从 MCP 服务器发现可用的 Skills。
"""
def __init__(self, mcp_servers: list[dict[str, Any]] | None = None):
"""
Args:
mcp_servers: MCP 服务器列表,每项包含 name, command, env 等
"""
self.mcp_servers = mcp_servers or []
self._discovered_skills: dict[str, SkillMetadata] = {}
def discover_skills(self) -> list[SkillMetadata]:
"""从所有配置的 MCP 服务器发现 Skills
Returns:
发现的 Skill 列表
"""
skills = []
for server in self.mcp_servers:
server_skills = self._discover_from_server(server)
skills.extend(server_skills)
return skills
def _discover_from_server(self, server: dict[str, Any]) -> list[SkillMetadata]:
"""从单个 MCP 服务器发现 Skills
Args:
server: 服务器配置
Returns:
Skill 列表
"""
skills = []
server_name = server.get("name", "unknown")
# 模拟从 MCP 服务器获取工具列表
# 实际实现时,这里会调用 MCP 服务器的 list_tools 接口
try:
tools = self._call_mcp_list_tools(server)
for tool in tools:
skill = self._tool_to_skill(tool, server_name)
if skill:
skills.append(skill)
self._discovered_skills[skill.name] = skill
except Exception:
pass
return skills
def _call_mcp_list_tools(self, server: dict[str, Any]) -> list[dict[str, Any]]:
"""调用 MCP 服务器的 list_tools 接口
Args:
server: 服务器配置
Returns:
工具列表
"""
# TODO: 实现实际的 MCP 协议调用
# 目前返回空列表,实际使用时需要实现 MCP 客户端
return []
def _tool_to_skill(self, tool: dict[str, Any], server: str) -> SkillMetadata | None:
"""将 MCP 工具转换为 Skill
Args:
tool: MCP 工具定义
server: 服务器名
Returns:
Skill 元数据或 None
"""
tool_name = tool.get("name")
if not tool_name:
return None
return SkillMetadata(
id=f"mcp_{server}_{tool_name}",
name=f"{server}:{tool_name}",
description=tool.get("description", f"MCP tool: {tool_name}"),
version="1.0.0",
content=self._generate_skill_content(tool),
triggers=[f"@{server}", f"/{tool_name}"],
tools=[tool_name],
tags=["mcp", server],
enabled=True,
)
def _generate_skill_content(self, tool: dict[str, Any]) -> str:
"""生成 Skill 内容
Args:
tool: MCP 工具定义
Returns:
Skill 内容字符串
"""
name = tool.get("name", "unknown")
description = tool.get("description", "No description")
input_schema = tool.get("inputSchema", {})
content = f"""# MCP Tool: {name}
**Description**: {description}
**Server**: {tool.get("server", "unknown")}
**Input Schema**:
```json
{input_schema}
```
**Usage**:
Use the `/{name}` command or `@{tool.get("server", "server")}` to invoke this tool.
**Examples**:
```
/{name} arg1=value1 arg2=value2
@{tool.get("server", "server")} {name} --arg1 value1
```
"""
return content
def get_skill(self, name: str) -> SkillMetadata | None:
"""获取已发现的 Skill
Args:
name: Skill 名称
Returns:
Skill 元数据或 None
"""
return self._discovered_skills.get(name)
def list_skills(self) -> list[SkillMetadata]:
"""列出所有已发现的 Skills
Returns:
Skill 列表
"""
return list(self._discovered_skills.values())
# 全局加载器
_loader: MCPSkillLoader | None = None
def get_mcp_skill_loader() -> MCPSkillLoader:
"""获取全局 MCP Skill 加载器"""
global _loader
if _loader is None:
_loader = MCPSkillLoader()
return _loader

View File

@@ -8,8 +8,9 @@ from typing import Any
class SkillMetadata:
"""Skill 元数据"""
name: str # Skill 名称
description: str # 描述
id: str = "" # Skill ID
name: str = "" # Skill 名称
description: str = "" # 描述
version: str = "1.0.0" # 版本
author: str = "" # 作者
tags: list[str] = field(default_factory=list) # 标签
@@ -18,9 +19,11 @@ class SkillMetadata:
source: str = "local" # 来源local, plugin, mcp, bundled
source_id: str = "" # 来源 ID
enabled: bool = True # 是否启用
tools: list[str] = field(default_factory=list) # 关联的工具
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"name": self.name,
"description": self.description,
"version": self.version,
@@ -31,6 +34,7 @@ class SkillMetadata:
"source": self.source,
"source_id": self.source_id,
"enabled": self.enabled,
"tools": self.tools,
}
@classmethod

View File

@@ -0,0 +1,140 @@
"""Skill 触发检测器 - Phase 9.5
检测消息中的 Skill 触发条件。
"""
import re
from typing import Any
from app.agents.skills.metadata import SkillMetadata
class SkillTriggerDetector:
"""Skill 触发检测器
检测用户消息中是否触发了某个 Skill。
"""
def __init__(self):
self._skills: dict[str, SkillMetadata] = {}
def register_skill(self, skill: SkillMetadata) -> None:
"""注册 Skill
Args:
skill: Skill 元数据
"""
self._skills[skill.name] = skill
def unregister_skill(self, name: str) -> bool:
"""注销 Skill
Args:
name: Skill 名称
Returns:
是否成功
"""
if name in self._skills:
del self._skills[name]
return True
return False
def detect_triggered_skills(self, message: str) -> list[str]:
"""检测触发的 Skills
Args:
message: 用户消息
Returns:
触发的 Skill 名称列表
"""
triggered = []
message_lower = message.lower()
for skill in self._skills.values():
if not skill.enabled:
continue
if self._matches_triggers(message, message_lower, skill):
triggered.append(skill.name)
return triggered
def _matches_triggers(self, message: str, message_lower: str, skill: SkillMetadata) -> bool:
"""检查消息是否匹配 Skill 触发条件
Args:
message: 原始消息
message_lower: 小写消息
skill: Skill 元数据
Returns:
是否匹配
"""
for trigger in skill.triggers:
trigger_lower = trigger.lower()
# 前缀匹配,如 "/code" 或 "@git"
if trigger_lower.startswith("/") or trigger_lower.startswith("@"):
if message_lower.startswith(trigger_lower):
return True
# 命令格式,如 "//analyze"
if trigger_lower.startswith("//"):
pattern = trigger_lower[2:]
if re.search(rf"\b{re.escape(pattern)}\b", message_lower):
return True
# 关键词匹配
if trigger_lower in message_lower:
return True
return False
def get_skill_prompt(self, skill_name: str) -> str | None:
"""获取 Skill 的提示词
Args:
skill_name: Skill 名称
Returns:
Skill 内容或 None
"""
skill = self._skills.get(skill_name)
if skill:
return skill.content
return None
def get_triggered_skill_context(self, message: str) -> str:
"""获取触发的 Skills 上下文
Args:
message: 用户消息
Returns:
拼接的 Skill 上下文
"""
triggered = self.detect_triggered_skills(message)
if not triggered:
return ""
contexts = []
for skill_name in triggered:
skill = self._skills.get(skill_name)
if skill:
contexts.append(f"# {skill.name}\n\n{skill.content}")
return "\n\n---\n\n".join(contexts)
# 全局检测器
_detector: SkillTriggerDetector | None = None
def get_skill_trigger_detector() -> SkillTriggerDetector:
"""获取全局 Skill 触发检测器"""
global _detector
if _detector is None:
_detector = SkillTriggerDetector()
return _detector

View File

@@ -0,0 +1,13 @@
"""Team 多 Agent 协作"""
from app.agents.team.leader import TeamLeader, TeamTask, TaskStatus
from app.agents.team.member import TeamMember, MemberStatus, MemberTask
__all__ = [
"TeamLeader",
"TeamTask",
"TaskStatus",
"TeamMember",
"MemberStatus",
"MemberTask",
]

View File

@@ -0,0 +1,166 @@
"""TeamMember 实现 - Phase 10.1
团队成员实现,负责执行分配的任务。
"""
from dataclasses import dataclass, field
from typing import Any
from enum import Enum
class MemberStatus(Enum):
"""成员状态"""
IDLE = "idle"
BUSY = "busy"
OFFLINE = "offline"
@dataclass
class MemberTask:
"""成员任务"""
task_id: str
description: str
status: str = "pending" # pending, in_progress, completed, failed
result: Any = None
error: str | None = None
class TeamMember:
"""团队成员
代表团队中的一个 Agent 成员,负责执行分配的任务。
"""
def __init__(self, member_id: str, name: str, capabilities: list[str] | None = None):
"""
Args:
member_id: 成员 ID
name: 成员名称
capabilities: 成员能力列表
"""
self.member_id = member_id
self.name = name
self.capabilities = capabilities or []
self.status = MemberStatus.IDLE
self._tasks: dict[str, MemberTask] = {}
self._metadata: dict[str, Any] = {}
def assign_task(self, task_id: str, description: str) -> MemberTask:
"""接收任务分配
Args:
task_id: 任务 ID
description: 任务描述
Returns:
创建的任务对象
"""
task = MemberTask(task_id=task_id, description=description)
self._tasks[task_id] = task
self.status = MemberStatus.BUSY
return task
def update_task_status(
self, task_id: str, status: str, result: Any = None, error: str | None = None
) -> bool:
"""更新任务状态
Args:
task_id: 任务 ID
status: 新状态
result: 任务结果
error: 错误信息
Returns:
是否更新成功
"""
if task_id not in self._tasks:
return False
task = self._tasks[task_id]
task.status = status
if result is not None:
task.result = result
if error is not None:
task.error = error
if status in ("completed", "failed"):
self.status = MemberStatus.IDLE
return True
def get_task(self, task_id: str) -> MemberTask | None:
"""获取任务
Args:
task_id: 任务 ID
Returns:
任务对象或 None
"""
return self._tasks.get(task_id)
def get_pending_tasks(self) -> list[MemberTask]:
"""获取待处理任务
Returns:
待处理任务列表
"""
return [t for t in self._tasks.values() if t.status == "pending"]
def get_active_task(self) -> MemberTask | None:
"""获取当前执行中的任务
Returns:
当前任务或 None
"""
for task in self._tasks.values():
if task.status == "in_progress":
return task
return None
def get_completed_tasks(self) -> list[MemberTask]:
"""获取已完成任务
Returns:
已完成任务列表
"""
return [t for t in self._tasks.values() if t.status == "completed"]
def set_metadata(self, key: str, value: Any) -> None:
"""设置元数据
Args:
key: 元数据键
value: 元数据值
"""
self._metadata[key] = value
def get_metadata(self, key: str) -> Any:
"""获取元数据
Args:
key: 元数据键
Returns:
元数据值或 None
"""
return self._metadata.get(key)
def get_status(self) -> dict[str, Any]:
"""获取成员状态
Returns:
状态字典
"""
return {
"member_id": self.member_id,
"name": self.name,
"status": self.status.value,
"capabilities": self.capabilities,
"task_count": len(self._tasks),
"pending_count": len(self.get_pending_tasks()),
"active_task": self.get_active_task().__dict__ if self.get_active_task() else None,
}

View File

@@ -0,0 +1,5 @@
"""自定义 Hook 加载器包"""
from app.agents.tools.hooks.custom.loader import CustomHookLoader, get_custom_hook_loader
__all__ = ["CustomHookLoader", "get_custom_hook_loader"]

View File

@@ -0,0 +1,153 @@
"""自定义 Hook 加载器 - Phase 7.4
支持动态加载用户自定义的 Hook。
"""
import importlib.util
import os
from typing import Any
from app.agents.tools.hooks.types import HookDefinition, HookType, HookTrigger, HookResult
class CustomHookLoader:
"""自定义 Hook 加载器
从指定目录动态加载自定义 Hook 模块。
"""
def __init__(self, hooks_dir: str | None = None):
"""
Args:
hooks_dir: Hook 目录None 则使用默认目录
"""
if hooks_dir is None:
hooks_dir = os.path.join(
os.path.dirname(__file__), "..", "..", "..", "data", "custom_hooks"
)
self.hooks_dir = hooks_dir
self._loaded_hooks: dict[str, HookDefinition] = {}
def load_all(self) -> list[HookDefinition]:
"""加载所有自定义 Hook
Returns:
Hook 定义列表
"""
hooks = []
if not os.path.exists(self.hooks_dir):
return hooks
for filename in os.listdir(self.hooks_dir):
if filename.endswith(".py") and not filename.startswith("_"):
hook_path = os.path.join(self.hooks_dir, filename)
hook_def = self._load_hook_from_file(hook_path, filename[:-3])
if hook_def:
hooks.append(hook_def)
self._loaded_hooks[hook_def.name] = hook_def
return hooks
def _load_hook_from_file(self, hook_path: str, module_name: str) -> HookDefinition | None:
"""从文件加载 Hook
Args:
hook_path: Hook 文件路径
module_name: 模块名
Returns:
Hook 定义或 None
"""
try:
spec = importlib.util.spec_from_file_location(module_name, hook_path)
if not spec or not spec.loader:
return None
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# 查找 HOOK_DEFINITION 或 hook_definition
hook_def = getattr(module, "HOOK_DEFINITION", None) or getattr(
module, "hook_definition", None
)
if hook_def and isinstance(hook_def, HookDefinition):
return hook_def
# 如果没有定义,尝试从函数自动推断
if hasattr(module, "pre_tool_hook") or hasattr(module, "post_tool_hook"):
return self._infer_hook_definition(module, module_name)
except Exception:
pass
return None
def _infer_hook_definition(self, module: Any, module_name: str) -> HookDefinition | None:
"""从模块函数推断 Hook 定义
Args:
module: 模块对象
module_name: 模块名
Returns:
Hook 定义或 None
"""
hook_type = None
handler = None
if hasattr(module, "pre_tool_hook"):
handler = module.pre_tool_hook
hook_type = HookType.PRE_TOOL_USE
elif hasattr(module, "post_tool_hook"):
handler = module.post_tool_hook
hook_type = HookType.POST_TOOL_USE
elif hasattr(module, "error_tool_hook"):
handler = module.error_tool_hook
hook_type = HookType.TOOL_ERROR
if not handler or not hook_type:
return None
return HookDefinition(
name=module_name,
hook_type=hook_type,
trigger=HookTrigger(),
handler=handler,
priority=0,
enabled=True,
description=f"Auto-loaded hook from {module_name}",
)
def get_hook(self, name: str) -> HookDefinition | None:
"""获取已加载的 Hook
Args:
name: Hook 名称
Returns:
Hook 定义或 None
"""
return self._loaded_hooks.get(name)
def reload(self) -> list[HookDefinition]:
"""重新加载所有 Hook
Returns:
重新加载的 Hook 列表
"""
self._loaded_hooks.clear()
return self.load_all()
# 全局加载器
_loader: CustomHookLoader | None = None
def get_custom_hook_loader() -> CustomHookLoader:
"""获取全局自定义 Hook 加载器"""
global _loader
if _loader is None:
_loader = CustomHookLoader()
return _loader

View File

@@ -0,0 +1,207 @@
"""WebSocket 连接管理 - Phase 10.2
管理 WebSocket 连接的生命周期。
"""
import asyncio
import json
from typing import Any, Callable
from dataclasses import dataclass
@dataclass
class WSConnection:
"""WebSocket 连接"""
session_id: str
websocket: Any # WebSocket 连接
user_id: str | None = None
created_at: float | None = None
last_ping: float | None = None
class WebSocketManager:
"""WebSocket 连接管理器
管理所有 WebSocket 连接的生命周期。
"""
def __init__(self, ping_interval: float = 30.0):
"""
Args:
ping_interval: 心跳间隔(秒)
"""
self._connections: dict[str, WSConnection] = {}
self._handlers: dict[str, Callable] = {}
self._ping_interval = ping_interval
self._ping_tasks: dict[str, asyncio.Task] = {}
async def connect(self, session_id: str, websocket: Any, user_id: str | None = None) -> bool:
"""建立连接
Args:
session_id: 会话 ID
websocket: WebSocket 连接
user_id: 用户 ID
Returns:
是否连接成功
"""
import time
if session_id in self._connections:
return False
conn = WSConnection(
session_id=session_id,
websocket=websocket,
user_id=user_id,
created_at=time.time(),
last_ping=time.time(),
)
self._connections[session_id] = conn
# 启动心跳
self._ping_tasks[session_id] = asyncio.create_task(self._ping_loop(session_id))
return True
async def disconnect(self, session_id: str) -> bool:
"""断开连接
Args:
session_id: 会话 ID
Returns:
是否断开成功
"""
if session_id not in self._connections:
return False
# 停止心跳
if session_id in self._ping_tasks:
self._ping_tasks[session_id].cancel()
del self._ping_tasks[session_id]
del self._connections[session_id]
return True
async def send(self, session_id: str, message: dict[str, Any]) -> bool:
"""发送消息
Args:
session_id: 会话 ID
message: 消息内容
Returns:
是否发送成功
"""
if session_id not in self._connections:
return False
try:
conn = self._connections[session_id]
await conn.websocket.send_json(message)
return True
except Exception:
return False
async def broadcast(self, message: dict[str, Any]) -> int:
"""广播消息
Args:
message: 消息内容
Returns:
发送成功的数量
"""
count = 0
for session_id in list(self._connections.keys()):
if await self.send(session_id, message):
count += 1
return count
async def _ping_loop(self, session_id: str) -> None:
"""心跳循环
Args:
session_id: 会话 ID
"""
import time
while session_id in self._connections:
await asyncio.sleep(self._ping_interval)
if session_id not in self._connections:
break
try:
conn = self._connections[session_id]
await conn.websocket.send_json({"type": "ping", "timestamp": time.time()})
conn.last_ping = time.time()
except Exception:
await self.disconnect(session_id)
break
def register_handler(self, event_type: str, handler: Callable) -> None:
"""注册消息处理器
Args:
event_type: 事件类型
handler: 处理函数
"""
self._handlers[event_type] = handler
async def handle_message(self, session_id: str, message: dict[str, Any]) -> None:
"""处理消息
Args:
session_id: 会话 ID
message: 消息内容
"""
msg_type = message.get("type")
handler = self._handlers.get(msg_type)
if handler:
await handler(session_id, message.get("data"))
def get_connection(self, session_id: str) -> WSConnection | None:
"""获取连接
Args:
session_id: 会话 ID
Returns:
连接信息或 None
"""
return self._connections.get(session_id)
def list_connections(self) -> list[WSConnection]:
"""列出所有连接
Returns:
连接列表
"""
return list(self._connections.values())
def is_connected(self, session_id: str) -> bool:
"""检查是否连接
Args:
session_id: 会话 ID
Returns:
是否已连接
"""
return session_id in self._connections
# 全局单例
_ws_manager: WebSocketManager | None = None
def get_websocket_manager() -> WebSocketManager:
"""获取全局 WebSocket 管理器"""
global _ws_manager
if _ws_manager is None:
_ws_manager = WebSocketManager()
return _ws_manager