feat(agents): Phase 7-10 API endpoints for hooks, plugins, skills, sessions
This commit is contained in:
17
backend/app/agents/session/__init__.py
Normal file
17
backend/app/agents/session/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Agent Session Management - Phase 10.3"""
|
||||
|
||||
from app.agents.session.manager import (
|
||||
AgentSession,
|
||||
SessionContext,
|
||||
SessionPersistence,
|
||||
create_agent_session,
|
||||
get_agent_session,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentSession",
|
||||
"SessionContext",
|
||||
"SessionPersistence",
|
||||
"create_agent_session",
|
||||
"get_agent_session",
|
||||
]
|
||||
238
backend/app/agents/session/manager.py
Normal file
238
backend/app/agents/session/manager.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""Agent Session 管理 - Phase 10.3
|
||||
|
||||
支持会话层级管理和子会话创建。
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionContext:
|
||||
"""会话上下文"""
|
||||
|
||||
session_id: str
|
||||
parent_session_id: str | None = None
|
||||
root_session_id: str | None = None
|
||||
depth: int = 0
|
||||
user_id: str | None = None
|
||||
created_at: str | None = None
|
||||
last_active: str | None = None
|
||||
message_count: int = 0
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.now().isoformat()
|
||||
if self.last_active is None:
|
||||
self.last_active = self.created_at
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionPersistence:
|
||||
"""会话持久化"""
|
||||
|
||||
def __init__(self, persistence_dir: str | None = None):
|
||||
if persistence_dir is None:
|
||||
persistence_dir = os.path.join(
|
||||
os.path.dirname(__file__), "..", "..", "..", "data", "sessions"
|
||||
)
|
||||
self.persistence_dir = persistence_dir
|
||||
|
||||
def _get_session_path(self, session_id: str) -> str:
|
||||
return os.path.join(self.persistence_dir, f"{session_id}.json")
|
||||
|
||||
def save(self, session: "AgentSession") -> bool:
|
||||
"""保存会话"""
|
||||
try:
|
||||
os.makedirs(self.persistence_dir, exist_ok=True)
|
||||
path = self._get_session_path(session.session_id)
|
||||
data = {
|
||||
"session_id": session.session_id,
|
||||
"parent_session_id": session.context.parent_session_id,
|
||||
"root_session_id": session.context.root_session_id,
|
||||
"depth": session.context.depth,
|
||||
"user_id": session.context.user_id,
|
||||
"created_at": session.context.created_at,
|
||||
"last_active": session.context.last_active,
|
||||
"message_count": session.context.message_count,
|
||||
"metadata": session.context.metadata,
|
||||
"history": session._history,
|
||||
}
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def load(self, session_id: str) -> dict[str, Any] | None:
|
||||
"""加载会话"""
|
||||
try:
|
||||
path = self._get_session_path(session_id)
|
||||
if not os.path.exists(path):
|
||||
return None
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def delete(self, session_id: str) -> bool:
|
||||
"""删除会话"""
|
||||
try:
|
||||
path = self._get_session_path(session_id)
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]:
|
||||
"""列出所有会话"""
|
||||
sessions = []
|
||||
try:
|
||||
os.makedirs(self.persistence_dir, exist_ok=True)
|
||||
for filename in os.listdir(self.persistence_dir):
|
||||
if filename.endswith(".json"):
|
||||
path = os.path.join(self.persistence_dir, filename)
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if user_id is None or data.get("user_id") == user_id:
|
||||
sessions.append(data)
|
||||
except Exception:
|
||||
pass
|
||||
return sessions
|
||||
|
||||
|
||||
class AgentSession:
|
||||
"""Agent 会话管理器
|
||||
|
||||
支持:
|
||||
- 会话层级(parent/root/depth)
|
||||
- 子会话创建
|
||||
- 会话摘要
|
||||
- 持久化
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
parent_session_id: str | None = None,
|
||||
):
|
||||
self.session_id = session_id or str(uuid.uuid4())[:8]
|
||||
self.context = SessionContext(
|
||||
session_id=self.session_id,
|
||||
user_id=user_id,
|
||||
parent_session_id=parent_session_id,
|
||||
depth=0 if parent_session_id is None else 1,
|
||||
)
|
||||
self._history: list[dict[str, Any]] = []
|
||||
self._persistence = SessionPersistence()
|
||||
|
||||
# 如果有父会话,设置 root_session_id
|
||||
if parent_session_id:
|
||||
parent_data = self._persistence.load(parent_session_id)
|
||||
if parent_data:
|
||||
self.context.root_session_id = (
|
||||
parent_data.get("root_session_id") or parent_session_id
|
||||
)
|
||||
self.context.depth = parent_data.get("depth", 0) + 1
|
||||
|
||||
async def initialize(self) -> dict[str, Any]:
|
||||
"""初始化会话"""
|
||||
self.context.last_active = datetime.now().isoformat()
|
||||
self._persistence.save(self)
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"depth": self.context.depth,
|
||||
"parent_session_id": self.context.parent_session_id,
|
||||
"root_session_id": self.context.root_session_id,
|
||||
}
|
||||
|
||||
async def process_message(self, message: str, response: str) -> None:
|
||||
"""处理消息并记录到历史"""
|
||||
self.context.message_count += 1
|
||||
self.context.last_active = datetime.now().isoformat()
|
||||
self._history.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": message,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
self._history.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": response,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
self._persistence.save(self)
|
||||
|
||||
async def spawn_child_session(self, user_id: str | None = None) -> "AgentSession":
|
||||
"""创建子会话"""
|
||||
child = AgentSession(
|
||||
user_id=user_id or self.context.user_id,
|
||||
parent_session_id=self.session_id,
|
||||
)
|
||||
child.context.root_session_id = self.context.root_session_id or self.session_id
|
||||
await child.initialize()
|
||||
return child
|
||||
|
||||
async def get_session_summary(self) -> dict[str, Any]:
|
||||
"""获取会话摘要"""
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"parent_session_id": self.context.parent_session_id,
|
||||
"root_session_id": self.context.root_session_id,
|
||||
"depth": self.context.depth,
|
||||
"user_id": self.context.user_id,
|
||||
"created_at": self.context.created_at,
|
||||
"last_active": self.context.last_active,
|
||||
"message_count": self.context.message_count,
|
||||
"history_length": len(self._history),
|
||||
}
|
||||
|
||||
async def persist(self) -> bool:
|
||||
"""持久化会话"""
|
||||
return self._persistence.save(self)
|
||||
|
||||
def get_history(self) -> list[dict[str, Any]]:
|
||||
"""获取会话历史"""
|
||||
return self._history.copy()
|
||||
|
||||
def add_metadata(self, key: str, value: Any) -> None:
|
||||
"""添加会话元数据"""
|
||||
self.context.metadata[key] = value
|
||||
|
||||
def get_metadata(self, key: str) -> Any:
|
||||
"""获取会话元数据"""
|
||||
return self.context.metadata.get(key)
|
||||
|
||||
|
||||
# 全局会话存储(内存中)
|
||||
_sessions: dict[str, AgentSession] = {}
|
||||
|
||||
|
||||
def get_agent_session(session_id: str) -> AgentSession | None:
|
||||
"""获取会话"""
|
||||
return _sessions.get(session_id)
|
||||
|
||||
|
||||
def create_agent_session(
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
parent_session_id: str | None = None,
|
||||
) -> AgentSession:
|
||||
"""创建新会话"""
|
||||
session = AgentSession(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
parent_session_id=parent_session_id,
|
||||
)
|
||||
_sessions[session.session_id] = session
|
||||
return session
|
||||
@@ -1,5 +1,7 @@
|
||||
"""插件 Skills 加载器 - Phase 9.2"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.skills.metadata import SkillMetadata
|
||||
from app.agents.plugins.manager import get_plugin_manager
|
||||
|
||||
|
||||
@@ -23,6 +23,11 @@ from app.routers import (
|
||||
log_router,
|
||||
system_router,
|
||||
brain_router,
|
||||
hooks_router,
|
||||
plugins_router,
|
||||
marketplace_router,
|
||||
agent_skills_router,
|
||||
agent_sessions_router,
|
||||
)
|
||||
from app.routers.scheduler import router as scheduler_router
|
||||
from app.services.scheduler_service import start_scheduler, stop_scheduler, get_scheduler_status
|
||||
@@ -40,15 +45,15 @@ import os
|
||||
|
||||
|
||||
INSECURE_SECRET_KEYS = {
|
||||
'change-me-in-production',
|
||||
'change-me-to-a-random-secret-key',
|
||||
'jarvis-secret-key-change-in-production',
|
||||
"change-me-in-production",
|
||||
"change-me-to-a-random-secret-key",
|
||||
"jarvis-secret-key-change-in-production",
|
||||
}
|
||||
|
||||
|
||||
def validate_startup_security() -> None:
|
||||
if not settings.DEBUG and settings.SECRET_KEY in INSECURE_SECRET_KEYS:
|
||||
raise RuntimeError('SECRET_KEY must be changed before running with DEBUG disabled')
|
||||
raise RuntimeError("SECRET_KEY must be changed before running with DEBUG disabled")
|
||||
|
||||
|
||||
async def run_startup() -> None:
|
||||
@@ -117,6 +122,11 @@ app.include_router(log_router)
|
||||
app.include_router(system_router)
|
||||
app.include_router(brain_router)
|
||||
app.include_router(scheduler_router)
|
||||
app.include_router(hooks_router)
|
||||
app.include_router(plugins_router)
|
||||
app.include_router(marketplace_router)
|
||||
app.include_router(agent_skills_router)
|
||||
app.include_router(agent_sessions_router)
|
||||
|
||||
|
||||
@app.get("/api/health")
|
||||
|
||||
@@ -15,3 +15,8 @@ from app.routers.skill import router as skill_router
|
||||
from app.routers.log import router as log_router
|
||||
from app.routers.system import router as system_router
|
||||
from app.routers.brain import router as brain_router
|
||||
from app.routers.hooks import router as hooks_router
|
||||
from app.routers.plugins import router as plugins_router
|
||||
from app.routers.plugins import _marketplace_router as marketplace_router
|
||||
from app.routers.agent_skills import router as agent_skills_router
|
||||
from app.routers.agent_sessions import router as agent_sessions_router
|
||||
|
||||
113
backend/app/routers/agent_sessions.py
Normal file
113
backend/app/routers/agent_sessions.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Agent Session API 路由 - Phase 10.3"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.agents.session.manager import AgentSession, create_agent_session, get_agent_session
|
||||
|
||||
router = APIRouter(prefix="/api/agent/sessions", tags=["Agent Sessions"])
|
||||
|
||||
|
||||
@router.post("", response_model=dict[str, Any])
|
||||
async def create_session(
|
||||
user_id: str | None = None,
|
||||
parent_session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""创建新会话"""
|
||||
session = create_agent_session(
|
||||
user_id=user_id,
|
||||
parent_session_id=parent_session_id,
|
||||
)
|
||||
return await session.initialize()
|
||||
|
||||
|
||||
@router.get("/{session_id}", response_model=dict[str, Any])
|
||||
async def get_session(session_id: str) -> dict[str, Any]:
|
||||
"""获取会话信息"""
|
||||
session = get_agent_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
||||
return await session.get_session_summary()
|
||||
|
||||
|
||||
@router.post("/{session_id}/message", response_model=dict[str, str])
|
||||
async def process_message(
|
||||
session_id: str,
|
||||
message: str,
|
||||
response: str,
|
||||
) -> dict[str, str]:
|
||||
"""处理消息"""
|
||||
session = get_agent_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
||||
await session.process_message(message, response)
|
||||
return {"status": "recorded", "session_id": session_id}
|
||||
|
||||
|
||||
@router.post("/{session_id}/spawn", response_model=dict[str, Any])
|
||||
async def spawn_child_session(
|
||||
session_id: str,
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""创建子会话"""
|
||||
session = get_agent_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
||||
child = await session.spawn_child_session(user_id=user_id)
|
||||
return await child.get_session_summary()
|
||||
|
||||
|
||||
@router.get("/{session_id}/history", response_model=dict[str, Any])
|
||||
async def get_session_history(session_id: str) -> dict[str, Any]:
|
||||
"""获取会话历史"""
|
||||
session = get_agent_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"history": session.get_history(),
|
||||
"count": len(session._history),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{session_id}/persist", response_model=dict[str, str])
|
||||
async def persist_session(session_id: str) -> dict[str, str]:
|
||||
"""持久化会话"""
|
||||
session = get_agent_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
||||
success = await session.persist()
|
||||
if success:
|
||||
return {"status": "persisted", "session_id": session_id}
|
||||
raise HTTPException(status_code=500, detail="Failed to persist session")
|
||||
|
||||
|
||||
@router.post("/{session_id}/metadata", response_model=dict[str, Any])
|
||||
async def set_session_metadata(
|
||||
session_id: str,
|
||||
key: str,
|
||||
value: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""设置会话元数据"""
|
||||
session = get_agent_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
||||
session.add_metadata(key, value)
|
||||
await session.persist()
|
||||
return {"key": key, "value": value}
|
||||
|
||||
|
||||
@router.get("/{session_id}/metadata/{key}", response_model=dict[str, Any])
|
||||
async def get_session_metadata(
|
||||
session_id: str,
|
||||
key: str,
|
||||
) -> dict[str, Any]:
|
||||
"""获取会话元数据"""
|
||||
session = get_agent_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
||||
value = session.get_metadata(key)
|
||||
if value is None:
|
||||
raise HTTPException(status_code=404, detail=f"Metadata key '{key}' not found")
|
||||
return {"key": key, "value": value}
|
||||
126
backend/app/routers/agent_skills.py
Normal file
126
backend/app/routers/agent_skills.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Agent Skills API 路由 - Phase 9.6
|
||||
|
||||
使用新的 SkillRegistry (file-based) 而不是 DB-based skill 系统。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.agents.skills.registry import get_skill_registry, SkillRegistry
|
||||
|
||||
router = APIRouter(prefix="/api/agent/skills", tags=["Agent Skills"])
|
||||
|
||||
|
||||
def _skill_to_dict(skill) -> dict[str, Any]:
|
||||
"""将 SkillMetadata 转换为字典"""
|
||||
return {
|
||||
"name": skill.name,
|
||||
"description": skill.description,
|
||||
"tags": skill.tags,
|
||||
"triggers": skill.triggers,
|
||||
"enabled": skill.enabled,
|
||||
"content_preview": skill.content[:200] + "..."
|
||||
if len(skill.content) > 200
|
||||
else skill.content,
|
||||
}
|
||||
|
||||
|
||||
@router.get("", response_model=dict[str, Any])
|
||||
async def list_agent_skills() -> dict[str, Any]:
|
||||
"""列出所有已加载的 Agent Skills"""
|
||||
registry = get_skill_registry()
|
||||
skills = registry.list_all()
|
||||
return {
|
||||
"skills": [_skill_to_dict(s) for s in skills],
|
||||
"count": len(skills),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/search", response_model=dict[str, Any])
|
||||
async def search_agent_skills(
|
||||
query: str,
|
||||
) -> dict[str, Any]:
|
||||
"""搜索 Skills"""
|
||||
registry = get_skill_registry()
|
||||
results = registry.search(query)
|
||||
return {
|
||||
"skills": [_skill_to_dict(s) for s in results],
|
||||
"count": len(results),
|
||||
"query": query,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{skill_name}", response_model=dict[str, Any])
|
||||
async def get_agent_skill(skill_name: str) -> dict[str, Any]:
|
||||
"""获取指定 Skill 详情"""
|
||||
registry = get_skill_registry()
|
||||
skill = registry.get_skill(skill_name)
|
||||
if not skill:
|
||||
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
||||
return {
|
||||
"name": skill.name,
|
||||
"description": skill.description,
|
||||
"tags": skill.tags,
|
||||
"triggers": skill.triggers,
|
||||
"enabled": skill.enabled,
|
||||
"content": skill.content,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{skill_name}/context", response_model=dict[str, str])
|
||||
async def get_skill_context(skill_name: str) -> dict[str, str]:
|
||||
"""获取 Skill 上下文字符串"""
|
||||
registry = get_skill_registry()
|
||||
context = registry.get_skill_context([skill_name])
|
||||
if not context:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Skill '{skill_name}' not found or not enabled"
|
||||
)
|
||||
return {"skill_name": skill_name, "context": context}
|
||||
|
||||
|
||||
@router.post("/context/batch", response_model=dict[str, str])
|
||||
async def get_batch_skill_context(
|
||||
skill_names: list[str],
|
||||
) -> dict[str, str]:
|
||||
"""批量获取多个 Skill 的上下文"""
|
||||
registry = get_skill_registry()
|
||||
context = registry.get_skill_context(skill_names)
|
||||
return {"skills": skill_names, "context": context}
|
||||
|
||||
|
||||
@router.post("/reload", response_model=dict[str, Any])
|
||||
async def reload_skills(
|
||||
skills_dir: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""重新加载所有 Skills"""
|
||||
registry = get_skill_registry()
|
||||
# 清除旧 skills
|
||||
for name in list(registry._skills.keys()):
|
||||
registry.unregister(name)
|
||||
# 重新加载
|
||||
count = registry.load_all(skills_dir)
|
||||
return {"loaded": count, "message": f"Loaded {count} skills"}
|
||||
|
||||
|
||||
@router.post("/{skill_name}/enable", response_model=dict[str, str])
|
||||
async def enable_skill(skill_name: str) -> dict[str, str]:
|
||||
"""启用 Skill"""
|
||||
registry = get_skill_registry()
|
||||
skill = registry.get_skill(skill_name)
|
||||
if not skill:
|
||||
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
||||
skill.enabled = True
|
||||
return {"status": "enabled", "skill_name": skill_name}
|
||||
|
||||
|
||||
@router.post("/{skill_name}/disable", response_model=dict[str, str])
|
||||
async def disable_skill(skill_name: str) -> dict[str, str]:
|
||||
"""禁用 Skill"""
|
||||
registry = get_skill_registry()
|
||||
skill = registry.get_skill(skill_name)
|
||||
if not skill:
|
||||
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
||||
skill.enabled = False
|
||||
return {"status": "disabled", "skill_name": skill_name}
|
||||
241
backend/app/routers/hooks.py
Normal file
241
backend/app/routers/hooks.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""Hook API 路由 - Phase 7.5"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.agents.tools.hooks import HookType
|
||||
from app.agents.tools.hooks.builtins import (
|
||||
AuditLogHook,
|
||||
DangerousConfirmationHook,
|
||||
SecurityScanHook,
|
||||
)
|
||||
from app.agents.tools.hooks.config import (
|
||||
HookConfigEntry,
|
||||
get_hook_config_persistence,
|
||||
)
|
||||
from app.agents.tools.hooks.manager import get_hook_manager
|
||||
|
||||
router = APIRouter(prefix="/api/hooks", tags=["Hooks"])
|
||||
|
||||
|
||||
class HookInfo(BaseModel):
|
||||
"""Hook 信息"""
|
||||
|
||||
name: str
|
||||
hook_type: str
|
||||
description: str
|
||||
builtin: bool
|
||||
|
||||
|
||||
class HookConfigUpdate(BaseModel):
|
||||
"""更新 Hook 配置"""
|
||||
|
||||
entries: list[HookConfigEntry]
|
||||
|
||||
|
||||
class HookConfigResponse(BaseModel):
|
||||
"""Hook 配置响应"""
|
||||
|
||||
entries: list[dict[str, Any]]
|
||||
count: int
|
||||
|
||||
|
||||
class HookStatusResponse(BaseModel):
|
||||
"""Hook 状态响应"""
|
||||
|
||||
name: str
|
||||
enabled: bool
|
||||
hook_type: str
|
||||
registered: bool
|
||||
|
||||
|
||||
# 内置 Hook 注册表
|
||||
BUILTIN_HOOKS: dict[str, dict[str, str]] = {
|
||||
"audit_log": {
|
||||
"name": "audit_log",
|
||||
"hook_type": "pre_tool_use,post_tool_use,tool_error",
|
||||
"description": "审计日志 Hook - 记录所有工具调用",
|
||||
"class": "AuditLogHook",
|
||||
},
|
||||
"dangerous_confirmation": {
|
||||
"name": "dangerous_confirmation",
|
||||
"hook_type": "pre_tool_use",
|
||||
"description": "危险操作确认 Hook - 拦截危险工具调用",
|
||||
"class": "DangerousConfirmationHook",
|
||||
},
|
||||
"security_scan": {
|
||||
"name": "security_scan",
|
||||
"hook_type": "post_tool_use",
|
||||
"description": "安全扫描 Hook - 检测敏感信息泄露",
|
||||
"class": "SecurityScanHook",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get("/available", response_model=list[HookInfo])
|
||||
async def list_available_hooks() -> list[HookInfo]:
|
||||
"""列出所有可用的内置 Hook"""
|
||||
return [
|
||||
HookInfo(
|
||||
name=info["name"],
|
||||
hook_type=info["hook_type"],
|
||||
description=info["description"],
|
||||
builtin=True,
|
||||
)
|
||||
for info in BUILTIN_HOOKS.values()
|
||||
]
|
||||
|
||||
|
||||
@router.get("/config", response_model=HookConfigResponse)
|
||||
async def get_hook_config() -> HookConfigResponse:
|
||||
"""获取当前 Hook 配置"""
|
||||
persistence = get_hook_config_persistence()
|
||||
entries = persistence.load_config()
|
||||
return HookConfigResponse(
|
||||
entries=[vars(e) if isinstance(e, HookConfigEntry) else e for e in entries],
|
||||
count=len(entries),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/config", response_model=HookConfigResponse)
|
||||
async def update_hook_config(
|
||||
entries: list[HookConfigEntry],
|
||||
) -> HookConfigResponse:
|
||||
"""更新 Hook 配置"""
|
||||
persistence = get_hook_config_persistence()
|
||||
success = persistence.save_config(entries)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to save hook config")
|
||||
|
||||
# 应用配置到 HookManager
|
||||
manager = get_hook_manager()
|
||||
manager.clear() # 清除旧配置
|
||||
persistence.apply_config() # 应用新配置
|
||||
|
||||
return HookConfigResponse(
|
||||
entries=[vars(e) if isinstance(e, HookConfigEntry) else e for e in entries],
|
||||
count=len(entries),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/apply-config", response_model=dict[str, Any])
|
||||
async def apply_hook_config() -> dict[str, Any]:
|
||||
"""应用配置文件到 HookManager"""
|
||||
persistence = get_hook_config_persistence()
|
||||
manager = get_hook_manager()
|
||||
manager.clear()
|
||||
count = persistence.apply_config()
|
||||
return {"applied": count, "message": f"Applied {count} hook configurations"}
|
||||
|
||||
|
||||
@router.get("/status", response_model=list[HookStatusResponse])
|
||||
async def get_hook_status() -> list[HookStatusResponse]:
|
||||
"""获取所有已注册 Hook 的状态"""
|
||||
manager = get_hook_manager()
|
||||
all_hooks = manager.list_all()
|
||||
|
||||
# 按名称索引已注册的 hooks
|
||||
registered: dict[str, dict[str, Any]] = {}
|
||||
for hook in all_hooks:
|
||||
registered[hook.name] = {
|
||||
"name": hook.name,
|
||||
"enabled": hook.enabled,
|
||||
"hook_type": hook.hook_type.value,
|
||||
"registered": True,
|
||||
}
|
||||
|
||||
# 合并内置 Hook 信息
|
||||
result: list[HookStatusResponse] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
# 先添加已注册的
|
||||
for hook in all_hooks:
|
||||
result.append(
|
||||
HookStatusResponse(
|
||||
name=hook.name,
|
||||
enabled=hook.enabled,
|
||||
hook_type=hook.hook_type.value,
|
||||
registered=True,
|
||||
)
|
||||
)
|
||||
seen.add(hook.name)
|
||||
|
||||
# 再添加内置但未注册的
|
||||
for name, info in BUILTIN_HOOKS.items():
|
||||
if name not in seen:
|
||||
result.append(
|
||||
HookStatusResponse(
|
||||
name=name,
|
||||
enabled=False,
|
||||
hook_type=info["hook_type"],
|
||||
registered=False,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/{name}/enable", response_model=dict[str, str])
|
||||
async def enable_hook(name: str) -> dict[str, str]:
|
||||
"""启用指定 Hook"""
|
||||
manager = get_hook_manager()
|
||||
if manager.enable(name):
|
||||
return {"status": "enabled", "name": name}
|
||||
raise HTTPException(status_code=404, detail=f"Hook '{name}' not found")
|
||||
|
||||
|
||||
@router.post("/{name}/disable", response_model=dict[str, str])
|
||||
async def disable_hook(name: str) -> dict[str, str]:
|
||||
"""禁用指定 Hook"""
|
||||
manager = get_hook_manager()
|
||||
if manager.disable(name):
|
||||
return {"status": "disabled", "name": name}
|
||||
raise HTTPException(status_code=404, detail=f"Hook '{name}' not found")
|
||||
|
||||
|
||||
@router.post("/register-builtin", response_model=dict[str, str])
|
||||
async def register_builtin_hook(
|
||||
name: str,
|
||||
hook_type: str = "pre_tool_use",
|
||||
) -> dict[str, str]:
|
||||
"""注册内置 Hook 到 HookManager"""
|
||||
from app.agents.tools.hooks.types import HookDefinition, HookTrigger
|
||||
|
||||
manager = get_hook_manager()
|
||||
|
||||
if name == "audit_log":
|
||||
hook_instance = AuditLogHook()
|
||||
handler = hook_instance.pre_tool_use
|
||||
hook_types = [HookType.PRE_TOOL_USE, HookType.POST_TOOL_USE, HookType.TOOL_ERROR]
|
||||
elif name == "dangerous_confirmation":
|
||||
hook_instance = DangerousConfirmationHook()
|
||||
handler = hook_instance.pre_tool_use
|
||||
hook_types = [HookType.PRE_TOOL_USE]
|
||||
elif name == "security_scan":
|
||||
hook_instance = SecurityScanHook()
|
||||
handler = hook_instance.post_tool_use
|
||||
hook_types = [HookType.POST_TOOL_USE]
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown builtin hook: {name}")
|
||||
|
||||
registered = []
|
||||
for ht in hook_types:
|
||||
hook_def = HookDefinition(
|
||||
name=f"{name}_{ht.value}",
|
||||
hook_type=ht,
|
||||
trigger=HookTrigger(),
|
||||
handler=handler,
|
||||
priority=0,
|
||||
enabled=True,
|
||||
description=f"Built-in {name} hook",
|
||||
)
|
||||
manager.register(hook_def)
|
||||
registered.append(ht.value)
|
||||
|
||||
return {
|
||||
"status": "registered",
|
||||
"name": name,
|
||||
"hook_types": ", ".join(registered),
|
||||
}
|
||||
169
backend/app/routers/plugins.py
Normal file
169
backend/app/routers/plugins.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Plugin API 路由 - Phase 8.6"""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.agents.plugins import get_plugin_manager, PluginManifest
|
||||
|
||||
router = APIRouter(prefix="/api/plugins", tags=["Plugins"])
|
||||
|
||||
|
||||
class PluginInfo(BaseModel):
|
||||
"""插件信息"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
version: str
|
||||
description: str
|
||||
author: str
|
||||
enabled: bool
|
||||
main: str
|
||||
|
||||
|
||||
class PluginInstallRequest(BaseModel):
|
||||
"""插件安装请求"""
|
||||
|
||||
plugin_path: str
|
||||
|
||||
|
||||
class PluginListResponse(BaseModel):
|
||||
"""插件列表响应"""
|
||||
|
||||
plugins: list[dict[str, Any]]
|
||||
count: int
|
||||
|
||||
|
||||
# 全局插件市场(简单内存实现)
|
||||
_plugin_marketplace: list[dict[str, str]] = []
|
||||
|
||||
|
||||
def _manifest_to_dict(manifest: PluginManifest, enabled: bool) -> dict[str, Any]:
|
||||
"""将 PluginManifest 转换为字典"""
|
||||
return {
|
||||
"id": manifest.id,
|
||||
"name": manifest.name,
|
||||
"version": manifest.version,
|
||||
"description": manifest.description,
|
||||
"author": manifest.author,
|
||||
"enabled": enabled,
|
||||
"main": manifest.main,
|
||||
}
|
||||
|
||||
|
||||
@router.get("", response_model=PluginListResponse)
|
||||
async def list_plugins() -> PluginListResponse:
|
||||
"""列出所有已安装的插件"""
|
||||
manager = get_plugin_manager()
|
||||
plugins = manager.list_plugins()
|
||||
result = []
|
||||
for p in plugins:
|
||||
enabled = manager.is_enabled(p.id)
|
||||
result.append(_manifest_to_dict(p, enabled))
|
||||
return PluginListResponse(plugins=result, count=len(result))
|
||||
|
||||
|
||||
@router.get("/{plugin_id}", response_model=dict[str, Any])
|
||||
async def get_plugin(plugin_id: str) -> dict[str, Any]:
|
||||
"""获取指定插件信息"""
|
||||
manager = get_plugin_manager()
|
||||
manifest = manager.get_plugin(plugin_id)
|
||||
if not manifest:
|
||||
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found")
|
||||
enabled = manager.is_enabled(plugin_id)
|
||||
return _manifest_to_dict(manifest, enabled)
|
||||
|
||||
|
||||
@router.post("/install", response_model=dict[str, str])
|
||||
async def install_plugin(request: PluginInstallRequest) -> dict[str, str]:
|
||||
"""安装插件"""
|
||||
manager = get_plugin_manager()
|
||||
if not os.path.exists(request.plugin_path):
|
||||
raise HTTPException(status_code=400, detail="Plugin path does not exist")
|
||||
|
||||
if manager.install(request.plugin_path):
|
||||
return {"status": "installed", "path": request.plugin_path}
|
||||
raise HTTPException(status_code=500, detail="Failed to install plugin")
|
||||
|
||||
|
||||
@router.post("/{plugin_id}/enable", response_model=dict[str, str])
|
||||
async def enable_plugin(plugin_id: str) -> dict[str, str]:
|
||||
"""启用插件"""
|
||||
manager = get_plugin_manager()
|
||||
if manager.enable(plugin_id):
|
||||
return {"status": "enabled", "plugin_id": plugin_id}
|
||||
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found")
|
||||
|
||||
|
||||
@router.post("/{plugin_id}/disable", response_model=dict[str, str])
|
||||
async def disable_plugin(plugin_id: str) -> dict[str, str]:
|
||||
"""禁用插件"""
|
||||
manager = get_plugin_manager()
|
||||
if manager.disable(plugin_id):
|
||||
return {"status": "disabled", "plugin_id": plugin_id}
|
||||
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found")
|
||||
|
||||
|
||||
@router.delete("/{plugin_id}", response_model=dict[str, str])
|
||||
async def uninstall_plugin(plugin_id: str) -> dict[str, str]:
|
||||
"""卸载插件"""
|
||||
manager = get_plugin_manager()
|
||||
if manager.uninstall(plugin_id):
|
||||
return {"status": "uninstalled", "plugin_id": plugin_id}
|
||||
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found")
|
||||
|
||||
|
||||
@router.post("/{plugin_id}/reload", response_model=dict[str, str])
|
||||
async def reload_plugin(plugin_id: str) -> dict[str, str]:
|
||||
"""重新加载插件"""
|
||||
manager = get_plugin_manager()
|
||||
if manager.reload(plugin_id):
|
||||
return {"status": "reloaded", "plugin_id": plugin_id}
|
||||
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found")
|
||||
|
||||
|
||||
# === Plugin Marketplace ===
|
||||
|
||||
_marketplace_router = APIRouter(prefix="/api/marketplace", tags=["Plugin Marketplace"])
|
||||
|
||||
|
||||
@_marketplace_router.get("/plugins", response_model=dict[str, Any])
|
||||
async def search_marketplace_plugins(
|
||||
query: str | None = None,
|
||||
category: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""搜索插件市场"""
|
||||
results = _plugin_marketplace
|
||||
if query:
|
||||
results = [
|
||||
p
|
||||
for p in results
|
||||
if query.lower() in p.get("name", "").lower()
|
||||
or query.lower() in p.get("description", "").lower()
|
||||
]
|
||||
if category:
|
||||
results = [p for p in results if p.get("category") == category]
|
||||
return {"plugins": results, "count": len(results)}
|
||||
|
||||
|
||||
@_marketplace_router.get("/plugins/{plugin_id}", response_model=dict[str, Any])
|
||||
async def get_marketplace_plugin(plugin_id: str) -> dict[str, Any]:
|
||||
"""获取市场中的插件详情"""
|
||||
for plugin in _plugin_marketplace:
|
||||
if plugin.get("id") == plugin_id:
|
||||
return plugin
|
||||
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found in marketplace")
|
||||
|
||||
|
||||
@_marketplace_router.post("/plugins", response_model=dict[str, str])
|
||||
async def add_to_marketplace(plugin: dict[str, str]) -> dict[str, str]:
|
||||
"""添加插件到市场(仅供测试/开发)"""
|
||||
if "id" not in plugin or "name" not in plugin:
|
||||
raise HTTPException(status_code=400, detail="Plugin must have id and name")
|
||||
# 移除已存在的同 ID 插件
|
||||
global _plugin_marketplace
|
||||
_plugin_marketplace = [p for p in _plugin_marketplace if p.get("id") != plugin["id"]]
|
||||
_plugin_marketplace.append(plugin)
|
||||
return {"status": "added", "id": plugin["id"]}
|
||||
Reference in New Issue
Block a user