Compare commits
8 Commits
phase1-reg
...
fca7a7cf3d
| Author | SHA1 | Date | |
|---|---|---|---|
| fca7a7cf3d | |||
| d18167826e | |||
| 88955ed550 | |||
| a3fe4d24fc | |||
| e5bd492d74 | |||
| a7b6b5eb90 | |||
| aa0ef0fbea | |||
| 4972b4e6b1 |
220
backend/app/agents/background/executor.py
Normal file
220
backend/app/agents/background/executor.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""Background task executor - Phase 10.4"""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from .manager import (
|
||||
BackgroundTask,
|
||||
BackgroundTaskManager,
|
||||
BackgroundTaskStatus,
|
||||
get_background_task_manager,
|
||||
)
|
||||
|
||||
|
||||
class BackgroundExecutor:
|
||||
"""Executes background tasks with error handling and result storage.
|
||||
|
||||
Provides methods to execute tasks synchronously or asynchronously,
|
||||
with full integration into BackgroundTaskManager for tracking.
|
||||
"""
|
||||
|
||||
def __init__(self, task_manager: BackgroundTaskManager | None = None):
|
||||
"""Initialize the executor.
|
||||
|
||||
Args:
|
||||
task_manager: Optional BackgroundTaskManager instance.
|
||||
If not provided, uses the global singleton.
|
||||
"""
|
||||
self._task_manager = task_manager or get_background_task_manager()
|
||||
self._executors: dict[str, asyncio.Task] = {}
|
||||
|
||||
async def execute_task(
|
||||
self,
|
||||
task_id: str,
|
||||
func: Callable[..., Coroutine[Any, Any, Any]],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> BackgroundTask:
|
||||
"""Execute a specific task by ID.
|
||||
|
||||
Args:
|
||||
task_id: Unique task identifier
|
||||
func: Async function to execute
|
||||
*args: Positional arguments for the function
|
||||
**kwargs: Keyword arguments for the function
|
||||
|
||||
Returns:
|
||||
The BackgroundTask with result or error
|
||||
"""
|
||||
# Get or create task record
|
||||
task = self._task_manager.get_task_status(task_id)
|
||||
if task is None:
|
||||
# Create a new task record if one doesn't exist
|
||||
task = BackgroundTask(
|
||||
id=task_id,
|
||||
name=f"executor_task_{task_id}",
|
||||
status=BackgroundTaskStatus.PENDING,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
self._task_manager._tasks[task_id] = task
|
||||
|
||||
# Update status to running
|
||||
task.status = BackgroundTaskStatus.RUNNING
|
||||
task.started_at = datetime.now()
|
||||
|
||||
try:
|
||||
# Execute the async function
|
||||
result = await func(*args, **kwargs)
|
||||
task.status = BackgroundTaskStatus.COMPLETED
|
||||
task.result = result
|
||||
except Exception as e:
|
||||
task.status = BackgroundTaskStatus.FAILED
|
||||
task.error = f"{type(e).__name__}: {str(e)}"
|
||||
task.result = None
|
||||
finally:
|
||||
task.completed_at = datetime.now()
|
||||
# Clean up executor reference
|
||||
if task_id in self._executors:
|
||||
del self._executors[task_id]
|
||||
|
||||
return task
|
||||
|
||||
async def execute_async(
|
||||
self,
|
||||
task_id: str,
|
||||
func: Callable[..., Coroutine[Any, Any, Any]],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Execute a task asynchronously in the background.
|
||||
|
||||
Args:
|
||||
task_id: Unique task identifier
|
||||
func: Async function to execute
|
||||
*args: Positional arguments for the function
|
||||
**kwargs: Keyword arguments for the function
|
||||
|
||||
Returns:
|
||||
The task ID
|
||||
"""
|
||||
# Create task record if it doesn't exist
|
||||
if self._task_manager.get_task_status(task_id) is None:
|
||||
self._task_manager._tasks[task_id] = BackgroundTask(
|
||||
id=task_id,
|
||||
name=f"async_task_{task_id}",
|
||||
status=BackgroundTaskStatus.PENDING,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
# Create and store the asyncio task
|
||||
async_task = asyncio.create_task(self.execute_task(task_id, func, *args, **kwargs))
|
||||
self._executors[task_id] = async_task
|
||||
|
||||
return task_id
|
||||
|
||||
def cancel_task(self, task_id: str) -> bool:
|
||||
"""Cancel a running task.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to cancel
|
||||
|
||||
Returns:
|
||||
True if cancelled, False if not found or not running
|
||||
"""
|
||||
if task_id not in self._executors:
|
||||
return False
|
||||
|
||||
self._executors[task_id].cancel()
|
||||
del self._executors[task_id]
|
||||
|
||||
# Update task status
|
||||
task = self._task_manager.get_task_status(task_id)
|
||||
if task:
|
||||
task.status = BackgroundTaskStatus.CANCELLED
|
||||
task.completed_at = datetime.now()
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_task_result(self, task_id: str) -> Any:
|
||||
"""Get the result of a completed task.
|
||||
|
||||
Args:
|
||||
task_id: The task ID
|
||||
|
||||
Returns:
|
||||
The task result or None if not found/not completed
|
||||
"""
|
||||
task = self._task_manager.get_task_status(task_id)
|
||||
if task and task.status == BackgroundTaskStatus.COMPLETED:
|
||||
return task.result
|
||||
return None
|
||||
|
||||
def get_task_error(self, task_id: str) -> str | None:
|
||||
"""Get the error of a failed task.
|
||||
|
||||
Args:
|
||||
task_id: The task ID
|
||||
|
||||
Returns:
|
||||
The error message or None if not found/not failed
|
||||
"""
|
||||
task = self._task_manager.get_task_status(task_id)
|
||||
if task and task.status == BackgroundTaskStatus.FAILED:
|
||||
return task.error
|
||||
return None
|
||||
|
||||
def is_task_running(self, task_id: str) -> bool:
|
||||
"""Check if a task is currently running.
|
||||
|
||||
Args:
|
||||
task_id: The task ID
|
||||
|
||||
Returns:
|
||||
True if running, False otherwise
|
||||
"""
|
||||
return task_id in self._executors
|
||||
|
||||
def wait_for_task(self, task_id: str, timeout: float | None = None) -> BackgroundTask:
|
||||
"""Wait for a task to complete.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to wait for
|
||||
timeout: Optional timeout in seconds
|
||||
|
||||
Returns:
|
||||
The completed BackgroundTask
|
||||
|
||||
Raises:
|
||||
asyncio.TimeoutError: If task doesn't complete within timeout
|
||||
asyncio.CancelledError: If task is cancelled
|
||||
"""
|
||||
if task_id not in self._executors:
|
||||
task = self._task_manager.get_task_status(task_id)
|
||||
if task:
|
||||
return task
|
||||
raise ValueError(f"Task {task_id} not found")
|
||||
|
||||
async def wait_task() -> BackgroundTask:
|
||||
await self._executors[task_id]
|
||||
return self._task_manager.get_task_status(task_id)
|
||||
|
||||
return asyncio.run_until_complete(asyncio.wait_for(wait_task(), timeout=timeout))
|
||||
|
||||
@property
|
||||
def task_manager(self) -> BackgroundTaskManager:
|
||||
"""Get the underlying task manager."""
|
||||
return self._task_manager
|
||||
|
||||
|
||||
# Global executor instance
|
||||
_executor: BackgroundExecutor | None = None
|
||||
|
||||
|
||||
def get_background_executor() -> BackgroundExecutor:
|
||||
"""Get the global BackgroundExecutor instance."""
|
||||
global _executor
|
||||
if _executor is None:
|
||||
_executor = BackgroundExecutor()
|
||||
return _executor
|
||||
119
backend/app/agents/background/manager.py
Normal file
119
backend/app/agents/background/manager.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""后台任务系统 - Phase 10.4"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class BackgroundTaskStatus(Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackgroundTask:
|
||||
"""后台任务"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
status: BackgroundTaskStatus
|
||||
created_at: datetime
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class BackgroundTaskManager:
|
||||
"""后台任务管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self._tasks: dict[str, BackgroundTask] = {}
|
||||
self._.coroutines: dict[str, asyncio.Task] = {}
|
||||
|
||||
def submit_task(self, name: str, coro: Any, *args, **kwargs) -> str:
|
||||
"""提交后台任务
|
||||
|
||||
Args:
|
||||
name: 任务名称
|
||||
coro: 协程函数
|
||||
*args: 位置参数
|
||||
**kwargs: 关键字参数
|
||||
|
||||
Returns:
|
||||
任务 ID
|
||||
"""
|
||||
task_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# 创建任务记录
|
||||
self._tasks[task_id] = BackgroundTask(
|
||||
id=task_id,
|
||||
name=name,
|
||||
status=BackgroundTaskStatus.PENDING,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
# 创建 asyncio task
|
||||
async def run_task():
|
||||
self._tasks[task_id].status = BackgroundTaskStatus.RUNNING
|
||||
self._tasks[task_id].started_at = datetime.now()
|
||||
try:
|
||||
result = await coro(*args, **kwargs)
|
||||
self._tasks[task_id].status = BackgroundTaskStatus.COMPLETED
|
||||
self._tasks[task_id].result = result
|
||||
except Exception as e:
|
||||
self._tasks[task_id].status = BackgroundTaskStatus.FAILED
|
||||
self._tasks[task_id].error = str(e)
|
||||
finally:
|
||||
self._tasks[task_id].completed_at = datetime.now()
|
||||
if task_id in self._coroutines:
|
||||
del self._coroutines[task_id]
|
||||
|
||||
self._coroutines[task_id] = asyncio.create_task(run_task())
|
||||
return task_id
|
||||
|
||||
def cancel_task(self, task_id: str) -> bool:
|
||||
"""取消任务
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
|
||||
Returns:
|
||||
是否成功取消
|
||||
"""
|
||||
if task_id not in self._tasks:
|
||||
return False
|
||||
|
||||
if task_id in self._coroutines:
|
||||
self._coroutines[task_id].cancel()
|
||||
del self._coroutines[task_id]
|
||||
|
||||
self._tasks[task_id].status = BackgroundTaskStatus.CANCELLED
|
||||
self._tasks[task_id].completed_at = datetime.now()
|
||||
return True
|
||||
|
||||
def get_task_status(self, task_id: str) -> BackgroundTask | None:
|
||||
"""获取任务状态"""
|
||||
return self._tasks.get(task_id)
|
||||
|
||||
def list_tasks(self) -> list[BackgroundTask]:
|
||||
"""列出所有任务"""
|
||||
return list(self._tasks.values())
|
||||
|
||||
|
||||
# 全局单例
|
||||
_manager: BackgroundTaskManager | None = None
|
||||
|
||||
|
||||
def get_background_task_manager() -> BackgroundTaskManager:
|
||||
"""获取全局后台任务管理器"""
|
||||
global _manager
|
||||
if _manager is None:
|
||||
_manager = BackgroundTaskManager()
|
||||
return _manager
|
||||
146
backend/app/agents/background/scheduler.py
Normal file
146
backend/app/agents/background/scheduler.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Background task scheduler - Phase 10.4"""
|
||||
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.base import BaseTrigger
|
||||
|
||||
from .manager import BackgroundTaskManager, get_background_task_manager
|
||||
|
||||
|
||||
class BackgroundScheduler:
|
||||
"""Background task scheduler using APScheduler.
|
||||
|
||||
Integrates with BackgroundTaskManager for task tracking and execution.
|
||||
"""
|
||||
|
||||
def __init__(self, task_manager: BackgroundTaskManager | None = None):
|
||||
"""Initialize the scheduler.
|
||||
|
||||
Args:
|
||||
task_manager: Optional BackgroundTaskManager instance.
|
||||
If not provided, uses the global singleton.
|
||||
"""
|
||||
self._scheduler = AsyncIOScheduler()
|
||||
self._task_manager = task_manager or get_background_task_manager()
|
||||
self._job_tasks: dict[str, str] = {} # Maps APScheduler job_id to task_id
|
||||
|
||||
def add_job(
|
||||
self,
|
||||
func: Callable[..., Coroutine[Any, Any, Any]],
|
||||
trigger: BaseTrigger,
|
||||
args: tuple[Any, ...] | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
id: str | None = None,
|
||||
name: str | None = None,
|
||||
**apscheduler_kwargs: Any,
|
||||
) -> str:
|
||||
"""Add a job to the scheduler.
|
||||
|
||||
Args:
|
||||
func: Async function to execute
|
||||
trigger: APScheduler trigger (date, interval, cron, etc.)
|
||||
args: Positional arguments for the function
|
||||
kwargs: Keyword arguments for the function
|
||||
id: Unique job ID (auto-generated if not provided)
|
||||
name: Job name for display purposes
|
||||
**apscheduler_kwargs: Additional APScheduler options
|
||||
|
||||
Returns:
|
||||
The job ID
|
||||
"""
|
||||
job_id = id or f"job_{len(self._job_tasks)}"
|
||||
task_name = name or f"scheduled_task_{job_id}"
|
||||
|
||||
# Wrap the async function to integrate with BackgroundTaskManager
|
||||
async def wrapped_func() -> None:
|
||||
coro = func(*(args or ()), **(kwargs or {}))
|
||||
task_id = self._task_manager.submit_task(task_name, coro)
|
||||
self._job_tasks[job_id] = task_id
|
||||
|
||||
self._scheduler.add_job(
|
||||
wrapped_func,
|
||||
trigger=trigger,
|
||||
id=job_id,
|
||||
name=task_name,
|
||||
**apscheduler_kwargs,
|
||||
)
|
||||
|
||||
return job_id
|
||||
|
||||
def remove_job(self, job_id: str) -> bool:
|
||||
"""Remove a job from the scheduler.
|
||||
|
||||
Args:
|
||||
job_id: The ID of the job to remove
|
||||
|
||||
Returns:
|
||||
True if job was removed, False if job didn't exist
|
||||
"""
|
||||
try:
|
||||
self._scheduler.remove_job(job_id)
|
||||
# Clean up task mapping if exists
|
||||
if job_id in self._job_tasks:
|
||||
task_id = self._job_tasks.pop(job_id)
|
||||
# Cancel the background task if still running
|
||||
self._task_manager.cancel_task(task_id)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def list_jobs(self) -> list[dict[str, Any]]:
|
||||
"""List all scheduled jobs.
|
||||
|
||||
Returns:
|
||||
List of job information dictionaries
|
||||
"""
|
||||
jobs = self._scheduler.get_jobs()
|
||||
return [
|
||||
{
|
||||
"id": job.id,
|
||||
"name": job.name,
|
||||
"next_run_time": job.next_run_time,
|
||||
"trigger": str(job.trigger),
|
||||
}
|
||||
for job in jobs
|
||||
]
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the scheduler."""
|
||||
if not self._scheduler.running:
|
||||
self._scheduler.start()
|
||||
|
||||
def shutdown(self, wait: bool = True) -> None:
|
||||
"""Shutdown the scheduler.
|
||||
|
||||
Args:
|
||||
wait: Whether to wait for running jobs to complete
|
||||
"""
|
||||
if self._scheduler.running:
|
||||
self._scheduler.shutdown(wait=wait)
|
||||
|
||||
def pause(self) -> None:
|
||||
"""Pause the scheduler."""
|
||||
self._scheduler.pause()
|
||||
|
||||
def resume(self) -> None:
|
||||
"""Resume the scheduler."""
|
||||
self._scheduler.resume()
|
||||
|
||||
@property
|
||||
def task_manager(self) -> BackgroundTaskManager:
|
||||
"""Get the underlying task manager."""
|
||||
return self._task_manager
|
||||
|
||||
|
||||
# Global scheduler instance
|
||||
_scheduler: BackgroundScheduler | None = None
|
||||
|
||||
|
||||
def get_background_scheduler() -> BackgroundScheduler:
|
||||
"""Get the global BackgroundScheduler instance."""
|
||||
global _scheduler
|
||||
if _scheduler is None:
|
||||
_scheduler = BackgroundScheduler()
|
||||
return _scheduler
|
||||
508
backend/app/agents/coordinator.py
Normal file
508
backend/app/agents/coordinator.py
Normal file
@@ -0,0 +1,508 @@
|
||||
"""Agent 协调整器 - Phase 10.5
|
||||
|
||||
统一协调所有 Agent 组件:TeamLeader, RemoteTransport, BackgroundTaskManager, SessionManager
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.background.manager import BackgroundTaskManager, get_background_task_manager
|
||||
from app.agents.session.manager import AgentSession, create_agent_session, get_agent_session
|
||||
from app.agents.team.leader import TeamLeader
|
||||
from app.agents.transport.remote import RemoteTransport
|
||||
|
||||
|
||||
class AgentCoordinator:
|
||||
"""Agent 协调整器
|
||||
|
||||
统一协调所有 Agent 组件,提供单一入口处理各类 Agent 操作。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
background_manager: BackgroundTaskManager | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
background_manager: 后台任务管理器,None 则使用全局单例
|
||||
"""
|
||||
self._team_leaders: dict[str, TeamLeader] = {}
|
||||
self._remote_transport = RemoteTransport()
|
||||
self._background_manager = background_manager or get_background_task_manager()
|
||||
self._sessions: dict[str, AgentSession] = {}
|
||||
|
||||
# === Team 协作方法 ===
|
||||
|
||||
def create_team(self, team_id: str, members: list[str]) -> dict[str, Any]:
|
||||
"""创建团队
|
||||
|
||||
Args:
|
||||
team_id: 团队 ID
|
||||
members: 成员 ID 列表
|
||||
|
||||
Returns:
|
||||
团队创建结果
|
||||
"""
|
||||
if team_id in self._team_leaders:
|
||||
return {"status": "error", "message": f"Team '{team_id}' already exists"}
|
||||
|
||||
leader = TeamLeader(team_id=team_id, members=members)
|
||||
self._team_leaders[team_id] = leader
|
||||
return {
|
||||
"status": "created",
|
||||
"team_id": team_id,
|
||||
"members": members,
|
||||
}
|
||||
|
||||
def get_team(self, team_id: str) -> TeamLeader | None:
|
||||
"""获取团队
|
||||
|
||||
Args:
|
||||
team_id: 团队 ID
|
||||
|
||||
Returns:
|
||||
TeamLeader 或 None
|
||||
"""
|
||||
return self._team_leaders.get(team_id)
|
||||
|
||||
def assign_task(self, team_id: str, description: str, member: str) -> dict[str, Any]:
|
||||
"""创建并分配任务
|
||||
|
||||
Args:
|
||||
team_id: 团队 ID
|
||||
description: 任务描述
|
||||
member: 成员 ID
|
||||
|
||||
Returns:
|
||||
分配结果
|
||||
"""
|
||||
leader = self._team_leaders.get(team_id)
|
||||
if not leader:
|
||||
return {"status": "error", "message": f"Team '{team_id}' not found"}
|
||||
|
||||
task_id = leader.create_task(description)
|
||||
success = leader.assign_task(task_id, member)
|
||||
return {
|
||||
"status": "assigned" if success else "error",
|
||||
"task_id": task_id,
|
||||
"assignee": member,
|
||||
}
|
||||
|
||||
def broadcast_task(self, team_id: str, description: str) -> dict[str, Any]:
|
||||
"""广播任务给所有成员
|
||||
|
||||
Args:
|
||||
team_id: 团队 ID
|
||||
description: 任务描述
|
||||
|
||||
Returns:
|
||||
广播结果
|
||||
"""
|
||||
leader = self._team_leaders.get(team_id)
|
||||
if not leader:
|
||||
return {"status": "error", "message": f"Team '{team_id}' not found"}
|
||||
|
||||
task_ids = leader.broadcast_task(description)
|
||||
return {
|
||||
"status": "broadcast",
|
||||
"team_id": team_id,
|
||||
"task_ids": task_ids,
|
||||
"member_count": len(leader.members),
|
||||
}
|
||||
|
||||
def collect_team_results(self, team_id: str) -> dict[str, Any]:
|
||||
"""收集团队任务结果
|
||||
|
||||
Args:
|
||||
team_id: 团队 ID
|
||||
|
||||
Returns:
|
||||
收集结果
|
||||
"""
|
||||
leader = self._team_leaders.get(team_id)
|
||||
if not leader:
|
||||
return {"status": "error", "message": f"Team '{team_id}' not found"}
|
||||
|
||||
results = leader.collect_results()
|
||||
status = leader.get_team_status()
|
||||
return {
|
||||
"status": "collected",
|
||||
"team_id": team_id,
|
||||
"results": results,
|
||||
"completed": status["completed"],
|
||||
"failed": status["failed"],
|
||||
}
|
||||
|
||||
def get_team_status(self, team_id: str) -> dict[str, Any]:
|
||||
"""获取团队状态
|
||||
|
||||
Args:
|
||||
team_id: 团队 ID
|
||||
|
||||
Returns:
|
||||
团队状态
|
||||
"""
|
||||
leader = self._team_leaders.get(team_id)
|
||||
if not leader:
|
||||
return {"status": "error", "message": f"Team '{team_id}' not found"}
|
||||
|
||||
return leader.get_team_status()
|
||||
|
||||
# === 后台任务方法 ===
|
||||
|
||||
def submit_background_task(
|
||||
self,
|
||||
name: str,
|
||||
coro: Any,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
"""提交后台任务
|
||||
|
||||
Args:
|
||||
name: 任务名称
|
||||
coro: 协程函数
|
||||
*args: 位置参数
|
||||
**kwargs: 关键字参数
|
||||
|
||||
Returns:
|
||||
提交结果
|
||||
"""
|
||||
task_id = self._background_manager.submit_task(name, coro, *args, **kwargs)
|
||||
return {
|
||||
"status": "submitted",
|
||||
"task_id": task_id,
|
||||
"name": name,
|
||||
}
|
||||
|
||||
def cancel_background_task(self, task_id: str) -> dict[str, Any]:
|
||||
"""取消后台任务
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
|
||||
Returns:
|
||||
取消结果
|
||||
"""
|
||||
success = self._background_manager.cancel_task(task_id)
|
||||
return {
|
||||
"status": "cancelled" if success else "error",
|
||||
"task_id": task_id,
|
||||
}
|
||||
|
||||
def get_background_task_status(self, task_id: str) -> dict[str, Any]:
|
||||
"""获取后台任务状态
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
|
||||
Returns:
|
||||
任务状态
|
||||
"""
|
||||
task = self._background_manager.get_task_status(task_id)
|
||||
if not task:
|
||||
return {"status": "error", "message": f"Task '{task_id}' not found"}
|
||||
|
||||
return {
|
||||
"status": "found",
|
||||
"task_id": task.id,
|
||||
"name": task.name,
|
||||
"task_status": task.status.value,
|
||||
"result": task.result,
|
||||
"error": task.error,
|
||||
}
|
||||
|
||||
def list_background_tasks(self) -> dict[str, Any]:
|
||||
"""列出所有后台任务
|
||||
|
||||
Returns:
|
||||
任务列表
|
||||
"""
|
||||
tasks = self._background_manager.list_tasks()
|
||||
return {
|
||||
"status": "list",
|
||||
"count": len(tasks),
|
||||
"tasks": [
|
||||
{
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"status": t.status.value,
|
||||
}
|
||||
for t in tasks
|
||||
],
|
||||
}
|
||||
|
||||
# === 会话方法 ===
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
parent_session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""创建会话
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
parent_session_id: 父会话 ID
|
||||
|
||||
Returns:
|
||||
创建结果
|
||||
"""
|
||||
session = create_agent_session(
|
||||
user_id=user_id,
|
||||
parent_session_id=parent_session_id,
|
||||
)
|
||||
self._sessions[session.session_id] = session
|
||||
return {
|
||||
"status": "created",
|
||||
"session_id": session.session_id,
|
||||
"user_id": user_id,
|
||||
"parent_session_id": parent_session_id,
|
||||
}
|
||||
|
||||
def get_session(self, session_id: str) -> AgentSession | None:
|
||||
"""获取会话
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
|
||||
Returns:
|
||||
AgentSession 或 None
|
||||
"""
|
||||
return self._sessions.get(session_id) or get_agent_session(session_id)
|
||||
|
||||
async def process_session_message(
|
||||
self,
|
||||
session_id: str,
|
||||
message: str,
|
||||
response: str,
|
||||
) -> dict[str, Any]:
|
||||
"""处理会话消息
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
message: 用户消息
|
||||
response: 助手响应
|
||||
|
||||
Returns:
|
||||
处理结果
|
||||
"""
|
||||
session = self.get_session(session_id)
|
||||
if not session:
|
||||
return {"status": "error", "message": f"Session '{session_id}' not found"}
|
||||
|
||||
await session.process_message(message, response)
|
||||
return {
|
||||
"status": "processed",
|
||||
"session_id": session_id,
|
||||
"message_count": session.context.message_count,
|
||||
}
|
||||
|
||||
async def spawn_child_session(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""创建子会话
|
||||
|
||||
Args:
|
||||
session_id: 父会话 ID
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
创建结果
|
||||
"""
|
||||
session = self.get_session(session_id)
|
||||
if not session:
|
||||
return {"status": "error", "message": f"Session '{session_id}' not found"}
|
||||
|
||||
child = await session.spawn_child_session(user_id=user_id)
|
||||
self._sessions[child.session_id] = child
|
||||
return {
|
||||
"status": "spawned",
|
||||
"parent_session_id": session_id,
|
||||
"child_session_id": child.session_id,
|
||||
"depth": child.context.depth,
|
||||
}
|
||||
|
||||
def get_session_summary(self, session_id: str) -> dict[str, Any]:
|
||||
"""获取会话摘要
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
|
||||
Returns:
|
||||
会话摘要
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
session = self.get_session(session_id)
|
||||
if not session:
|
||||
return {"status": "error", "message": f"Session '{session_id}' not found"}
|
||||
|
||||
# get_session_summary is async, so we need to run it
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# Create a future
|
||||
future = asyncio.ensure_future(session.get_session_summary())
|
||||
return {"status": "found", "summary": future}
|
||||
else:
|
||||
return {
|
||||
"status": "found",
|
||||
"summary": loop.run_until_complete(session.get_session_summary()),
|
||||
}
|
||||
except RuntimeError:
|
||||
# No event loop, create one
|
||||
return {"status": "found", "summary": asyncio.run(session.get_session_summary())}
|
||||
|
||||
# === 远程传输方法 ===
|
||||
|
||||
def register_remote_handler(self, event_type: str, handler: Any) -> None:
|
||||
"""注册远程消息处理器
|
||||
|
||||
Args:
|
||||
event_type: 事件类型
|
||||
handler: 处理函数
|
||||
"""
|
||||
self._remote_transport.register_handler(event_type, handler)
|
||||
|
||||
async def send_remote_response(
|
||||
self,
|
||||
session_id: str,
|
||||
response: dict[str, Any],
|
||||
) -> bool:
|
||||
"""发送远程响应
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
response: 响应数据
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
return await self._remote_transport.send_response(session_id, response)
|
||||
|
||||
async def send_remote_event(
|
||||
self,
|
||||
session_id: str,
|
||||
event: dict[str, Any],
|
||||
) -> bool:
|
||||
"""发送远程事件
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
event: 事件数据
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
return await self._remote_transport.send_event(session_id, event)
|
||||
|
||||
async def send_remote_tool_call(
|
||||
self,
|
||||
session_id: str,
|
||||
tool_call: dict[str, Any],
|
||||
) -> bool:
|
||||
"""发送远程工具调用
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
tool_call: 工具调用数据
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
return await self._remote_transport.send_tool_call(session_id, tool_call)
|
||||
|
||||
# === 统一协调入口 ===
|
||||
|
||||
async def coordinate(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""统一协调入口
|
||||
|
||||
根据请求类型协调各类 Agent 操作。
|
||||
|
||||
Args:
|
||||
request: 请求数据,包含:
|
||||
- action: 操作类型 (team_create, team_assign, task_submit, session_create, etc.)
|
||||
- 其他参数根据 action 不同而不同
|
||||
|
||||
Returns:
|
||||
协调结果
|
||||
"""
|
||||
action = request.get("action")
|
||||
|
||||
if action == "team_create":
|
||||
return self.create_team(
|
||||
team_id=request["team_id"],
|
||||
members=request["members"],
|
||||
)
|
||||
|
||||
elif action == "team_assign":
|
||||
return self.assign_task(
|
||||
team_id=request["team_id"],
|
||||
description=request["description"],
|
||||
member=request["member"],
|
||||
)
|
||||
|
||||
elif action == "team_broadcast":
|
||||
return self.broadcast_task(
|
||||
team_id=request["team_id"],
|
||||
description=request["description"],
|
||||
)
|
||||
|
||||
elif action == "team_collect":
|
||||
return self.collect_team_results(team_id=request["team_id"])
|
||||
|
||||
elif action == "team_status":
|
||||
return self.get_team_status(team_id=request["team_id"])
|
||||
|
||||
elif action == "task_submit":
|
||||
return self.submit_background_task(
|
||||
name=request["name"],
|
||||
coro=request["coro"],
|
||||
*request.get("args", []),
|
||||
**request.get("kwargs", {}),
|
||||
)
|
||||
|
||||
elif action == "task_cancel":
|
||||
return self.cancel_background_task(task_id=request["task_id"])
|
||||
|
||||
elif action == "task_status":
|
||||
return self.get_background_task_status(task_id=request["task_id"])
|
||||
|
||||
elif action == "session_create":
|
||||
return self.create_session(
|
||||
user_id=request.get("user_id"),
|
||||
parent_session_id=request.get("parent_session_id"),
|
||||
)
|
||||
|
||||
elif action == "session_message":
|
||||
return await self.process_session_message(
|
||||
session_id=request["session_id"],
|
||||
message=request["message"],
|
||||
response=request["response"],
|
||||
)
|
||||
|
||||
elif action == "session_spawn":
|
||||
return await self.spawn_child_session(
|
||||
session_id=request["session_id"],
|
||||
user_id=request.get("user_id"),
|
||||
)
|
||||
|
||||
elif action == "session_summary":
|
||||
return self.get_session_summary(session_id=request["session_id"])
|
||||
|
||||
else:
|
||||
return {"status": "error", "message": f"Unknown action: {action}"}
|
||||
|
||||
|
||||
# 全局单例
|
||||
_coordinator: AgentCoordinator | None = None
|
||||
|
||||
|
||||
def get_agent_coordinator() -> AgentCoordinator:
|
||||
"""获取全局 Agent 协调整器"""
|
||||
global _coordinator
|
||||
if _coordinator is None:
|
||||
_coordinator = AgentCoordinator()
|
||||
return _coordinator
|
||||
File diff suppressed because it is too large
Load Diff
14
backend/app/agents/isolation/__init__.py
Normal file
14
backend/app/agents/isolation/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from app.agents.isolation.session_isolation import prepare_session_isolation
|
||||
from app.agents.isolation.strategy_selector import IsolationDecision, select_isolation_strategy
|
||||
from app.agents.isolation.worktree_isolation import (
|
||||
WorktreeIsolationError,
|
||||
prepare_worktree_isolation,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"IsolationDecision",
|
||||
"WorktreeIsolationError",
|
||||
"prepare_session_isolation",
|
||||
"prepare_worktree_isolation",
|
||||
"select_isolation_strategy",
|
||||
]
|
||||
31
backend/app/agents/isolation/session_isolation.py
Normal file
31
backend/app/agents/isolation/session_isolation.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from app.agents.isolation.strategy_selector import IsolationDecision
|
||||
|
||||
|
||||
def prepare_session_isolation(
|
||||
*,
|
||||
state: dict[str, Any],
|
||||
decision: IsolationDecision,
|
||||
role_value: str,
|
||||
sub_commander: str,
|
||||
) -> dict[str, Any]:
|
||||
isolation_id = f"session-{uuid4().hex[:8]}"
|
||||
return {
|
||||
"mode": "session",
|
||||
"isolation_id": isolation_id,
|
||||
"workspace_path": None,
|
||||
"parent_conversation_id": str(state.get("conversation_id") or "") or None,
|
||||
"metadata": {
|
||||
**dict(decision.metadata or {}),
|
||||
"reason": decision.reason,
|
||||
"role": role_value,
|
||||
"sub_commander": sub_commander,
|
||||
"tool_names": list(decision.tool_names),
|
||||
"capability_ids": list(decision.capability_ids),
|
||||
"status": "active",
|
||||
},
|
||||
}
|
||||
147
backend/app/agents/isolation/strategy_selector.py
Normal file
147
backend/app/agents/isolation/strategy_selector.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
from app.agents.registry import load_builtin_registry_indexes
|
||||
from app.agents.registry.models import CapabilityManifest, PermissionClass, SideEffectScope
|
||||
|
||||
|
||||
IsolationMode = Literal["none", "session", "worktree"]
|
||||
|
||||
_WORKTREE_QUERY_MARKERS = (
|
||||
"code",
|
||||
"repo",
|
||||
"repository",
|
||||
"git",
|
||||
"worktree",
|
||||
"branch",
|
||||
"patch",
|
||||
"diff",
|
||||
"refactor",
|
||||
"build",
|
||||
"test",
|
||||
"fix",
|
||||
"file",
|
||||
"files",
|
||||
"python",
|
||||
"typescript",
|
||||
"javascript",
|
||||
"代码",
|
||||
"仓库",
|
||||
"分支",
|
||||
"补丁",
|
||||
"重构",
|
||||
"构建",
|
||||
"测试",
|
||||
"修复",
|
||||
"文件",
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IsolationDecision:
|
||||
mode: IsolationMode
|
||||
reason: str
|
||||
tool_names: tuple[str, ...] = ()
|
||||
capability_ids: tuple[str, ...] = ()
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _capability_metadata(capability: CapabilityManifest | None) -> dict[str, Any]:
|
||||
if capability is None:
|
||||
return {}
|
||||
return {
|
||||
"capability_id": capability.capability_id,
|
||||
"tool_name": capability.tool_name,
|
||||
"permission_class": capability.permission_class.value,
|
||||
"side_effect_scope": capability.side_effect_scope.value,
|
||||
"supports_retry": capability.supports_retry,
|
||||
"idempotent": capability.idempotent,
|
||||
"safe_for_parallel_use": capability.safe_for_parallel_use,
|
||||
"requires_confirmation": capability.requires_confirmation,
|
||||
}
|
||||
|
||||
|
||||
def select_isolation_strategy(
|
||||
*,
|
||||
user_query: str,
|
||||
tool_names: list[str] | tuple[str, ...],
|
||||
role_value: str,
|
||||
execution_mode: str | None,
|
||||
) -> IsolationDecision:
|
||||
indexes = load_builtin_registry_indexes()
|
||||
capabilities: list[CapabilityManifest] = []
|
||||
capability_ids: list[str] = []
|
||||
|
||||
for tool_name in tool_names:
|
||||
capability_id = indexes.capability_id_by_tool_name.get(tool_name)
|
||||
capability = indexes.capability_by_id.get(capability_id) if capability_id else None
|
||||
if capability is not None:
|
||||
capabilities.append(capability)
|
||||
capability_ids.append(capability.capability_id)
|
||||
|
||||
normalized_query = (user_query or "").strip().lower()
|
||||
has_worktree_query_signal = any(marker in normalized_query for marker in _WORKTREE_QUERY_MARKERS)
|
||||
has_write_capability = any(cap.permission_class == PermissionClass.WRITE for cap in capabilities)
|
||||
has_external_capability = any(cap.permission_class == PermissionClass.EXTERNAL for cap in capabilities)
|
||||
has_non_parallel_capability = any(not cap.safe_for_parallel_use for cap in capabilities)
|
||||
has_stateful_side_effect = any(
|
||||
cap.side_effect_scope in {SideEffectScope.LOCAL_STATE, SideEffectScope.DB_WRITE}
|
||||
for cap in capabilities
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"role": role_value,
|
||||
"execution_mode": execution_mode,
|
||||
"capabilities": [_capability_metadata(capability) for capability in capabilities],
|
||||
"workspace_strategy": "inline",
|
||||
"risk_level": "low",
|
||||
}
|
||||
|
||||
if has_worktree_query_signal:
|
||||
return IsolationDecision(
|
||||
mode="worktree",
|
||||
reason="workspace_mutation_signals_detected",
|
||||
tool_names=tuple(tool_names),
|
||||
capability_ids=tuple(capability_ids),
|
||||
metadata={
|
||||
**metadata,
|
||||
"workspace_strategy": "ephemeral_worktree",
|
||||
"risk_level": "high",
|
||||
},
|
||||
)
|
||||
|
||||
if has_write_capability or has_stateful_side_effect or has_non_parallel_capability:
|
||||
return IsolationDecision(
|
||||
mode="session",
|
||||
reason="stateful_or_non_parallel_tooling",
|
||||
tool_names=tuple(tool_names),
|
||||
capability_ids=tuple(capability_ids),
|
||||
metadata={
|
||||
**metadata,
|
||||
"workspace_strategy": "isolated_session",
|
||||
"risk_level": "medium",
|
||||
},
|
||||
)
|
||||
|
||||
if execution_mode == "collaboration" or role_value in {"analyst", "librarian"} or has_external_capability:
|
||||
return IsolationDecision(
|
||||
mode="session",
|
||||
reason="context_heavy_or_external_retrieval",
|
||||
tool_names=tuple(tool_names),
|
||||
capability_ids=tuple(capability_ids),
|
||||
metadata={
|
||||
**metadata,
|
||||
"workspace_strategy": "isolated_session",
|
||||
"risk_level": "medium",
|
||||
},
|
||||
)
|
||||
|
||||
return IsolationDecision(
|
||||
mode="none",
|
||||
reason="inline_execution_is_sufficient",
|
||||
tool_names=tuple(tool_names),
|
||||
capability_ids=tuple(capability_ids),
|
||||
metadata=metadata,
|
||||
)
|
||||
83
backend/app/agents/isolation/worktree_isolation.py
Normal file
83
backend/app/agents/isolation/worktree_isolation.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from app.agents.isolation.strategy_selector import IsolationDecision
|
||||
|
||||
|
||||
class WorktreeIsolationError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _slugify(value: str, *, fallback: str) -> str:
|
||||
slug = re.sub(r"[^a-zA-Z0-9._-]+", "-", (value or "").strip()).strip("-").lower()
|
||||
return slug or fallback
|
||||
|
||||
|
||||
def _resolve_git_root() -> Path:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "--show-toplevel"],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
except subprocess.CalledProcessError as exc:
|
||||
raise WorktreeIsolationError(exc.stderr.strip() or exc.stdout.strip() or "git_root_unavailable") from exc
|
||||
git_root = Path(result.stdout.strip())
|
||||
if not git_root.exists():
|
||||
raise WorktreeIsolationError("git_root_not_found")
|
||||
return git_root
|
||||
|
||||
|
||||
def prepare_worktree_isolation(
|
||||
*,
|
||||
state: dict[str, Any],
|
||||
decision: IsolationDecision,
|
||||
role_value: str,
|
||||
sub_commander: str,
|
||||
create_workspace: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
isolation_id = f"worktree-{uuid4().hex[:8]}"
|
||||
conversation_slug = _slugify(str(state.get("conversation_id") or "conversation"), fallback="conversation")
|
||||
role_slug = _slugify(role_value, fallback="agent")
|
||||
git_root = _resolve_git_root()
|
||||
workspace_root = git_root / ".worktrees" / "jarvis" / conversation_slug
|
||||
workspace_path = workspace_root / f"{role_slug}-{isolation_id}"
|
||||
branch = f"jarvis/{conversation_slug}/{role_slug}-{isolation_id}"
|
||||
|
||||
if create_workspace and not workspace_path.exists():
|
||||
workspace_root.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "-C", str(git_root), "worktree", "add", "-b", branch, str(workspace_path), "HEAD"],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
except subprocess.CalledProcessError as exc:
|
||||
raise WorktreeIsolationError(exc.stderr.strip() or exc.stdout.strip() or "worktree_add_failed") from exc
|
||||
|
||||
return {
|
||||
"mode": "worktree",
|
||||
"isolation_id": isolation_id,
|
||||
"workspace_path": str(workspace_path),
|
||||
"parent_conversation_id": str(state.get("conversation_id") or "") or None,
|
||||
"metadata": {
|
||||
**dict(decision.metadata or {}),
|
||||
"reason": decision.reason,
|
||||
"role": role_value,
|
||||
"sub_commander": sub_commander,
|
||||
"tool_names": list(decision.tool_names),
|
||||
"capability_ids": list(decision.capability_ids),
|
||||
"repo_root": str(git_root),
|
||||
"branch": branch,
|
||||
"workspace_strategy": "ephemeral_worktree",
|
||||
"cleanup_status": "pending",
|
||||
"materialized": workspace_path.exists(),
|
||||
},
|
||||
}
|
||||
20
backend/app/agents/orchestration/__init__.py
Normal file
20
backend/app/agents/orchestration/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""高级编排系统 - Phase 10"""
|
||||
|
||||
from app.agents.team.leader import TeamLeader, TeamTask, TaskStatus
|
||||
from app.agents.transport.remote import RemoteTransport, StructuredMessage
|
||||
from app.agents.background.manager import (
|
||||
BackgroundTaskManager,
|
||||
BackgroundTask,
|
||||
get_background_task_manager,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TeamLeader",
|
||||
"TeamTask",
|
||||
"TaskStatus",
|
||||
"RemoteTransport",
|
||||
"StructuredMessage",
|
||||
"BackgroundTaskManager",
|
||||
"BackgroundTask",
|
||||
"get_background_task_manager",
|
||||
]
|
||||
12
backend/app/agents/plugins/__init__.py
Normal file
12
backend/app/agents/plugins/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""插件系统 - Phase 8"""
|
||||
|
||||
from app.agents.plugins.manager import PluginManager, get_plugin_manager
|
||||
from app.agents.plugins.manifest import PluginManifest
|
||||
from app.agents.plugins.sandbox import PluginSandbox
|
||||
|
||||
__all__ = [
|
||||
"PluginManager",
|
||||
"PluginManifest",
|
||||
"PluginSandbox",
|
||||
"get_plugin_manager",
|
||||
]
|
||||
19
backend/app/agents/plugins/builtins/code_helper/__init__.py
Normal file
19
backend/app/agents/plugins/builtins/code_helper/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Code Helper Plugin - Linting, formatting, and code explanation tools"""
|
||||
|
||||
|
||||
def lint_file(file_path: str) -> dict:
|
||||
"""Lint a source file and return issues found."""
|
||||
return {"status": "ok", "tool": "lint_file", "result": f"Linting {file_path}"}
|
||||
|
||||
|
||||
def format_file(file_path: str) -> dict:
|
||||
"""Format a source file and return the result."""
|
||||
return {"status": "ok", "tool": "format_file", "result": f"Formatting {file_path}"}
|
||||
|
||||
|
||||
def explain_code(code_snippet: str) -> dict:
|
||||
"""Explain a code snippet and return the explanation."""
|
||||
return {"status": "ok", "tool": "explain_code", "result": f"Explaining code snippet"}
|
||||
|
||||
|
||||
tools = [lint_file, format_file, explain_code]
|
||||
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"id": "code_helper",
|
||||
"name": "Code Helper",
|
||||
"version": "1.0.0",
|
||||
"description": "Code linting, formatting, and explanation tools",
|
||||
"author": "",
|
||||
"homepage": "",
|
||||
"license": "MIT",
|
||||
"plugin_type": "tool",
|
||||
"main": "__init__.py",
|
||||
"hooks": [],
|
||||
"tools": ["lint_file", "format_file", "explain_code"],
|
||||
"skills": [],
|
||||
"dependencies": {},
|
||||
"peer_dependencies": {},
|
||||
"permissions": [],
|
||||
"allowed_paths": [],
|
||||
"denied_paths": [],
|
||||
"network_allowed": false,
|
||||
"allowed_hosts": [],
|
||||
"config_schema": {}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
"""File Organizer Plugin - File organization and duplicate detection tools"""
|
||||
|
||||
|
||||
def organize_by_type(directory: str) -> dict:
|
||||
"""Organize files in a directory by file type."""
|
||||
return {"status": "ok", "tool": "organize_by_type", "result": f"Organizing {directory} by type"}
|
||||
|
||||
|
||||
def find_duplicates(directory: str) -> dict:
|
||||
"""Find duplicate files in a directory."""
|
||||
return {
|
||||
"status": "ok",
|
||||
"tool": "find_duplicates",
|
||||
"result": f"Finding duplicates in {directory}",
|
||||
}
|
||||
|
||||
|
||||
tools = [organize_by_type, find_duplicates]
|
||||
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"id": "file_organizer",
|
||||
"name": "File Organizer",
|
||||
"version": "1.0.0",
|
||||
"description": "File organization and duplicate detection tools",
|
||||
"author": "",
|
||||
"homepage": "",
|
||||
"license": "MIT",
|
||||
"plugin_type": "tool",
|
||||
"main": "__init__.py",
|
||||
"hooks": [],
|
||||
"tools": ["organize_by_type", "find_duplicates"],
|
||||
"skills": [],
|
||||
"dependencies": {},
|
||||
"peer_dependencies": {},
|
||||
"permissions": [],
|
||||
"allowed_paths": [],
|
||||
"denied_paths": [],
|
||||
"network_allowed": false,
|
||||
"allowed_hosts": [],
|
||||
"config_schema": {}
|
||||
}
|
||||
23
backend/app/agents/plugins/builtins/git_helper/__init__.py
Normal file
23
backend/app/agents/plugins/builtins/git_helper/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Git Helper Plugin - Git status, log, and diff summary tools"""
|
||||
|
||||
|
||||
def git_status_summary() -> dict:
|
||||
"""Get a summary of git status."""
|
||||
return {"status": "ok", "tool": "git_status_summary", "result": "Git status summary"}
|
||||
|
||||
|
||||
def git_log_summary(limit: int = 10) -> dict:
|
||||
"""Get a summary of recent git commits."""
|
||||
return {"status": "ok", "tool": "git_log_summary", "result": f"Git log summary (limit={limit})"}
|
||||
|
||||
|
||||
def git_diff_summary(ref1: str = "HEAD", ref2: str = "HEAD~1") -> dict:
|
||||
"""Get a summary of changes between two refs."""
|
||||
return {
|
||||
"status": "ok",
|
||||
"tool": "git_diff_summary",
|
||||
"result": f"Git diff summary ({ref1}..{ref2})",
|
||||
}
|
||||
|
||||
|
||||
tools = [git_status_summary, git_log_summary, git_diff_summary]
|
||||
22
backend/app/agents/plugins/builtins/git_helper/manifest.json
Normal file
22
backend/app/agents/plugins/builtins/git_helper/manifest.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"id": "git_helper",
|
||||
"name": "Git Helper",
|
||||
"version": "1.0.0",
|
||||
"description": "Git status, log, and diff summary tools",
|
||||
"author": "",
|
||||
"homepage": "",
|
||||
"license": "MIT",
|
||||
"plugin_type": "tool",
|
||||
"main": "__init__.py",
|
||||
"hooks": [],
|
||||
"tools": ["git_status_summary", "git_log_summary", "git_diff_summary"],
|
||||
"skills": [],
|
||||
"dependencies": {},
|
||||
"peer_dependencies": {},
|
||||
"permissions": [],
|
||||
"allowed_paths": [],
|
||||
"denied_paths": [],
|
||||
"network_allowed": false,
|
||||
"allowed_hosts": [],
|
||||
"config_schema": {}
|
||||
}
|
||||
14
backend/app/agents/plugins/builtins/web_helper/__init__.py
Normal file
14
backend/app/agents/plugins/builtins/web_helper/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Web Helper Plugin - Web fetching and HTML parsing tools"""
|
||||
|
||||
|
||||
def fetch_url_content(url: str) -> dict:
|
||||
"""Fetch content from a URL."""
|
||||
return {"status": "ok", "tool": "fetch_url_content", "result": f"Fetching {url}"}
|
||||
|
||||
|
||||
def parse_html_links(html_content: str) -> dict:
|
||||
"""Parse HTML content and extract links."""
|
||||
return {"status": "ok", "tool": "parse_html_links", "result": "Extracted links from HTML"}
|
||||
|
||||
|
||||
tools = [fetch_url_content, parse_html_links]
|
||||
22
backend/app/agents/plugins/builtins/web_helper/manifest.json
Normal file
22
backend/app/agents/plugins/builtins/web_helper/manifest.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"id": "web_helper",
|
||||
"name": "Web Helper",
|
||||
"version": "1.0.0",
|
||||
"description": "Web fetching and HTML parsing tools",
|
||||
"author": "",
|
||||
"homepage": "",
|
||||
"license": "MIT",
|
||||
"plugin_type": "tool",
|
||||
"main": "__init__.py",
|
||||
"hooks": [],
|
||||
"tools": ["fetch_url_content", "parse_html_links"],
|
||||
"skills": [],
|
||||
"dependencies": {},
|
||||
"peer_dependencies": {},
|
||||
"permissions": [],
|
||||
"allowed_paths": [],
|
||||
"denied_paths": [],
|
||||
"network_allowed": true,
|
||||
"allowed_hosts": [],
|
||||
"config_schema": {}
|
||||
}
|
||||
207
backend/app/agents/plugins/manager.py
Normal file
207
backend/app/agents/plugins/manager.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""插件管理器 - Phase 8.2"""
|
||||
|
||||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from app.agents.plugins.manifest import PluginManifest
|
||||
from app.agents.plugins.sandbox import PluginSandbox
|
||||
|
||||
|
||||
class PluginManager:
|
||||
"""插件管理器
|
||||
|
||||
负责插件的安装、卸载、启用、禁用和生命周期管理。
|
||||
"""
|
||||
|
||||
def __init__(self, plugins_dir: str | None = None):
|
||||
"""
|
||||
Args:
|
||||
plugins_dir: 插件目录,None 则使用默认目录
|
||||
"""
|
||||
if plugins_dir is None:
|
||||
plugins_dir = os.path.join(os.path.dirname(__file__), "..", "..", "..", "plugins")
|
||||
self.plugins_dir = plugins_dir
|
||||
self._plugins: dict[str, PluginManifest] = {}
|
||||
self._enabled: dict[str, bool] = {}
|
||||
self._modules: dict[str, Any] = {}
|
||||
self._sandbox = PluginSandbox()
|
||||
|
||||
def install(self, plugin_path: str) -> bool:
|
||||
"""安装插件
|
||||
|
||||
Args:
|
||||
plugin_path: 插件目录路径或 manifest.json 所在目录
|
||||
|
||||
Returns:
|
||||
是否安装成功
|
||||
"""
|
||||
try:
|
||||
manifest_path = os.path.join(plugin_path, "manifest.json")
|
||||
|
||||
if not os.path.exists(manifest_path):
|
||||
return False
|
||||
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
import json
|
||||
|
||||
data = json.load(f)
|
||||
|
||||
manifest = PluginManifest.from_dict(data)
|
||||
|
||||
# 验证 manifest
|
||||
if not self._validate_manifest(manifest, plugin_path):
|
||||
return False
|
||||
|
||||
# 复制插件到 plugins_dir
|
||||
target_dir = os.path.join(self.plugins_dir, manifest.id)
|
||||
os.makedirs(os.path.dirname(target_dir), exist_ok=True)
|
||||
|
||||
# 保存 manifest
|
||||
with open(os.path.join(target_dir, "manifest.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(manifest.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
|
||||
# 注册插件
|
||||
self._plugins[manifest.id] = manifest
|
||||
self._enabled[manifest.id] = True
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def uninstall(self, plugin_id: str) -> bool:
|
||||
"""卸载插件
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID
|
||||
|
||||
Returns:
|
||||
是否卸载成功
|
||||
"""
|
||||
if plugin_id not in self._plugins:
|
||||
return False
|
||||
|
||||
# 禁用插件
|
||||
self.disable(plugin_id)
|
||||
|
||||
# 移除模块
|
||||
if plugin_id in self._modules:
|
||||
del self._modules[plugin_id]
|
||||
|
||||
# 移除插件
|
||||
del self._plugins[plugin_id]
|
||||
del self._enabled[plugin_id]
|
||||
|
||||
# 删除目录
|
||||
plugin_dir = os.path.join(self.plugins_dir, plugin_id)
|
||||
if os.path.exists(plugin_dir):
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(plugin_dir)
|
||||
|
||||
return True
|
||||
|
||||
def enable(self, plugin_id: str) -> bool:
|
||||
"""启用插件
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID
|
||||
|
||||
Returns:
|
||||
是否启用成功
|
||||
"""
|
||||
if plugin_id not in self._plugins:
|
||||
return False
|
||||
|
||||
self._enabled[plugin_id] = True
|
||||
return True
|
||||
|
||||
def disable(self, plugin_id: str) -> bool:
|
||||
"""禁用插件
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID
|
||||
|
||||
Returns:
|
||||
是否禁用成功
|
||||
"""
|
||||
if plugin_id not in self._plugins:
|
||||
return False
|
||||
|
||||
self._enabled[plugin_id] = False
|
||||
return True
|
||||
|
||||
def reload(self, plugin_id: str) -> bool:
|
||||
"""重新加载插件
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID
|
||||
|
||||
Returns:
|
||||
是否重新加载成功
|
||||
"""
|
||||
if plugin_id not in self._plugins:
|
||||
return False
|
||||
|
||||
# 卸载模块
|
||||
if plugin_id in self._modules:
|
||||
del self._modules[plugin_id]
|
||||
|
||||
# 重新加载
|
||||
return self._load_plugin_module(plugin_id)
|
||||
|
||||
def list_plugins(self) -> list[PluginManifest]:
|
||||
"""列出所有插件"""
|
||||
return list(self._plugins.values())
|
||||
|
||||
def get_plugin(self, plugin_id: str) -> PluginManifest | None:
|
||||
"""获取插件清单"""
|
||||
return self._plugins.get(plugin_id)
|
||||
|
||||
def is_enabled(self, plugin_id: str) -> bool:
|
||||
"""检查插件是否启用"""
|
||||
return self._enabled.get(plugin_id, False)
|
||||
|
||||
def _validate_manifest(self, manifest: PluginManifest, plugin_path: str) -> bool:
|
||||
"""验证 manifest"""
|
||||
# 检查主入口文件是否存在
|
||||
main_path = os.path.join(plugin_path, manifest.main)
|
||||
if not os.path.exists(main_path):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _load_plugin_module(self, plugin_id: str) -> bool:
|
||||
"""加载插件模块"""
|
||||
plugin_dir = os.path.join(self.plugins_dir, plugin_id)
|
||||
manifest = self._plugins.get(plugin_id)
|
||||
if not manifest:
|
||||
return False
|
||||
|
||||
try:
|
||||
main_path = os.path.join(plugin_dir, manifest.main)
|
||||
spec = importlib.util.spec_from_file_location(plugin_id, main_path)
|
||||
if spec and spec.loader:
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[plugin_id] = module
|
||||
spec.loader.exec_module(module)
|
||||
self._modules[plugin_id] = module
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# 全局单例
|
||||
_manager: PluginManager | None = None
|
||||
|
||||
|
||||
def get_plugin_manager() -> PluginManager:
|
||||
"""获取全局插件管理器"""
|
||||
global _manager
|
||||
if _manager is None:
|
||||
_manager = PluginManager()
|
||||
return _manager
|
||||
73
backend/app/agents/plugins/manifest.py
Normal file
73
backend/app/agents/plugins/manifest.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""插件清单定义 - Phase 8.1"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginManifest:
|
||||
"""插件清单
|
||||
|
||||
定义插件的元数据和接口。
|
||||
"""
|
||||
|
||||
id: str # 唯一标识
|
||||
name: str # 显示名称
|
||||
version: str # 版本号
|
||||
description: str # 描述
|
||||
author: str = "" # 作者
|
||||
homepage: str = "" # 主页
|
||||
license: str = "MIT" # 许可证
|
||||
|
||||
# 插件类型
|
||||
plugin_type: str = "tool" # tool, hook, skill, all
|
||||
|
||||
# 入口点
|
||||
main: str = "index.py" # 主入口文件
|
||||
hooks: list[str] = field(default_factory=list) # 提供的 Hook 列表
|
||||
tools: list[str] = field(default_factory=list) # 提供的工具列表
|
||||
skills: list[str] = field(default_factory=list) # 提供的 Skills 列表
|
||||
|
||||
# 依赖
|
||||
dependencies: dict[str, str] = field(default_factory=dict) # pip 依赖
|
||||
peer_dependencies: dict[str, str] = field(default_factory=dict) # 对等依赖
|
||||
|
||||
# 权限要求
|
||||
permissions: list[str] = field(default_factory=list) # 需要的权限
|
||||
allowed_paths: list[str] = field(default_factory=list) # 允许访问的路径
|
||||
denied_paths: list[str] = field(default_factory=list) # 禁止访问的路径
|
||||
|
||||
# 网络权限
|
||||
network_allowed: bool = False # 是否允许网络访问
|
||||
allowed_hosts: list[str] = field(default_factory=list) # 允许访问的 host
|
||||
|
||||
# 配置
|
||||
config_schema: dict[str, Any] = field(default_factory=dict) # 配置 schema
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"version": self.version,
|
||||
"description": self.description,
|
||||
"author": self.author,
|
||||
"homepage": self.homepage,
|
||||
"license": self.license,
|
||||
"plugin_type": self.plugin_type,
|
||||
"main": self.main,
|
||||
"hooks": self.hooks,
|
||||
"tools": self.tools,
|
||||
"skills": self.skills,
|
||||
"dependencies": self.dependencies,
|
||||
"peer_dependencies": self.peer_dependencies,
|
||||
"permissions": self.permissions,
|
||||
"allowed_paths": self.allowed_paths,
|
||||
"denied_paths": self.denied_paths,
|
||||
"network_allowed": self.network_allowed,
|
||||
"allowed_hosts": self.allowed_hosts,
|
||||
"config_schema": self.config_schema,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "PluginManifest":
|
||||
return cls(**data)
|
||||
111
backend/app/agents/plugins/sandbox.py
Normal file
111
backend/app/agents/plugins/sandbox.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""插件沙箱隔离 - Phase 8.3"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
|
||||
class PluginSandbox:
|
||||
"""插件沙箱
|
||||
|
||||
提供插件执行隔离环境。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._allowed_paths: set[str] = set()
|
||||
self._denied_paths: set[str] = set()
|
||||
self._network_allowed: bool = False
|
||||
self._allowed_hosts: set[str] = set()
|
||||
|
||||
def set_file_permissions(
|
||||
self,
|
||||
allowed_paths: list[str] | None = None,
|
||||
denied_paths: list[str] | None = None,
|
||||
) -> None:
|
||||
"""设置文件访问权限
|
||||
|
||||
Args:
|
||||
allowed_paths: 允许访问的路径列表
|
||||
denied_paths: 禁止访问的路径列表
|
||||
"""
|
||||
self._allowed_paths = set(allowed_paths or [])
|
||||
self._denied_paths = set(denied_paths or [])
|
||||
|
||||
def set_network_permissions(
|
||||
self, allowed: bool, allowed_hosts: list[str] | None = None
|
||||
) -> None:
|
||||
"""设置网络访问权限
|
||||
|
||||
Args:
|
||||
allowed: 是否允许网络访问
|
||||
allowed_hosts: 允许访问的 host 列表
|
||||
"""
|
||||
self._network_allowed = allowed
|
||||
self._allowed_hosts = set(allowed_hosts or [])
|
||||
|
||||
def check_file_access(self, path: str) -> bool:
|
||||
"""检查文件访问权限
|
||||
|
||||
Args:
|
||||
path: 文件路径
|
||||
|
||||
Returns:
|
||||
是否允许访问
|
||||
"""
|
||||
# 如果有允许列表,只允许访问列表中的路径
|
||||
if self._allowed_paths:
|
||||
return path in self._allowed_paths or any(
|
||||
path.startswith(allowed) for allowed in self._allowed_paths
|
||||
)
|
||||
|
||||
# 如果有禁止列表,禁止访问列表中的路径
|
||||
if self._denied_paths:
|
||||
return not any(path.startswith(denied) for denied in self._denied_paths)
|
||||
|
||||
# 没有限制
|
||||
return True
|
||||
|
||||
def check_network_access(self, host: str) -> bool:
|
||||
"""检查网络访问权限
|
||||
|
||||
Args:
|
||||
host: 主机地址
|
||||
|
||||
Returns:
|
||||
是否允许访问
|
||||
"""
|
||||
if not self._network_allowed:
|
||||
return False
|
||||
|
||||
if self._allowed_hosts:
|
||||
return host in self._allowed_hosts or any(
|
||||
host.endswith(allowed) for allowed in self._allowed_hosts
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def execute_in_sandbox(self, func: Any, *args, **kwargs) -> Any:
|
||||
"""在沙箱中执行函数
|
||||
|
||||
Args:
|
||||
func: 要执行的函数
|
||||
*args: 位置参数
|
||||
**kwargs: 关键字参数
|
||||
|
||||
Returns:
|
||||
函数返回值
|
||||
"""
|
||||
# 保存当前状态
|
||||
old_allowed_paths = self._allowed_paths.copy()
|
||||
old_denied_paths = self._denied_paths.copy()
|
||||
old_network_allowed = self._network_allowed
|
||||
old_allowed_hosts = self._allowed_hosts.copy()
|
||||
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
finally:
|
||||
# 恢复状态
|
||||
self._allowed_paths = old_allowed_paths
|
||||
self._denied_paths = old_denied_paths
|
||||
self._network_allowed = old_network_allowed
|
||||
self._allowed_hosts = old_allowed_hosts
|
||||
@@ -324,6 +324,38 @@ ANALYST_INSIGHTS_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
"""
|
||||
|
||||
|
||||
COORDINATOR_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
|
||||
你是 Jarvis 的协作协调官,负责把复杂请求收束成最小受控协作,而不是放任系统进入自由 swarm。
|
||||
|
||||
## 你的职责:
|
||||
- 先判断当前请求是否真的需要拆解;不需要时应明确建议继续走 direct
|
||||
- 只有在明显多步骤、跨领域、需要多角色配合时,才拆成 2~4 个子任务
|
||||
- 每个子任务必须清晰写出 `title`、`role`、`goal`、`expected_evidence`
|
||||
- 角色建议只能来自现有 top-level agent:`schedule_planner`、`librarian`、`analyst`、`executor`
|
||||
- 汇总时基于子任务结果回收,不依赖单点硬编码拼接
|
||||
|
||||
## 边界:
|
||||
- 禁止无限递归拆分
|
||||
- 禁止创建新的 runtime agent / worker
|
||||
- 禁止把一个简单请求硬拆成多个空泛步骤
|
||||
- 如果证据不足、子任务未闭环,必须把风险明确暴露出来
|
||||
"""
|
||||
|
||||
|
||||
VERIFIER_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
|
||||
你是 Jarvis 的验证官,负责对执行结果做最小但明确的核验。
|
||||
|
||||
## 你的职责:
|
||||
- 只输出 passed、failed、skipped 三种验证结论之一
|
||||
- 用一句话总结验证判断
|
||||
- 如有证据,保留关键证据点
|
||||
- 当信息不足以证明成功或失败时,优先判定为 skipped
|
||||
- 不重写执行方案,不扩展无关建议
|
||||
"""
|
||||
|
||||
|
||||
JSON_ACTION_FALLBACK_PROMPT = """你当前运行在 JSON action fallback 模式。
|
||||
|
||||
你的输出必须满足以下规则:
|
||||
|
||||
@@ -1,11 +1,19 @@
|
||||
"""Registry manifest models and validation helpers."""
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
from app.agents.registry.indexes import RegistryIndexes, build_registry_indexes
|
||||
from app.agents.registry.loader import RegistryBundle, load_builtin_registry_bundle
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def load_builtin_registry_indexes() -> RegistryIndexes:
|
||||
return build_registry_indexes(load_builtin_registry_bundle())
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RegistryBundle",
|
||||
"RegistryIndexes",
|
||||
"build_registry_indexes",
|
||||
"load_builtin_registry_bundle",
|
||||
"load_builtin_registry_indexes",
|
||||
]
|
||||
|
||||
@@ -2,6 +2,8 @@ from app.agents.prompts import SUB_COMMANDER_PROMPTS_BY_KEY
|
||||
from app.agents.registry.models import (
|
||||
AgentManifest,
|
||||
CapabilityManifest,
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
SpecialistTemplateManifest,
|
||||
SubCommanderManifest,
|
||||
)
|
||||
@@ -55,6 +57,19 @@ TOP_LEVEL_AGENT_ROUTING_HINTS: dict[str, tuple[str, ...]] = {
|
||||
),
|
||||
}
|
||||
|
||||
TOP_LEVEL_AGENT_ALLOWED_SPAWN_ROLES: dict[str, tuple[str, ...]] = {
|
||||
AgentRole.MASTER.value: (
|
||||
AgentRole.SCHEDULE_PLANNER.value,
|
||||
AgentRole.EXECUTOR.value,
|
||||
AgentRole.LIBRARIAN.value,
|
||||
AgentRole.ANALYST.value,
|
||||
),
|
||||
AgentRole.SCHEDULE_PLANNER.value: (AgentRole.SCHEDULE_PLANNER.value,),
|
||||
AgentRole.EXECUTOR.value: (AgentRole.EXECUTOR.value,),
|
||||
AgentRole.LIBRARIAN.value: (AgentRole.LIBRARIAN.value,),
|
||||
AgentRole.ANALYST.value: (AgentRole.ANALYST.value,),
|
||||
}
|
||||
|
||||
SUB_COMMANDER_PARENT_AGENT_IDS: dict[str, str] = {
|
||||
"schedule_analysis": AgentRole.SCHEDULE_PLANNER.value,
|
||||
"schedule_planning": AgentRole.SCHEDULE_PLANNER.value,
|
||||
@@ -75,6 +90,8 @@ BUILTIN_AGENT_MANIFESTS: tuple[AgentManifest, ...] = tuple(
|
||||
system_prompt_key=role.value,
|
||||
routing_hints=list(TOP_LEVEL_AGENT_ROUTING_HINTS[role.value]),
|
||||
default_sub_commanders=list(TOP_LEVEL_AGENT_DEFAULT_SUB_COMMANDERS[role.value]),
|
||||
can_spawn_children=bool(TOP_LEVEL_AGENT_ALLOWED_SPAWN_ROLES[role.value]),
|
||||
allowed_spawn_role_values=list(TOP_LEVEL_AGENT_ALLOWED_SPAWN_ROLES[role.value]),
|
||||
skill_context_key=role.value.replace("agent_", ""),
|
||||
)
|
||||
for role in AgentRole
|
||||
@@ -89,10 +106,150 @@ _capability_tool_names = tuple(
|
||||
)
|
||||
)
|
||||
|
||||
_CAPABILITY_METADATA_BY_TOOL_NAME: dict[str, dict[str, object]] = {
|
||||
"get_tasks": {
|
||||
"permission_class": PermissionClass.READ,
|
||||
"side_effect_scope": SideEffectScope.NONE,
|
||||
"supports_retry": True,
|
||||
"idempotent": True,
|
||||
"safe_for_parallel_use": True,
|
||||
"requires_confirmation": False,
|
||||
},
|
||||
"get_schedule_day": {
|
||||
"permission_class": PermissionClass.READ,
|
||||
"side_effect_scope": SideEffectScope.NONE,
|
||||
"supports_retry": True,
|
||||
"idempotent": True,
|
||||
"safe_for_parallel_use": True,
|
||||
"requires_confirmation": False,
|
||||
},
|
||||
"resolve_time_expression": {
|
||||
"permission_class": PermissionClass.READ,
|
||||
"side_effect_scope": SideEffectScope.NONE,
|
||||
"supports_retry": True,
|
||||
"idempotent": True,
|
||||
"safe_for_parallel_use": True,
|
||||
"requires_confirmation": False,
|
||||
},
|
||||
"search_knowledge": {
|
||||
"permission_class": PermissionClass.READ,
|
||||
"side_effect_scope": SideEffectScope.NONE,
|
||||
"supports_retry": True,
|
||||
"idempotent": True,
|
||||
"safe_for_parallel_use": True,
|
||||
"requires_confirmation": False,
|
||||
},
|
||||
"hybrid_search": {
|
||||
"permission_class": PermissionClass.READ,
|
||||
"side_effect_scope": SideEffectScope.NONE,
|
||||
"supports_retry": True,
|
||||
"idempotent": True,
|
||||
"safe_for_parallel_use": True,
|
||||
"requires_confirmation": False,
|
||||
},
|
||||
"get_knowledge_graph_context": {
|
||||
"permission_class": PermissionClass.READ,
|
||||
"side_effect_scope": SideEffectScope.NONE,
|
||||
"supports_retry": True,
|
||||
"idempotent": True,
|
||||
"safe_for_parallel_use": True,
|
||||
"requires_confirmation": False,
|
||||
},
|
||||
"get_forum_posts": {
|
||||
"permission_class": PermissionClass.READ,
|
||||
"side_effect_scope": SideEffectScope.NONE,
|
||||
"supports_retry": True,
|
||||
"idempotent": True,
|
||||
"safe_for_parallel_use": True,
|
||||
"requires_confirmation": False,
|
||||
},
|
||||
"scan_forum_for_instructions": {
|
||||
"permission_class": PermissionClass.READ,
|
||||
"side_effect_scope": SideEffectScope.NONE,
|
||||
"supports_retry": True,
|
||||
"idempotent": True,
|
||||
"safe_for_parallel_use": True,
|
||||
"requires_confirmation": False,
|
||||
},
|
||||
"web_search": {
|
||||
"permission_class": PermissionClass.EXTERNAL,
|
||||
"side_effect_scope": SideEffectScope.NETWORK,
|
||||
"supports_retry": True,
|
||||
"idempotent": True,
|
||||
"safe_for_parallel_use": True,
|
||||
"requires_confirmation": False,
|
||||
},
|
||||
"create_task": {
|
||||
"permission_class": PermissionClass.WRITE,
|
||||
"side_effect_scope": SideEffectScope.LOCAL_STATE,
|
||||
"supports_retry": False,
|
||||
"idempotent": False,
|
||||
"safe_for_parallel_use": False,
|
||||
"requires_confirmation": True,
|
||||
},
|
||||
"update_task_status": {
|
||||
"permission_class": PermissionClass.WRITE,
|
||||
"side_effect_scope": SideEffectScope.LOCAL_STATE,
|
||||
"supports_retry": False,
|
||||
"idempotent": False,
|
||||
"safe_for_parallel_use": False,
|
||||
"requires_confirmation": True,
|
||||
},
|
||||
"create_todo": {
|
||||
"permission_class": PermissionClass.WRITE,
|
||||
"side_effect_scope": SideEffectScope.LOCAL_STATE,
|
||||
"supports_retry": False,
|
||||
"idempotent": False,
|
||||
"safe_for_parallel_use": False,
|
||||
"requires_confirmation": True,
|
||||
},
|
||||
"create_schedule_task": {
|
||||
"permission_class": PermissionClass.WRITE,
|
||||
"side_effect_scope": SideEffectScope.LOCAL_STATE,
|
||||
"supports_retry": False,
|
||||
"idempotent": False,
|
||||
"safe_for_parallel_use": False,
|
||||
"requires_confirmation": True,
|
||||
},
|
||||
"create_reminder": {
|
||||
"permission_class": PermissionClass.WRITE,
|
||||
"side_effect_scope": SideEffectScope.LOCAL_STATE,
|
||||
"supports_retry": False,
|
||||
"idempotent": False,
|
||||
"safe_for_parallel_use": False,
|
||||
"requires_confirmation": True,
|
||||
},
|
||||
"create_goal": {
|
||||
"permission_class": PermissionClass.WRITE,
|
||||
"side_effect_scope": SideEffectScope.LOCAL_STATE,
|
||||
"supports_retry": False,
|
||||
"idempotent": False,
|
||||
"safe_for_parallel_use": False,
|
||||
"requires_confirmation": True,
|
||||
},
|
||||
"create_forum_post": {
|
||||
"permission_class": PermissionClass.WRITE,
|
||||
"side_effect_scope": SideEffectScope.LOCAL_STATE,
|
||||
"supports_retry": False,
|
||||
"idempotent": False,
|
||||
"safe_for_parallel_use": False,
|
||||
"requires_confirmation": True,
|
||||
},
|
||||
"build_knowledge_graph": {
|
||||
"permission_class": PermissionClass.WRITE,
|
||||
"side_effect_scope": SideEffectScope.LOCAL_STATE,
|
||||
"supports_retry": False,
|
||||
"idempotent": False,
|
||||
"safe_for_parallel_use": False,
|
||||
"requires_confirmation": True,
|
||||
},
|
||||
}
|
||||
|
||||
BUILTIN_CAPABILITY_MANIFESTS: tuple[CapabilityManifest, ...] = tuple(
|
||||
CapabilityManifest(
|
||||
capability_id=tool_name,
|
||||
tool_name=tool_name,
|
||||
**dict(_CAPABILITY_METADATA_BY_TOOL_NAME.get(tool_name, {})),
|
||||
)
|
||||
for tool_name in _capability_tool_names
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@ from app.agents.registry.models import (
|
||||
@dataclass(frozen=True)
|
||||
class RegistryIndexes:
|
||||
agent_by_id: Mapping[str, AgentManifest]
|
||||
agent_by_role_value: Mapping[str, AgentManifest]
|
||||
sub_commander_by_id: Mapping[str, SubCommanderManifest]
|
||||
capability_by_id: Mapping[str, CapabilityManifest]
|
||||
specialist_template_by_id: Mapping[str, SpecialistTemplateManifest]
|
||||
@@ -24,6 +25,7 @@ class RegistryIndexes:
|
||||
skill_context_key_by_agent_id: Mapping[str, str]
|
||||
capability_id_by_tool_name: Mapping[str, str]
|
||||
capability_ids_by_sub_commander_id: Mapping[str, tuple[str, ...]]
|
||||
spawnable_role_values_by_agent_id: Mapping[str, tuple[str, ...]]
|
||||
|
||||
|
||||
def summarize_registry_indexes(indexes: RegistryIndexes) -> dict[str, int]:
|
||||
@@ -50,6 +52,9 @@ def build_registry_indexes(bundle: RegistryBundle) -> RegistryIndexes:
|
||||
|
||||
return RegistryIndexes(
|
||||
agent_by_id=MappingProxyType(agent_by_id),
|
||||
agent_by_role_value=MappingProxyType({
|
||||
agent.role_value: agent for agent in bundle.agents
|
||||
}),
|
||||
sub_commander_by_id=MappingProxyType(sub_commander_by_id),
|
||||
capability_by_id=MappingProxyType(capability_by_id),
|
||||
specialist_template_by_id=MappingProxyType(specialist_template_by_id),
|
||||
@@ -73,4 +78,9 @@ def build_registry_indexes(bundle: RegistryBundle) -> RegistryIndexes:
|
||||
sub_commander.sub_commander_id: tuple(sub_commander.capability_ids)
|
||||
for sub_commander in bundle.sub_commanders
|
||||
}),
|
||||
spawnable_role_values_by_agent_id=MappingProxyType({
|
||||
agent.agent_id: tuple(agent.allowed_spawn_role_values)
|
||||
for agent in bundle.agents
|
||||
if agent.can_spawn_children and agent.allowed_spawn_role_values
|
||||
}),
|
||||
)
|
||||
|
||||
@@ -1,4 +1,19 @@
|
||||
from pydantic import BaseModel
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PermissionClass(str, Enum):
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
EXTERNAL = "external"
|
||||
|
||||
|
||||
class SideEffectScope(str, Enum):
|
||||
NONE = "none"
|
||||
LOCAL_STATE = "local_state"
|
||||
DB_WRITE = "db_write"
|
||||
NETWORK = "network"
|
||||
|
||||
|
||||
class AgentManifest(BaseModel):
|
||||
@@ -8,6 +23,8 @@ class AgentManifest(BaseModel):
|
||||
system_prompt_key: str
|
||||
routing_hints: list[str]
|
||||
default_sub_commanders: list[str]
|
||||
can_spawn_children: bool = False
|
||||
allowed_spawn_role_values: list[str] = Field(default_factory=list)
|
||||
skill_context_key: str | None = None
|
||||
continuity_policy: str | None = None
|
||||
clarification_policy: str | None = None
|
||||
@@ -23,6 +40,12 @@ class SubCommanderManifest(BaseModel):
|
||||
class CapabilityManifest(BaseModel):
|
||||
capability_id: str
|
||||
tool_name: str
|
||||
permission_class: PermissionClass = PermissionClass.READ
|
||||
side_effect_scope: SideEffectScope = SideEffectScope.NONE
|
||||
supports_retry: bool = False
|
||||
idempotent: bool = False
|
||||
safe_for_parallel_use: bool = False
|
||||
requires_confirmation: bool = False
|
||||
|
||||
|
||||
class SpecialistTemplateManifest(BaseModel):
|
||||
|
||||
86
backend/app/agents/runtime_metrics.py
Normal file
86
backend/app/agents/runtime_metrics.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
INPUT_TOKEN_USD_RATE = 0.000003
|
||||
OUTPUT_TOKEN_USD_RATE = 0.000015
|
||||
DEFAULT_COST_THRESHOLDS = {
|
||||
"total_tokens": 4000,
|
||||
"estimated_cost": 0.02,
|
||||
}
|
||||
|
||||
|
||||
def estimate_token_cost(input_tokens: int, output_tokens: int) -> float | None:
|
||||
total_tokens = max(input_tokens, 0) + max(output_tokens, 0)
|
||||
if total_tokens <= 0:
|
||||
return None
|
||||
return round(
|
||||
(max(input_tokens, 0) * INPUT_TOKEN_USD_RATE)
|
||||
+ (max(output_tokens, 0) * OUTPUT_TOKEN_USD_RATE),
|
||||
6,
|
||||
)
|
||||
|
||||
|
||||
def extract_token_usage(response: Any) -> tuple[int, int]:
|
||||
usage_metadata = getattr(response, "usage_metadata", None) or {}
|
||||
if isinstance(usage_metadata, dict):
|
||||
input_tokens = int(
|
||||
usage_metadata.get("input_tokens")
|
||||
or usage_metadata.get("prompt_tokens")
|
||||
or 0
|
||||
)
|
||||
output_tokens = int(
|
||||
usage_metadata.get("output_tokens")
|
||||
or usage_metadata.get("completion_tokens")
|
||||
or 0
|
||||
)
|
||||
if input_tokens or output_tokens:
|
||||
return input_tokens, output_tokens
|
||||
|
||||
response_metadata = getattr(response, "response_metadata", None) or {}
|
||||
token_usage = {}
|
||||
if isinstance(response_metadata, dict):
|
||||
token_usage = response_metadata.get("token_usage") or response_metadata.get("usage") or {}
|
||||
if isinstance(token_usage, dict):
|
||||
input_tokens = int(
|
||||
token_usage.get("prompt_tokens")
|
||||
or token_usage.get("input_tokens")
|
||||
or 0
|
||||
)
|
||||
output_tokens = int(
|
||||
token_usage.get("completion_tokens")
|
||||
or token_usage.get("output_tokens")
|
||||
or 0
|
||||
)
|
||||
if input_tokens or output_tokens:
|
||||
return input_tokens, output_tokens
|
||||
|
||||
return 0, 0
|
||||
|
||||
|
||||
def coerce_cost_thresholds(raw_thresholds: Any) -> dict[str, float]:
|
||||
thresholds: dict[str, float] = dict(DEFAULT_COST_THRESHOLDS)
|
||||
if not isinstance(raw_thresholds, dict):
|
||||
return thresholds
|
||||
for key in DEFAULT_COST_THRESHOLDS:
|
||||
value = raw_thresholds.get(key)
|
||||
if isinstance(value, (int, float)) and value > 0:
|
||||
thresholds[key] = float(value)
|
||||
return thresholds
|
||||
|
||||
|
||||
def is_cost_budget_warning(
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
estimated_cost: float | None,
|
||||
thresholds: dict[str, float] | None = None,
|
||||
) -> bool:
|
||||
effective_thresholds = thresholds or DEFAULT_COST_THRESHOLDS
|
||||
total_tokens = max(input_tokens, 0) + max(output_tokens, 0)
|
||||
token_threshold = float(effective_thresholds.get("total_tokens") or 0)
|
||||
cost_threshold = float(effective_thresholds.get("estimated_cost") or 0)
|
||||
return (
|
||||
(token_threshold > 0 and total_tokens >= token_threshold)
|
||||
or (cost_threshold > 0 and estimated_cost is not None and estimated_cost >= cost_threshold)
|
||||
)
|
||||
25
backend/app/agents/schemas/__init__.py
Normal file
25
backend/app/agents/schemas/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from app.agents.schemas.event import AgentEvent
|
||||
from app.agents.schemas.message import AgentMessage
|
||||
from app.agents.schemas.task import (
|
||||
AgentTask,
|
||||
CollaborationBudget,
|
||||
InterruptRecord,
|
||||
RecoveryRecord,
|
||||
TaskLifecycleStatus,
|
||||
TaskResult,
|
||||
TaskResultStatus,
|
||||
VerificationStatus,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentEvent",
|
||||
"AgentMessage",
|
||||
"AgentTask",
|
||||
"CollaborationBudget",
|
||||
"InterruptRecord",
|
||||
"RecoveryRecord",
|
||||
"TaskLifecycleStatus",
|
||||
"TaskResult",
|
||||
"TaskResultStatus",
|
||||
"VerificationStatus",
|
||||
]
|
||||
52
backend/app/agents/schemas/event.py
Normal file
52
backend/app/agents/schemas/event.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
AgentEventType = Literal[
|
||||
"agent.tool.start",
|
||||
"agent.tool.result",
|
||||
"agent.verify.started",
|
||||
"agent.verify.completed",
|
||||
"agent.created",
|
||||
"agent.spawn.blocked",
|
||||
"agent.message.sent",
|
||||
"agent.message.received",
|
||||
"agent.interrupt.requested",
|
||||
"agent.interrupt.completed",
|
||||
"agent.recovery.started",
|
||||
"agent.recovery.completed",
|
||||
"agent.task.interrupted",
|
||||
"agent.task.recovered",
|
||||
"agent.task.reassigned",
|
||||
"agent.collaboration.budget.updated",
|
||||
"agent.isolation.selected",
|
||||
"agent.isolation.fallback",
|
||||
"agent.cost.updated",
|
||||
"agent.cost.warning",
|
||||
"agent.phase.changed",
|
||||
"agent.checkpoint.recorded",
|
||||
"agent.error",
|
||||
]
|
||||
AgentEventSeverity = Literal["info", "warning", "error"]
|
||||
|
||||
|
||||
class AgentEvent(BaseModel):
|
||||
event_id: str
|
||||
event_type: AgentEventType
|
||||
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
conversation_id: str | None = None
|
||||
agent_id: str | None = None
|
||||
sub_commander_id: str | None = None
|
||||
task_id: str | None = None
|
||||
parent_task_id: str | None = None
|
||||
child_task_id: str | None = None
|
||||
thread_id: str | None = None
|
||||
message_id: str | None = None
|
||||
interrupt_id: str | None = None
|
||||
recovery_id: str | None = None
|
||||
payload: dict[str, Any] = Field(default_factory=dict)
|
||||
severity: AgentEventSeverity = "info"
|
||||
29
backend/app/agents/schemas/message.py
Normal file
29
backend/app/agents/schemas/message.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
AgentMessageType = Literal[
|
||||
"task_request",
|
||||
"task_update",
|
||||
"handoff",
|
||||
"verification_request",
|
||||
"verification_feedback",
|
||||
"interrupt_notice",
|
||||
]
|
||||
|
||||
|
||||
class AgentMessage(BaseModel):
|
||||
message_id: str
|
||||
thread_id: str
|
||||
from_agent_id: str
|
||||
to_agent_id: str
|
||||
task_id: str | None = None
|
||||
reply_to_message_id: str | None = None
|
||||
message_type: AgentMessageType = "task_update"
|
||||
content_summary: str
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
payload: dict[str, Any] = Field(default_factory=dict)
|
||||
85
backend/app/agents/schemas/task.py
Normal file
85
backend/app/agents/schemas/task.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
TaskLifecycleStatus = Literal["pending", "in_progress", "completed", "failed", "blocked"]
|
||||
VerificationStatus = Literal["passed", "failed", "skipped"]
|
||||
TaskResultStatus = Literal["completed", "failed", "blocked", "passed", "skipped"]
|
||||
InterruptStatus = Literal["requested", "acknowledged", "resolved"]
|
||||
BudgetMode = Literal["direct", "collaboration"]
|
||||
|
||||
|
||||
class InterruptRecord(BaseModel):
|
||||
interrupt_id: str
|
||||
reason: str
|
||||
status: InterruptStatus = "requested"
|
||||
requested_by: str | None = None
|
||||
source_event_id: str | None = None
|
||||
requested_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
payload: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class RecoveryRecord(BaseModel):
|
||||
recovery_id: str
|
||||
source_interrupt_id: str | None = None
|
||||
strategy: str | None = None
|
||||
resumed_from_task_id: str | None = None
|
||||
resumed_from_thread_id: str | None = None
|
||||
recovered_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
payload: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class CollaborationBudget(BaseModel):
|
||||
mode: BudgetMode = "direct"
|
||||
max_parallel_tasks: int | None = None
|
||||
remaining_parallel_tasks: int | None = None
|
||||
max_tool_calls: int | None = None
|
||||
remaining_tool_calls: int | None = None
|
||||
max_iterations: int | None = None
|
||||
remaining_iterations: int | None = None
|
||||
escalation_threshold: int | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class AgentTask(BaseModel):
|
||||
task_id: str
|
||||
title: str
|
||||
status: TaskLifecycleStatus = "pending"
|
||||
owner_agent_id: str | None = None
|
||||
role: str | None = None
|
||||
goal: str | None = None
|
||||
parent_task_id: str | None = None
|
||||
child_task_ids: list[str] = Field(default_factory=list)
|
||||
thread_id: str | None = None
|
||||
message_id: str | None = None
|
||||
message_index: int | None = None
|
||||
expected_evidence: list[dict[str, Any]] = Field(default_factory=list)
|
||||
evidence: list[dict[str, Any]] = Field(default_factory=list)
|
||||
interrupt_records: list[InterruptRecord | dict[str, Any]] = Field(default_factory=list)
|
||||
recovery_records: list[RecoveryRecord | dict[str, Any]] = Field(default_factory=list)
|
||||
collaboration_budget: CollaborationBudget | dict[str, Any] | None = None
|
||||
result_summary: str | None = None
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
class TaskResult(BaseModel):
|
||||
task_id: str
|
||||
status: TaskResultStatus
|
||||
summary: str | None = None
|
||||
evidence: list[dict[str, Any]] = Field(default_factory=list)
|
||||
owner_agent_id: str | None = None
|
||||
parent_task_id: str | None = None
|
||||
child_task_ids: list[str] = Field(default_factory=list)
|
||||
thread_id: str | None = None
|
||||
message_id: str | None = None
|
||||
message_index: int | None = None
|
||||
interrupt_records: list[InterruptRecord | dict[str, Any]] = Field(default_factory=list)
|
||||
recovery_records: list[RecoveryRecord | dict[str, Any]] = Field(default_factory=list)
|
||||
budget_snapshot: CollaborationBudget | dict[str, Any] | None = None
|
||||
next_action: str | None = None
|
||||
output_data: dict[str, Any] | None = None
|
||||
17
backend/app/agents/session/__init__.py
Normal file
17
backend/app/agents/session/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Agent Session Management - Phase 10.3"""
|
||||
|
||||
from app.agents.session.manager import (
|
||||
AgentSession,
|
||||
SessionContext,
|
||||
SessionPersistence,
|
||||
create_agent_session,
|
||||
get_agent_session,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentSession",
|
||||
"SessionContext",
|
||||
"SessionPersistence",
|
||||
"create_agent_session",
|
||||
"get_agent_session",
|
||||
]
|
||||
238
backend/app/agents/session/manager.py
Normal file
238
backend/app/agents/session/manager.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""Agent Session 管理 - Phase 10.3
|
||||
|
||||
支持会话层级管理和子会话创建。
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionContext:
|
||||
"""会话上下文"""
|
||||
|
||||
session_id: str
|
||||
parent_session_id: str | None = None
|
||||
root_session_id: str | None = None
|
||||
depth: int = 0
|
||||
user_id: str | None = None
|
||||
created_at: str | None = None
|
||||
last_active: str | None = None
|
||||
message_count: int = 0
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.now().isoformat()
|
||||
if self.last_active is None:
|
||||
self.last_active = self.created_at
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionPersistence:
|
||||
"""会话持久化"""
|
||||
|
||||
def __init__(self, persistence_dir: str | None = None):
|
||||
if persistence_dir is None:
|
||||
persistence_dir = os.path.join(
|
||||
os.path.dirname(__file__), "..", "..", "..", "data", "sessions"
|
||||
)
|
||||
self.persistence_dir = persistence_dir
|
||||
|
||||
def _get_session_path(self, session_id: str) -> str:
|
||||
return os.path.join(self.persistence_dir, f"{session_id}.json")
|
||||
|
||||
def save(self, session: "AgentSession") -> bool:
|
||||
"""保存会话"""
|
||||
try:
|
||||
os.makedirs(self.persistence_dir, exist_ok=True)
|
||||
path = self._get_session_path(session.session_id)
|
||||
data = {
|
||||
"session_id": session.session_id,
|
||||
"parent_session_id": session.context.parent_session_id,
|
||||
"root_session_id": session.context.root_session_id,
|
||||
"depth": session.context.depth,
|
||||
"user_id": session.context.user_id,
|
||||
"created_at": session.context.created_at,
|
||||
"last_active": session.context.last_active,
|
||||
"message_count": session.context.message_count,
|
||||
"metadata": session.context.metadata,
|
||||
"history": session._history,
|
||||
}
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def load(self, session_id: str) -> dict[str, Any] | None:
|
||||
"""加载会话"""
|
||||
try:
|
||||
path = self._get_session_path(session_id)
|
||||
if not os.path.exists(path):
|
||||
return None
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def delete(self, session_id: str) -> bool:
|
||||
"""删除会话"""
|
||||
try:
|
||||
path = self._get_session_path(session_id)
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]:
|
||||
"""列出所有会话"""
|
||||
sessions = []
|
||||
try:
|
||||
os.makedirs(self.persistence_dir, exist_ok=True)
|
||||
for filename in os.listdir(self.persistence_dir):
|
||||
if filename.endswith(".json"):
|
||||
path = os.path.join(self.persistence_dir, filename)
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if user_id is None or data.get("user_id") == user_id:
|
||||
sessions.append(data)
|
||||
except Exception:
|
||||
pass
|
||||
return sessions
|
||||
|
||||
|
||||
class AgentSession:
|
||||
"""Agent 会话管理器
|
||||
|
||||
支持:
|
||||
- 会话层级(parent/root/depth)
|
||||
- 子会话创建
|
||||
- 会话摘要
|
||||
- 持久化
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
parent_session_id: str | None = None,
|
||||
):
|
||||
self.session_id = session_id or str(uuid.uuid4())[:8]
|
||||
self.context = SessionContext(
|
||||
session_id=self.session_id,
|
||||
user_id=user_id,
|
||||
parent_session_id=parent_session_id,
|
||||
depth=0 if parent_session_id is None else 1,
|
||||
)
|
||||
self._history: list[dict[str, Any]] = []
|
||||
self._persistence = SessionPersistence()
|
||||
|
||||
# 如果有父会话,设置 root_session_id
|
||||
if parent_session_id:
|
||||
parent_data = self._persistence.load(parent_session_id)
|
||||
if parent_data:
|
||||
self.context.root_session_id = (
|
||||
parent_data.get("root_session_id") or parent_session_id
|
||||
)
|
||||
self.context.depth = parent_data.get("depth", 0) + 1
|
||||
|
||||
async def initialize(self) -> dict[str, Any]:
|
||||
"""初始化会话"""
|
||||
self.context.last_active = datetime.now().isoformat()
|
||||
self._persistence.save(self)
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"depth": self.context.depth,
|
||||
"parent_session_id": self.context.parent_session_id,
|
||||
"root_session_id": self.context.root_session_id,
|
||||
}
|
||||
|
||||
async def process_message(self, message: str, response: str) -> None:
|
||||
"""处理消息并记录到历史"""
|
||||
self.context.message_count += 1
|
||||
self.context.last_active = datetime.now().isoformat()
|
||||
self._history.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": message,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
self._history.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": response,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
self._persistence.save(self)
|
||||
|
||||
async def spawn_child_session(self, user_id: str | None = None) -> "AgentSession":
|
||||
"""创建子会话"""
|
||||
child = AgentSession(
|
||||
user_id=user_id or self.context.user_id,
|
||||
parent_session_id=self.session_id,
|
||||
)
|
||||
child.context.root_session_id = self.context.root_session_id or self.session_id
|
||||
await child.initialize()
|
||||
return child
|
||||
|
||||
async def get_session_summary(self) -> dict[str, Any]:
|
||||
"""获取会话摘要"""
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"parent_session_id": self.context.parent_session_id,
|
||||
"root_session_id": self.context.root_session_id,
|
||||
"depth": self.context.depth,
|
||||
"user_id": self.context.user_id,
|
||||
"created_at": self.context.created_at,
|
||||
"last_active": self.context.last_active,
|
||||
"message_count": self.context.message_count,
|
||||
"history_length": len(self._history),
|
||||
}
|
||||
|
||||
async def persist(self) -> bool:
|
||||
"""持久化会话"""
|
||||
return self._persistence.save(self)
|
||||
|
||||
def get_history(self) -> list[dict[str, Any]]:
|
||||
"""获取会话历史"""
|
||||
return self._history.copy()
|
||||
|
||||
def add_metadata(self, key: str, value: Any) -> None:
|
||||
"""添加会话元数据"""
|
||||
self.context.metadata[key] = value
|
||||
|
||||
def get_metadata(self, key: str) -> Any:
|
||||
"""获取会话元数据"""
|
||||
return self.context.metadata.get(key)
|
||||
|
||||
|
||||
# 全局会话存储(内存中)
|
||||
_sessions: dict[str, AgentSession] = {}
|
||||
|
||||
|
||||
def get_agent_session(session_id: str) -> AgentSession | None:
|
||||
"""获取会话"""
|
||||
return _sessions.get(session_id)
|
||||
|
||||
|
||||
def create_agent_session(
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
parent_session_id: str | None = None,
|
||||
) -> AgentSession:
|
||||
"""创建新会话"""
|
||||
session = AgentSession(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
parent_session_id=parent_session_id,
|
||||
)
|
||||
_sessions[session.session_id] = session
|
||||
return session
|
||||
16
backend/app/agents/skills/__init__.py
Normal file
16
backend/app/agents/skills/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Skills 注册表 - Phase 9"""
|
||||
|
||||
from app.agents.skills.registry import SkillRegistry, get_skill_registry
|
||||
from app.agents.skills.metadata import SkillMetadata
|
||||
from app.agents.skills.loaders.local_loader import LocalSkillLoader
|
||||
from app.agents.skills.loaders.plugin_loader import PluginSkillLoader
|
||||
from app.agents.skills.mcp_builder import MCPSkillBuilder
|
||||
|
||||
__all__ = [
|
||||
"SkillRegistry",
|
||||
"SkillMetadata",
|
||||
"LocalSkillLoader",
|
||||
"PluginSkillLoader",
|
||||
"MCPSkillBuilder",
|
||||
"get_skill_registry",
|
||||
]
|
||||
72
backend/app/agents/skills/bundled.py
Normal file
72
backend/app/agents/skills/bundled.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Built-in Skills - Phase 9.4
|
||||
|
||||
This module contains bundled skills that are always available
|
||||
without requiring external skill loaders.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
# SkillMetadata-compatible structure for bundled skills
|
||||
BUNDLED_SKILLS: list[dict[str, Any]] = [
|
||||
{
|
||||
"id": "code-analysis",
|
||||
"name": "Code Analysis",
|
||||
"description": "Analyze code structure, patterns, and quality. Helps understand codebase architecture, find issues, and suggest improvements.",
|
||||
"version": "1.0.0",
|
||||
"prompts": [
|
||||
"Analyze the code structure and identify key components, their relationships, and responsibilities.",
|
||||
"Review the code for potential issues like bugs, security vulnerabilities, or performance problems.",
|
||||
"Explain how the code works and what it does in simple terms.",
|
||||
],
|
||||
"tools": ["grep", "read", "glob", "lsp_symbols", "lsp_find_references"],
|
||||
},
|
||||
{
|
||||
"id": "git-helper",
|
||||
"name": "Git Helper",
|
||||
"description": "Assists with Git operations including commit, branch management, merge conflicts, and repository exploration.",
|
||||
"version": "1.0.0",
|
||||
"prompts": [
|
||||
"Show me the current git status and any uncommitted changes.",
|
||||
"Help me create a meaningful commit message for these changes.",
|
||||
"Explain the git history and branch structure of this repository.",
|
||||
],
|
||||
"tools": ["bash"],
|
||||
},
|
||||
{
|
||||
"id": "web-research",
|
||||
"name": "Web Research",
|
||||
"description": "Search the web for information, documentation, and resources. Helps find answers and learn about technologies.",
|
||||
"version": "1.0.0",
|
||||
"prompts": [
|
||||
"Search the web for information about {topic} and summarize the key findings.",
|
||||
"Find official documentation or reliable resources about {topic}.",
|
||||
"Look up the latest news or developments in {topic}.",
|
||||
],
|
||||
"tools": ["search_brave_web_search", "websearch_web_search_exa", "webfetch"],
|
||||
},
|
||||
{
|
||||
"id": "file-management",
|
||||
"name": "File Management",
|
||||
"description": "Helps with file operations like creating, editing, organizing, and managing project files and directories.",
|
||||
"version": "1.0.0",
|
||||
"prompts": [
|
||||
"Create a new file at {path} with the following content: {content}",
|
||||
"Organize the files in the project structure and suggest improvements.",
|
||||
"Find all files related to {topic} or matching {pattern}.",
|
||||
],
|
||||
"tools": ["read", "write", "glob", "bash"],
|
||||
},
|
||||
{
|
||||
"id": "task-planning",
|
||||
"name": "Task Planning",
|
||||
"description": "Helps break down complex tasks into smaller steps, create implementation plans, and track progress.",
|
||||
"version": "1.0.0",
|
||||
"prompts": [
|
||||
"Break down this task into smaller, manageable steps: {task}",
|
||||
"Create an implementation plan for building {feature} with clear phases.",
|
||||
"Review the current progress and suggest next steps for completing {goal}.",
|
||||
],
|
||||
"tools": ["todowrite", "read", "write"],
|
||||
},
|
||||
]
|
||||
12
backend/app/agents/skills/loaders/__init__.py
Normal file
12
backend/app/agents/skills/loaders/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Skills 加载器包"""
|
||||
|
||||
from app.agents.skills.loaders.local_loader import LocalSkillLoader
|
||||
from app.agents.skills.loaders.plugin_loader import PluginSkillLoader
|
||||
from app.agents.skills.loaders.mcp_loader import MCPSkillLoader, get_mcp_skill_loader
|
||||
|
||||
__all__ = [
|
||||
"LocalSkillLoader",
|
||||
"PluginSkillLoader",
|
||||
"MCPSkillLoader",
|
||||
"get_mcp_skill_loader",
|
||||
]
|
||||
100
backend/app/agents/skills/loaders/local_loader.py
Normal file
100
backend/app/agents/skills/loaders/local_loader.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""本地 Skills 加载器 - Phase 9.2"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.agents.skills.metadata import SkillMetadata
|
||||
|
||||
|
||||
class LocalSkillLoader:
|
||||
"""本地 Skills 加载器
|
||||
|
||||
从 skills_dir 目录加载 SKILL.md 文件。
|
||||
"""
|
||||
|
||||
def __init__(self, skills_dir: str):
|
||||
self.skills_dir = skills_dir
|
||||
|
||||
def load_all(self) -> list[SkillMetadata]:
|
||||
"""加载所有本地 Skills
|
||||
|
||||
Returns:
|
||||
Skill 元数据列表
|
||||
"""
|
||||
skills = []
|
||||
|
||||
if not os.path.exists(self.skills_dir):
|
||||
return skills
|
||||
|
||||
for root, dirs, files in os.walk(self.skills_dir):
|
||||
# 跳过隐藏目录
|
||||
dirs[:] = [d for d in dirs if not d.startswith(".")]
|
||||
|
||||
if "SKILL.md" in files:
|
||||
skill = self._load_skill_from_dir(root)
|
||||
if skill:
|
||||
skills.append(skill)
|
||||
|
||||
return skills
|
||||
|
||||
def _load_skill_from_dir(self, skill_dir: str) -> SkillMetadata | None:
|
||||
"""从目录加载 Skill
|
||||
|
||||
Args:
|
||||
skill_dir: Skill 目录
|
||||
|
||||
Returns:
|
||||
Skill 元数据
|
||||
"""
|
||||
skill_path = os.path.join(skill_dir, "SKILL.md")
|
||||
|
||||
try:
|
||||
with open(skill_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# 解析 frontmatter
|
||||
metadata = self._parse_frontmatter(content)
|
||||
|
||||
# 获取 Skill 名称(目录名)
|
||||
name = os.path.basename(skill_dir)
|
||||
|
||||
return SkillMetadata(
|
||||
name=metadata.get("name", name),
|
||||
description=metadata.get("description", ""),
|
||||
version=metadata.get("version", "1.0.0"),
|
||||
author=metadata.get("author", ""),
|
||||
tags=metadata.get("tags", []),
|
||||
triggers=metadata.get("triggers", []),
|
||||
content=content,
|
||||
source="local",
|
||||
source_id=skill_dir,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _parse_frontmatter(self, content: str) -> dict[str, Any]:
|
||||
"""解析 frontmatter"""
|
||||
metadata = {}
|
||||
|
||||
# 匹配 --- 包裹的 frontmatter
|
||||
match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL)
|
||||
if match:
|
||||
frontmatter = match.group(1)
|
||||
|
||||
for line in frontmatter.split("\n"):
|
||||
if ":" in line:
|
||||
key, value = line.split(":", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
# 处理列表
|
||||
if value.startswith("[") and value.endswith("]"):
|
||||
value = [v.strip().strip('"').strip("'") for v in value[1:-1].split(",")]
|
||||
elif value.lower() in ("true", "false"):
|
||||
value = value.lower() == "true"
|
||||
|
||||
metadata[key] = value
|
||||
|
||||
return metadata
|
||||
169
backend/app/agents/skills/loaders/mcp_loader.py
Normal file
169
backend/app/agents/skills/loaders/mcp_loader.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""MCP Skill 加载器 - Phase 9.2
|
||||
|
||||
从 MCP (Model Context Protocol) 服务器发现和加载 Skills。
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from app.agents.skills.metadata import SkillMetadata
|
||||
|
||||
|
||||
class MCPSkillLoader:
|
||||
"""MCP Skill 加载器
|
||||
|
||||
从 MCP 服务器发现可用的 Skills。
|
||||
"""
|
||||
|
||||
def __init__(self, mcp_servers: list[dict[str, Any]] | None = None):
|
||||
"""
|
||||
Args:
|
||||
mcp_servers: MCP 服务器列表,每项包含 name, command, env 等
|
||||
"""
|
||||
self.mcp_servers = mcp_servers or []
|
||||
self._discovered_skills: dict[str, SkillMetadata] = {}
|
||||
|
||||
def discover_skills(self) -> list[SkillMetadata]:
|
||||
"""从所有配置的 MCP 服务器发现 Skills
|
||||
|
||||
Returns:
|
||||
发现的 Skill 列表
|
||||
"""
|
||||
skills = []
|
||||
|
||||
for server in self.mcp_servers:
|
||||
server_skills = self._discover_from_server(server)
|
||||
skills.extend(server_skills)
|
||||
|
||||
return skills
|
||||
|
||||
def _discover_from_server(self, server: dict[str, Any]) -> list[SkillMetadata]:
|
||||
"""从单个 MCP 服务器发现 Skills
|
||||
|
||||
Args:
|
||||
server: 服务器配置
|
||||
|
||||
Returns:
|
||||
Skill 列表
|
||||
"""
|
||||
skills = []
|
||||
server_name = server.get("name", "unknown")
|
||||
|
||||
# 模拟从 MCP 服务器获取工具列表
|
||||
# 实际实现时,这里会调用 MCP 服务器的 list_tools 接口
|
||||
try:
|
||||
tools = self._call_mcp_list_tools(server)
|
||||
for tool in tools:
|
||||
skill = self._tool_to_skill(tool, server_name)
|
||||
if skill:
|
||||
skills.append(skill)
|
||||
self._discovered_skills[skill.name] = skill
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return skills
|
||||
|
||||
def _call_mcp_list_tools(self, server: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""调用 MCP 服务器的 list_tools 接口
|
||||
|
||||
Args:
|
||||
server: 服务器配置
|
||||
|
||||
Returns:
|
||||
工具列表
|
||||
"""
|
||||
# TODO: 实现实际的 MCP 协议调用
|
||||
# 目前返回空列表,实际使用时需要实现 MCP 客户端
|
||||
return []
|
||||
|
||||
def _tool_to_skill(self, tool: dict[str, Any], server: str) -> SkillMetadata | None:
|
||||
"""将 MCP 工具转换为 Skill
|
||||
|
||||
Args:
|
||||
tool: MCP 工具定义
|
||||
server: 服务器名
|
||||
|
||||
Returns:
|
||||
Skill 元数据或 None
|
||||
"""
|
||||
tool_name = tool.get("name")
|
||||
if not tool_name:
|
||||
return None
|
||||
|
||||
return SkillMetadata(
|
||||
id=f"mcp_{server}_{tool_name}",
|
||||
name=f"{server}:{tool_name}",
|
||||
description=tool.get("description", f"MCP tool: {tool_name}"),
|
||||
version="1.0.0",
|
||||
content=self._generate_skill_content(tool),
|
||||
triggers=[f"@{server}", f"/{tool_name}"],
|
||||
tools=[tool_name],
|
||||
tags=["mcp", server],
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
def _generate_skill_content(self, tool: dict[str, Any]) -> str:
|
||||
"""生成 Skill 内容
|
||||
|
||||
Args:
|
||||
tool: MCP 工具定义
|
||||
|
||||
Returns:
|
||||
Skill 内容字符串
|
||||
"""
|
||||
name = tool.get("name", "unknown")
|
||||
description = tool.get("description", "No description")
|
||||
input_schema = tool.get("inputSchema", {})
|
||||
|
||||
content = f"""# MCP Tool: {name}
|
||||
|
||||
**Description**: {description}
|
||||
|
||||
**Server**: {tool.get("server", "unknown")}
|
||||
|
||||
**Input Schema**:
|
||||
```json
|
||||
{input_schema}
|
||||
```
|
||||
|
||||
**Usage**:
|
||||
Use the `/{name}` command or `@{tool.get("server", "server")}` to invoke this tool.
|
||||
|
||||
**Examples**:
|
||||
```
|
||||
/{name} arg1=value1 arg2=value2
|
||||
@{tool.get("server", "server")} {name} --arg1 value1
|
||||
```
|
||||
"""
|
||||
return content
|
||||
|
||||
def get_skill(self, name: str) -> SkillMetadata | None:
|
||||
"""获取已发现的 Skill
|
||||
|
||||
Args:
|
||||
name: Skill 名称
|
||||
|
||||
Returns:
|
||||
Skill 元数据或 None
|
||||
"""
|
||||
return self._discovered_skills.get(name)
|
||||
|
||||
def list_skills(self) -> list[SkillMetadata]:
|
||||
"""列出所有已发现的 Skills
|
||||
|
||||
Returns:
|
||||
Skill 列表
|
||||
"""
|
||||
return list(self._discovered_skills.values())
|
||||
|
||||
|
||||
# 全局加载器
|
||||
_loader: MCPSkillLoader | None = None
|
||||
|
||||
|
||||
def get_mcp_skill_loader() -> MCPSkillLoader:
|
||||
"""获取全局 MCP Skill 加载器"""
|
||||
global _loader
|
||||
if _loader is None:
|
||||
_loader = MCPSkillLoader()
|
||||
return _loader
|
||||
53
backend/app/agents/skills/loaders/plugin_loader.py
Normal file
53
backend/app/agents/skills/loaders/plugin_loader.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""插件 Skills 加载器 - Phase 9.2"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.skills.metadata import SkillMetadata
|
||||
from app.agents.plugins.manager import get_plugin_manager
|
||||
|
||||
|
||||
class PluginSkillLoader:
|
||||
"""插件 Skills 加载器
|
||||
|
||||
从已安装的插件中加载 Skills。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.plugin_manager = get_plugin_manager()
|
||||
|
||||
def load_all(self) -> list[SkillMetadata]:
|
||||
"""从所有已启用的插件加载 Skills
|
||||
|
||||
Returns:
|
||||
Skill 元数据列表
|
||||
"""
|
||||
skills = []
|
||||
|
||||
for plugin in self.plugin_manager.list_plugins():
|
||||
if not self.plugin_manager.is_enabled(plugin.id):
|
||||
continue
|
||||
|
||||
# 从插件加载 Skills
|
||||
plugin_skills = self._load_from_plugin(plugin)
|
||||
skills.extend(plugin_skills)
|
||||
|
||||
return skills
|
||||
|
||||
def _load_from_plugin(self, plugin: Any) -> list[SkillMetadata]:
|
||||
"""从单个插件加载 Skills"""
|
||||
skills = []
|
||||
|
||||
for skill_name in plugin.skills:
|
||||
skill = SkillMetadata(
|
||||
name=f"{plugin.id}/{skill_name}",
|
||||
description=f"Skill from plugin: {plugin.name}",
|
||||
version=plugin.version,
|
||||
author=plugin.author,
|
||||
tags=["plugin", plugin.id],
|
||||
content=f"# {skill_name}\n\nFrom plugin: {plugin.name}",
|
||||
source="plugin",
|
||||
source_id=plugin.id,
|
||||
)
|
||||
skills.append(skill)
|
||||
|
||||
return skills
|
||||
100
backend/app/agents/skills/mcp_builder.py
Normal file
100
backend/app/agents/skills/mcp_builder.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""MCP Skill Builder - Phase 9.3"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.skills.metadata import SkillMetadata
|
||||
|
||||
|
||||
class MCPSkillBuilder:
|
||||
"""MCP Skill Builder
|
||||
|
||||
从 MCP 服务器发现和构建 Skills。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._skills: dict[str, SkillMetadata] = {}
|
||||
|
||||
def discover_skills_from_mcp(self, mcp_servers: list[dict[str, Any]]) -> list[SkillMetadata]:
|
||||
"""从 MCP 服务器发现 Skills
|
||||
|
||||
Args:
|
||||
mcp_servers: MCP 服务器配置列表
|
||||
|
||||
Returns:
|
||||
发现的 Skill 元数据列表
|
||||
"""
|
||||
skills = []
|
||||
|
||||
for server in mcp_servers:
|
||||
server_skills = self._discover_from_server(server)
|
||||
skills.extend(server_skills)
|
||||
|
||||
return skills
|
||||
|
||||
def _discover_from_server(self, server: dict[str, Any]) -> list[SkillMetadata]:
|
||||
"""从单个 MCP 服务器发现 Skills"""
|
||||
skills = []
|
||||
server_name = server.get("name", "unknown")
|
||||
tools = server.get("tools", [])
|
||||
|
||||
# 按工具分组
|
||||
tool_groups: dict[str, list[str]] = {}
|
||||
for tool in tools:
|
||||
group = tool.get("group", "default")
|
||||
if group not in tool_groups:
|
||||
tool_groups[group] = []
|
||||
tool_groups[group].append(tool)
|
||||
|
||||
# 为每个组创建一个 Skill
|
||||
for group_name, group_tools in tool_groups.items():
|
||||
skill = self._tool_to_skill(group_name, group_tools, server_name)
|
||||
skills.append(skill)
|
||||
|
||||
return skills
|
||||
|
||||
def _tool_to_skill(self, group: str, tools: list[dict[str, Any]], server: str) -> SkillMetadata:
|
||||
"""将 MCP 工具转换为 Skill"""
|
||||
tool_summaries = []
|
||||
for tool in tools:
|
||||
name = tool.get("name", "unknown")
|
||||
description = tool.get("description", "")
|
||||
input_schema = tool.get("inputSchema", {})
|
||||
|
||||
tool_summaries.append(f"### {name}\n{description}\n\nInput: {input_schema}")
|
||||
|
||||
content = f"""# MCP Skill: {group}
|
||||
|
||||
来自 MCP 服务器: {server}
|
||||
|
||||
## 工具列表
|
||||
|
||||
{chr(10).join(tool_summaries)}
|
||||
|
||||
## 使用说明
|
||||
|
||||
使用这些工具前请确保理解每个工具的输入输出格式。
|
||||
"""
|
||||
|
||||
return SkillMetadata(
|
||||
name=f"mcp-{server}-{group}",
|
||||
description=f"MCP skill from {server}: {group}",
|
||||
version="1.0.0",
|
||||
tags=["mcp", server, group],
|
||||
triggers=[group, server],
|
||||
content=content,
|
||||
source="mcp",
|
||||
source_id=f"{server}:{group}",
|
||||
)
|
||||
|
||||
def _group_to_skill(self, group: str, tools: list[str], server: str) -> SkillMetadata:
|
||||
"""将 MCP 工具组转换为 Skill"""
|
||||
return SkillMetadata(
|
||||
name=f"mcp-{server}-{group}",
|
||||
description=f"MCP skill from {server}: {group}",
|
||||
version="1.0.0",
|
||||
tags=["mcp", server, group],
|
||||
triggers=[group, server],
|
||||
content=f"# {group}\n\nTools: {', '.join(tools)}",
|
||||
source="mcp",
|
||||
source_id=f"{server}:{group}",
|
||||
)
|
||||
42
backend/app/agents/skills/metadata.py
Normal file
42
backend/app/agents/skills/metadata.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Skill 元数据定义 - Phase 9.1"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillMetadata:
|
||||
"""Skill 元数据"""
|
||||
|
||||
id: str = "" # Skill ID
|
||||
name: str = "" # Skill 名称
|
||||
description: str = "" # 描述
|
||||
version: str = "1.0.0" # 版本
|
||||
author: str = "" # 作者
|
||||
tags: list[str] = field(default_factory=list) # 标签
|
||||
triggers: list[str] = field(default_factory=list) # 触发关键词
|
||||
content: str = "" # Skill 内容(markdown)
|
||||
source: str = "local" # 来源:local, plugin, mcp, bundled
|
||||
source_id: str = "" # 来源 ID
|
||||
enabled: bool = True # 是否启用
|
||||
tools: list[str] = field(default_factory=list) # 关联的工具
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"version": self.version,
|
||||
"author": self.author,
|
||||
"tags": self.tags,
|
||||
"triggers": self.triggers,
|
||||
"content": self.content,
|
||||
"source": self.source,
|
||||
"source_id": self.source_id,
|
||||
"enabled": self.enabled,
|
||||
"tools": self.tools,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SkillMetadata":
|
||||
return cls(**data)
|
||||
133
backend/app/agents/skills/registry.py
Normal file
133
backend/app/agents/skills/registry.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Skills 注册表 - Phase 9.1"""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from app.agents.skills.metadata import SkillMetadata
|
||||
from app.agents.skills.loaders.local_loader import LocalSkillLoader
|
||||
|
||||
|
||||
class SkillRegistry:
|
||||
"""Skills 注册表
|
||||
|
||||
管理所有 Skills 的注册、发现和加载。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._skills: dict[str, SkillMetadata] = {}
|
||||
self._loaders: list[Any] = []
|
||||
|
||||
def load_all(self, skills_dir: str | None = None) -> int:
|
||||
"""加载所有 Skills
|
||||
|
||||
Args:
|
||||
skills_dir: Skills 目录,None 则使用默认目录
|
||||
|
||||
Returns:
|
||||
加载的 Skill 数量
|
||||
"""
|
||||
if skills_dir is None:
|
||||
skills_dir = os.path.join(
|
||||
os.path.dirname(__file__), "..", "..", "..", ".claude", "skills"
|
||||
)
|
||||
|
||||
count = 0
|
||||
|
||||
# 本地加载器
|
||||
local_loader = LocalSkillLoader(skills_dir)
|
||||
local_skills = local_loader.load_all()
|
||||
for skill in local_skills:
|
||||
self.register(skill)
|
||||
count += 1
|
||||
|
||||
# 插件加载器
|
||||
for loader in self._loaders:
|
||||
try:
|
||||
external_skills = loader.load_all()
|
||||
for skill in external_skills:
|
||||
self.register(skill)
|
||||
count += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return count
|
||||
|
||||
def register(self, skill: SkillMetadata) -> None:
|
||||
"""注册 Skill"""
|
||||
self._skills[skill.name] = skill
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""注销 Skill"""
|
||||
if name in self._skills:
|
||||
del self._skills[name]
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_skill(self, name: str) -> SkillMetadata | None:
|
||||
"""获取 Skill"""
|
||||
return self._skills.get(name)
|
||||
|
||||
def search(self, query: str) -> list[SkillMetadata]:
|
||||
"""搜索 Skills
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
|
||||
Returns:
|
||||
匹配的 Skills 列表
|
||||
"""
|
||||
query_lower = query.lower()
|
||||
results = []
|
||||
|
||||
for skill in self._skills.values():
|
||||
if not skill.enabled:
|
||||
continue
|
||||
|
||||
# 匹配名称、描述、标签
|
||||
if (
|
||||
query_lower in skill.name.lower()
|
||||
or query_lower in skill.description.lower()
|
||||
or any(query_lower in tag.lower() for tag in skill.tags)
|
||||
or any(query_lower in trigger.lower() for trigger in skill.triggers)
|
||||
):
|
||||
results.append(skill)
|
||||
|
||||
return results
|
||||
|
||||
def get_skill_context(self, names: list[str]) -> str:
|
||||
"""获取 Skill 上下文
|
||||
|
||||
Args:
|
||||
names: Skill 名称列表
|
||||
|
||||
Returns:
|
||||
拼接的 Skill 内容
|
||||
"""
|
||||
contexts = []
|
||||
|
||||
for name in names:
|
||||
skill = self._skills.get(name)
|
||||
if skill and skill.enabled:
|
||||
contexts.append(f"# {skill.name}\n\n{skill.content}")
|
||||
|
||||
return "\n\n---\n\n".join(contexts)
|
||||
|
||||
def add_loader(self, loader: Any) -> None:
|
||||
"""添加加载器"""
|
||||
self._loaders.append(loader)
|
||||
|
||||
def list_all(self) -> list[SkillMetadata]:
|
||||
"""列出所有 Skills"""
|
||||
return list(self._skills.values())
|
||||
|
||||
|
||||
# 全局单例
|
||||
_registry: SkillRegistry | None = None
|
||||
|
||||
|
||||
def get_skill_registry() -> SkillRegistry:
|
||||
"""获取全局 Skills 注册表"""
|
||||
global _registry
|
||||
if _registry is None:
|
||||
_registry = SkillRegistry()
|
||||
return _registry
|
||||
140
backend/app/agents/skills/trigger.py
Normal file
140
backend/app/agents/skills/trigger.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""Skill 触发检测器 - Phase 9.5
|
||||
|
||||
检测消息中的 Skill 触发条件。
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.agents.skills.metadata import SkillMetadata
|
||||
|
||||
|
||||
class SkillTriggerDetector:
|
||||
"""Skill 触发检测器
|
||||
|
||||
检测用户消息中是否触发了某个 Skill。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._skills: dict[str, SkillMetadata] = {}
|
||||
|
||||
def register_skill(self, skill: SkillMetadata) -> None:
|
||||
"""注册 Skill
|
||||
|
||||
Args:
|
||||
skill: Skill 元数据
|
||||
"""
|
||||
self._skills[skill.name] = skill
|
||||
|
||||
def unregister_skill(self, name: str) -> bool:
|
||||
"""注销 Skill
|
||||
|
||||
Args:
|
||||
name: Skill 名称
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if name in self._skills:
|
||||
del self._skills[name]
|
||||
return True
|
||||
return False
|
||||
|
||||
def detect_triggered_skills(self, message: str) -> list[str]:
|
||||
"""检测触发的 Skills
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
|
||||
Returns:
|
||||
触发的 Skill 名称列表
|
||||
"""
|
||||
triggered = []
|
||||
message_lower = message.lower()
|
||||
|
||||
for skill in self._skills.values():
|
||||
if not skill.enabled:
|
||||
continue
|
||||
|
||||
if self._matches_triggers(message, message_lower, skill):
|
||||
triggered.append(skill.name)
|
||||
|
||||
return triggered
|
||||
|
||||
def _matches_triggers(self, message: str, message_lower: str, skill: SkillMetadata) -> bool:
|
||||
"""检查消息是否匹配 Skill 触发条件
|
||||
|
||||
Args:
|
||||
message: 原始消息
|
||||
message_lower: 小写消息
|
||||
skill: Skill 元数据
|
||||
|
||||
Returns:
|
||||
是否匹配
|
||||
"""
|
||||
for trigger in skill.triggers:
|
||||
trigger_lower = trigger.lower()
|
||||
|
||||
# 前缀匹配,如 "/code" 或 "@git"
|
||||
if trigger_lower.startswith("/") or trigger_lower.startswith("@"):
|
||||
if message_lower.startswith(trigger_lower):
|
||||
return True
|
||||
|
||||
# 命令格式,如 "//analyze"
|
||||
if trigger_lower.startswith("//"):
|
||||
pattern = trigger_lower[2:]
|
||||
if re.search(rf"\b{re.escape(pattern)}\b", message_lower):
|
||||
return True
|
||||
|
||||
# 关键词匹配
|
||||
if trigger_lower in message_lower:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_skill_prompt(self, skill_name: str) -> str | None:
|
||||
"""获取 Skill 的提示词
|
||||
|
||||
Args:
|
||||
skill_name: Skill 名称
|
||||
|
||||
Returns:
|
||||
Skill 内容或 None
|
||||
"""
|
||||
skill = self._skills.get(skill_name)
|
||||
if skill:
|
||||
return skill.content
|
||||
return None
|
||||
|
||||
def get_triggered_skill_context(self, message: str) -> str:
|
||||
"""获取触发的 Skills 上下文
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
|
||||
Returns:
|
||||
拼接的 Skill 上下文
|
||||
"""
|
||||
triggered = self.detect_triggered_skills(message)
|
||||
if not triggered:
|
||||
return ""
|
||||
|
||||
contexts = []
|
||||
for skill_name in triggered:
|
||||
skill = self._skills.get(skill_name)
|
||||
if skill:
|
||||
contexts.append(f"# {skill.name}\n\n{skill.content}")
|
||||
|
||||
return "\n\n---\n\n".join(contexts)
|
||||
|
||||
|
||||
# 全局检测器
|
||||
_detector: SkillTriggerDetector | None = None
|
||||
|
||||
|
||||
def get_skill_trigger_detector() -> SkillTriggerDetector:
|
||||
"""获取全局 Skill 触发检测器"""
|
||||
global _detector
|
||||
if _detector is None:
|
||||
_detector = SkillTriggerDetector()
|
||||
return _detector
|
||||
@@ -1,10 +1,21 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypedDict, Annotated, Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Literal, TypedDict
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from app.agents.schemas.event import AgentEvent
|
||||
from app.agents.schemas.message import AgentMessage
|
||||
from app.agents.schemas.task import AgentTask, CollaborationBudget, InterruptRecord, RecoveryRecord, TaskResult, VerificationStatus
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
AgentPhase = Literal[
|
||||
"phase_0_bootstrap",
|
||||
"phase_1_routing",
|
||||
"phase_2_controlled_collaboration",
|
||||
"phase_3_dynamic_collaboration",
|
||||
"phase_4_visibility_and_verification",
|
||||
]
|
||||
|
||||
|
||||
class AgentRole(str, Enum):
|
||||
MASTER = "master"
|
||||
@@ -22,41 +33,113 @@ class ConversationTurn:
|
||||
model: str | None = None
|
||||
|
||||
|
||||
def turn_to_message(turn: ConversationTurn) -> BaseMessage:
|
||||
if turn.role == "user":
|
||||
return HumanMessage(content=turn.content)
|
||||
return AIMessage(content=turn.content)
|
||||
|
||||
|
||||
class AgentState(TypedDict):
|
||||
# Core message history with add_messages reducer
|
||||
messages: Annotated[list[BaseMessage], add_messages]
|
||||
|
||||
# Session identifiers
|
||||
user_id: str
|
||||
conversation_id: str
|
||||
parent_conversation_id: str | None
|
||||
thread_id: str | None
|
||||
last_message_id: str | None
|
||||
message_sequence: int
|
||||
agent_id: str | None
|
||||
parent_agent_id: str | None
|
||||
root_agent_id: str | None
|
||||
collaboration_depth: int
|
||||
spawned_agent_ids: list[str]
|
||||
|
||||
# Agent routing state
|
||||
execution_mode: Literal["direct", "collaboration", "delegated", "verified"]
|
||||
current_agent: str | None
|
||||
next_step: str | None # For explicit graph routing
|
||||
|
||||
# Traceability
|
||||
next_step: str | None
|
||||
active_agents: list[AgentRole]
|
||||
current_sub_commander: str | None
|
||||
active_sub_commanders: list[str]
|
||||
sub_commander_trace: list[dict[str, Any]]
|
||||
agent_trace: list[str]
|
||||
|
||||
# Task & Entity Tracking (Business Logic)
|
||||
pending_tasks: list[dict]
|
||||
completed_tasks: list[dict]
|
||||
created_entities: list[dict]
|
||||
event_trace: list[AgentEvent | dict[str, Any]]
|
||||
message_trace: list[AgentMessage | dict[str, Any]]
|
||||
|
||||
pending_tasks: list[dict[str, Any]]
|
||||
completed_tasks: list[dict[str, Any]]
|
||||
active_tasks: list[AgentTask | dict[str, Any]]
|
||||
task_results: list[TaskResult | dict[str, Any]]
|
||||
task_hierarchy: dict[str, list[str]]
|
||||
interrupted_tasks: list[InterruptRecord | dict[str, Any]]
|
||||
recovery_trace: list[RecoveryRecord | dict[str, Any]]
|
||||
recovery_points: list[dict[str, Any]]
|
||||
tool_calls: list[dict[str, Any]]
|
||||
last_tool_result: str | None
|
||||
action_results: list[dict[str, Any]]
|
||||
created_entities: list[dict[str, Any]]
|
||||
tool_outcomes: list[dict[str, Any]]
|
||||
task_result_summary: dict[str, Any] | None
|
||||
verifier_hints: dict[str, Any] | None
|
||||
|
||||
verification_status: VerificationStatus | None
|
||||
verification_summary: str | None
|
||||
verification_evidence: list[dict[str, Any]]
|
||||
isolation_mode: str
|
||||
isolation_id: str | None
|
||||
isolation_workspace_path: str | None
|
||||
isolation_parent_conversation_id: str | None
|
||||
isolation_metadata: dict[str, Any]
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
estimated_cost: float | None
|
||||
budget_warning: bool
|
||||
cost_by_agent: dict[str, dict[str, Any]]
|
||||
cost_thresholds: dict[str, Any]
|
||||
budget_state: CollaborationBudget | dict[str, Any] | None
|
||||
collaboration_budget_history: list[CollaborationBudget | dict[str, Any]]
|
||||
current_phase: AgentPhase
|
||||
phase_history: list[dict[str, Any]]
|
||||
current_checkpoint: str | None
|
||||
checkpoint_history: list[dict[str, Any]]
|
||||
|
||||
tool_strategy_used: str | None
|
||||
tool_round_count: int
|
||||
max_tool_rounds: int
|
||||
retry_count: int
|
||||
max_retries: int
|
||||
iteration_count: int
|
||||
max_iterations: int
|
||||
routing_hops: int
|
||||
max_routing_hops: int
|
||||
terminated_due_to_loop_guard: bool
|
||||
retrieval_trace: list[dict[str, Any]]
|
||||
stop_reason: str | None
|
||||
|
||||
clarification_needed: bool
|
||||
clarification_question: str | None
|
||||
fallback_parse_error: str | None
|
||||
should_respond: bool
|
||||
|
||||
# Context summaries (for long-term or cross-agent context)
|
||||
knowledge_context: str | None
|
||||
graph_context: str | None
|
||||
schedule_context_summary: str | None
|
||||
plan: str | None
|
||||
plan_steps: list[dict[str, Any]]
|
||||
analysis_report: str | None
|
||||
|
||||
# Output control
|
||||
final_response: str | None
|
||||
|
||||
# Memory & Environment
|
||||
|
||||
memory_context: str | None
|
||||
current_datetime_context: str | None
|
||||
|
||||
# Configuration
|
||||
user_llm_config: dict | None
|
||||
provider_capabilities: dict | None
|
||||
current_datetime_reference: dict[str, str] | None
|
||||
|
||||
turn_context: dict[str, Any] | None
|
||||
routing_decision: dict[str, Any] | None
|
||||
continuity_state: dict[str, Any] | None
|
||||
pending_action: dict[str, Any] | None
|
||||
last_completed_action: dict[str, Any] | None
|
||||
clarification_context: dict[str, Any] | None
|
||||
|
||||
user_llm_config: dict[str, Any] | None
|
||||
provider_capabilities: dict[str, Any] | None
|
||||
|
||||
|
||||
def initial_state(user_id: str, conversation_id: str) -> AgentState:
|
||||
@@ -64,18 +147,103 @@ def initial_state(user_id: str, conversation_id: str) -> AgentState:
|
||||
messages=[],
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
parent_conversation_id=None,
|
||||
thread_id=None,
|
||||
last_message_id=None,
|
||||
message_sequence=0,
|
||||
agent_id=AgentRole.MASTER.value,
|
||||
parent_agent_id=None,
|
||||
root_agent_id=AgentRole.MASTER.value,
|
||||
collaboration_depth=0,
|
||||
spawned_agent_ids=[],
|
||||
execution_mode="direct",
|
||||
current_agent=AgentRole.MASTER.value,
|
||||
next_step=None,
|
||||
active_agents=[AgentRole.MASTER],
|
||||
current_sub_commander=None,
|
||||
active_sub_commanders=[],
|
||||
sub_commander_trace=[],
|
||||
agent_trace=[AgentRole.MASTER.value],
|
||||
event_trace=[],
|
||||
message_trace=[],
|
||||
pending_tasks=[],
|
||||
completed_tasks=[],
|
||||
active_tasks=[],
|
||||
task_results=[],
|
||||
task_hierarchy={},
|
||||
interrupted_tasks=[],
|
||||
recovery_trace=[],
|
||||
recovery_points=[],
|
||||
tool_calls=[],
|
||||
last_tool_result=None,
|
||||
action_results=[],
|
||||
created_entities=[],
|
||||
tool_outcomes=[],
|
||||
task_result_summary=None,
|
||||
verifier_hints=None,
|
||||
verification_status=None,
|
||||
verification_summary=None,
|
||||
verification_evidence=[],
|
||||
isolation_mode="none",
|
||||
isolation_id=None,
|
||||
isolation_workspace_path=None,
|
||||
isolation_parent_conversation_id=None,
|
||||
isolation_metadata={},
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
estimated_cost=None,
|
||||
budget_warning=False,
|
||||
cost_by_agent={},
|
||||
cost_thresholds={},
|
||||
budget_state=None,
|
||||
collaboration_budget_history=[],
|
||||
current_phase="phase_0_bootstrap",
|
||||
phase_history=[
|
||||
{
|
||||
"phase": "phase_0_bootstrap",
|
||||
"reason": "initial_state_created",
|
||||
}
|
||||
],
|
||||
current_checkpoint="bootstrap.initialized",
|
||||
checkpoint_history=[
|
||||
{
|
||||
"checkpoint": "bootstrap.initialized",
|
||||
"phase": "phase_0_bootstrap",
|
||||
"reason": "initial_state_created",
|
||||
}
|
||||
],
|
||||
tool_strategy_used=None,
|
||||
tool_round_count=0,
|
||||
max_tool_rounds=2,
|
||||
retry_count=0,
|
||||
max_retries=1,
|
||||
iteration_count=0,
|
||||
max_iterations=3,
|
||||
routing_hops=0,
|
||||
max_routing_hops=2,
|
||||
terminated_due_to_loop_guard=False,
|
||||
retrieval_trace=[],
|
||||
stop_reason=None,
|
||||
clarification_needed=False,
|
||||
clarification_question=None,
|
||||
fallback_parse_error=None,
|
||||
should_respond=True,
|
||||
knowledge_context=None,
|
||||
graph_context=None,
|
||||
schedule_context_summary=None,
|
||||
plan=None,
|
||||
plan_steps=[],
|
||||
analysis_report=None,
|
||||
final_response=None,
|
||||
memory_context=None,
|
||||
current_datetime_context=None,
|
||||
current_datetime_reference=None,
|
||||
turn_context=None,
|
||||
routing_decision=None,
|
||||
continuity_state=None,
|
||||
pending_action=None,
|
||||
last_completed_action=None,
|
||||
clarification_context=None,
|
||||
user_llm_config=None,
|
||||
provider_capabilities=None,
|
||||
)
|
||||
|
||||
13
backend/app/agents/team/__init__.py
Normal file
13
backend/app/agents/team/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Team 多 Agent 协作"""
|
||||
|
||||
from app.agents.team.leader import TeamLeader, TeamTask, TaskStatus
|
||||
from app.agents.team.member import TeamMember, MemberStatus, MemberTask
|
||||
|
||||
__all__ = [
|
||||
"TeamLeader",
|
||||
"TeamTask",
|
||||
"TaskStatus",
|
||||
"TeamMember",
|
||||
"MemberStatus",
|
||||
"MemberTask",
|
||||
]
|
||||
121
backend/app/agents/team/leader.py
Normal file
121
backend/app/agents/team/leader.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Team 多 Agent 协作 - Phase 10.1"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TeamTask:
|
||||
"""团队任务"""
|
||||
|
||||
id: str
|
||||
description: str
|
||||
assignee: str | None = None
|
||||
status: TaskStatus = TaskStatus.PENDING
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class TeamLeader:
|
||||
"""团队领导者
|
||||
|
||||
协调多个 Agent 成员执行任务。
|
||||
"""
|
||||
|
||||
def __init__(self, team_id: str, members: list[str]):
|
||||
"""
|
||||
Args:
|
||||
team_id: 团队 ID
|
||||
members: 成员 ID 列表
|
||||
"""
|
||||
self.team_id = team_id
|
||||
self.members = members
|
||||
self._tasks: dict[str, TeamTask] = {}
|
||||
|
||||
def create_task(self, description: str) -> str:
|
||||
"""创建任务
|
||||
|
||||
Args:
|
||||
description: 任务描述
|
||||
|
||||
Returns:
|
||||
任务 ID
|
||||
"""
|
||||
import uuid
|
||||
|
||||
task_id = str(uuid.uuid4())[:8]
|
||||
self._tasks[task_id] = TeamTask(
|
||||
id=task_id,
|
||||
description=description,
|
||||
)
|
||||
return task_id
|
||||
|
||||
def assign_task(self, task_id: str, member: str) -> bool:
|
||||
"""分配任务
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
member: 成员 ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if task_id not in self._tasks:
|
||||
return False
|
||||
|
||||
if member not in self.members:
|
||||
return False
|
||||
|
||||
self._tasks[task_id].assignee = member
|
||||
self._tasks[task_id].status = TaskStatus.IN_PROGRESS
|
||||
return True
|
||||
|
||||
def broadcast_task(self, description: str) -> list[str]:
|
||||
"""广播任务给所有成员
|
||||
|
||||
Args:
|
||||
description: 任务描述
|
||||
|
||||
Returns:
|
||||
创建的任务 ID 列表
|
||||
"""
|
||||
task_ids = []
|
||||
for member in self.members:
|
||||
task_id = self.create_task(description)
|
||||
self.assign_task(task_id, member)
|
||||
task_ids.append(task_id)
|
||||
return task_ids
|
||||
|
||||
def collect_results(self) -> dict[str, Any]:
|
||||
"""收集所有任务结果
|
||||
|
||||
Returns:
|
||||
任务 ID -> 结果的映射
|
||||
"""
|
||||
return {
|
||||
task_id: task.result
|
||||
for task_id, task in self._tasks.items()
|
||||
if task.status == TaskStatus.COMPLETED
|
||||
}
|
||||
|
||||
def get_team_status(self) -> dict[str, Any]:
|
||||
"""获取团队状态
|
||||
|
||||
Returns:
|
||||
团队状态摘要
|
||||
"""
|
||||
return {
|
||||
"team_id": self.team_id,
|
||||
"members": self.members,
|
||||
"task_count": len(self._tasks),
|
||||
"completed": sum(1 for t in self._tasks.values() if t.status == TaskStatus.COMPLETED),
|
||||
"failed": sum(1 for t in self._tasks.values() if t.status == TaskStatus.FAILED),
|
||||
}
|
||||
166
backend/app/agents/team/member.py
Normal file
166
backend/app/agents/team/member.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""TeamMember 实现 - Phase 10.1
|
||||
|
||||
团队成员实现,负责执行分配的任务。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class MemberStatus(Enum):
|
||||
"""成员状态"""
|
||||
|
||||
IDLE = "idle"
|
||||
BUSY = "busy"
|
||||
OFFLINE = "offline"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemberTask:
|
||||
"""成员任务"""
|
||||
|
||||
task_id: str
|
||||
description: str
|
||||
status: str = "pending" # pending, in_progress, completed, failed
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class TeamMember:
|
||||
"""团队成员
|
||||
|
||||
代表团队中的一个 Agent 成员,负责执行分配的任务。
|
||||
"""
|
||||
|
||||
def __init__(self, member_id: str, name: str, capabilities: list[str] | None = None):
|
||||
"""
|
||||
Args:
|
||||
member_id: 成员 ID
|
||||
name: 成员名称
|
||||
capabilities: 成员能力列表
|
||||
"""
|
||||
self.member_id = member_id
|
||||
self.name = name
|
||||
self.capabilities = capabilities or []
|
||||
self.status = MemberStatus.IDLE
|
||||
self._tasks: dict[str, MemberTask] = {}
|
||||
self._metadata: dict[str, Any] = {}
|
||||
|
||||
def assign_task(self, task_id: str, description: str) -> MemberTask:
|
||||
"""接收任务分配
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
description: 任务描述
|
||||
|
||||
Returns:
|
||||
创建的任务对象
|
||||
"""
|
||||
task = MemberTask(task_id=task_id, description=description)
|
||||
self._tasks[task_id] = task
|
||||
self.status = MemberStatus.BUSY
|
||||
return task
|
||||
|
||||
def update_task_status(
|
||||
self, task_id: str, status: str, result: Any = None, error: str | None = None
|
||||
) -> bool:
|
||||
"""更新任务状态
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
status: 新状态
|
||||
result: 任务结果
|
||||
error: 错误信息
|
||||
|
||||
Returns:
|
||||
是否更新成功
|
||||
"""
|
||||
if task_id not in self._tasks:
|
||||
return False
|
||||
|
||||
task = self._tasks[task_id]
|
||||
task.status = status
|
||||
if result is not None:
|
||||
task.result = result
|
||||
if error is not None:
|
||||
task.error = error
|
||||
|
||||
if status in ("completed", "failed"):
|
||||
self.status = MemberStatus.IDLE
|
||||
|
||||
return True
|
||||
|
||||
def get_task(self, task_id: str) -> MemberTask | None:
|
||||
"""获取任务
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
|
||||
Returns:
|
||||
任务对象或 None
|
||||
"""
|
||||
return self._tasks.get(task_id)
|
||||
|
||||
def get_pending_tasks(self) -> list[MemberTask]:
|
||||
"""获取待处理任务
|
||||
|
||||
Returns:
|
||||
待处理任务列表
|
||||
"""
|
||||
return [t for t in self._tasks.values() if t.status == "pending"]
|
||||
|
||||
def get_active_task(self) -> MemberTask | None:
|
||||
"""获取当前执行中的任务
|
||||
|
||||
Returns:
|
||||
当前任务或 None
|
||||
"""
|
||||
for task in self._tasks.values():
|
||||
if task.status == "in_progress":
|
||||
return task
|
||||
return None
|
||||
|
||||
def get_completed_tasks(self) -> list[MemberTask]:
|
||||
"""获取已完成任务
|
||||
|
||||
Returns:
|
||||
已完成任务列表
|
||||
"""
|
||||
return [t for t in self._tasks.values() if t.status == "completed"]
|
||||
|
||||
def set_metadata(self, key: str, value: Any) -> None:
|
||||
"""设置元数据
|
||||
|
||||
Args:
|
||||
key: 元数据键
|
||||
value: 元数据值
|
||||
"""
|
||||
self._metadata[key] = value
|
||||
|
||||
def get_metadata(self, key: str) -> Any:
|
||||
"""获取元数据
|
||||
|
||||
Args:
|
||||
key: 元数据键
|
||||
|
||||
Returns:
|
||||
元数据值或 None
|
||||
"""
|
||||
return self._metadata.get(key)
|
||||
|
||||
def get_status(self) -> dict[str, Any]:
|
||||
"""获取成员状态
|
||||
|
||||
Returns:
|
||||
状态字典
|
||||
"""
|
||||
return {
|
||||
"member_id": self.member_id,
|
||||
"name": self.name,
|
||||
"status": self.status.value,
|
||||
"capabilities": self.capabilities,
|
||||
"task_count": len(self._tasks),
|
||||
"pending_count": len(self.get_pending_tasks()),
|
||||
"active_task": self.get_active_task().__dict__ if self.get_active_task() else None,
|
||||
}
|
||||
@@ -1,6 +1,9 @@
|
||||
from app.agents.tools.search import (
|
||||
search_knowledge, get_knowledge_graph_context,
|
||||
build_knowledge_graph, hybrid_search, web_search,
|
||||
search_knowledge,
|
||||
get_knowledge_graph_context,
|
||||
build_knowledge_graph,
|
||||
hybrid_search,
|
||||
web_search,
|
||||
)
|
||||
from app.agents.tools.task import get_tasks, create_task, update_task_status
|
||||
from app.agents.tools.forum import get_forum_posts, create_forum_post, scan_forum_for_instructions
|
||||
@@ -13,6 +16,58 @@ from app.agents.tools.schedule import (
|
||||
)
|
||||
from app.agents.tools.time_reasoning import resolve_time_expression
|
||||
|
||||
# Phase 6.1: Tool Registry exports
|
||||
from app.agents.tools.registry import (
|
||||
ToolRegistry,
|
||||
get_tool_registry,
|
||||
reset_tool_registry,
|
||||
)
|
||||
from app.agents.tools.manifest import (
|
||||
HookConfig,
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
ToolCategory,
|
||||
ToolManifest,
|
||||
)
|
||||
from app.agents.tools.migration import (
|
||||
migrate_tool,
|
||||
migrate_all_tools,
|
||||
get_tool_executor,
|
||||
BackwardCompatTool,
|
||||
)
|
||||
|
||||
# Phase 6.2: Hook System exports
|
||||
from app.agents.tools.hooks import (
|
||||
HookManager,
|
||||
HookExecutor,
|
||||
HookType,
|
||||
HookDefinition,
|
||||
HookResult,
|
||||
ExecutionContext,
|
||||
get_hook_manager,
|
||||
get_hook_executor,
|
||||
)
|
||||
|
||||
# Phase 6.3: Streaming Executor exports
|
||||
from app.agents.tools.streaming import (
|
||||
StreamingToolExecutor,
|
||||
get_streaming_executor,
|
||||
)
|
||||
|
||||
# Phase 6.4: Builtin Tools exports
|
||||
from app.agents.tools.builtins import (
|
||||
GlobTool,
|
||||
GrepTool,
|
||||
ReadFileTool,
|
||||
WriteFileTool,
|
||||
BashTool,
|
||||
PowerShellTool,
|
||||
LSPTools,
|
||||
GitTool,
|
||||
TeamAgentTool,
|
||||
TaskBroadcastTool,
|
||||
)
|
||||
|
||||
TASK_TOOLS = [
|
||||
get_tasks,
|
||||
create_task,
|
||||
|
||||
18
backend/app/agents/tools/async_bridge.py
Normal file
18
backend/app/agents/tools/async_bridge.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def run_async(coro: Any, timeout: int = 30):
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
|
||||
|
||||
|
||||
__all__ = ["run_async"]
|
||||
161
backend/app/agents/tools/base.py
Normal file
161
backend/app/agents/tools/base.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""工具基类 - 工具系统重构 Phase 6.1"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from app.agents.tools.manifest import (
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
ToolCategory,
|
||||
ToolManifest,
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseTool(ABC, Generic[T]):
|
||||
"""工具基类
|
||||
|
||||
提供工具的标准接口和默认实现。
|
||||
所有自定义工具应继承此类。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
category: ToolCategory,
|
||||
permission_class: PermissionClass,
|
||||
side_effect_scope: SideEffectScope = SideEffectScope.NONE,
|
||||
requires_confirmation: bool = False,
|
||||
is_streaming: bool = False,
|
||||
tags: list[str] | None = None,
|
||||
):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.category = category
|
||||
self.permission_class = permission_class
|
||||
self.side_effect_scope = side_effect_scope
|
||||
self.requires_confirmation = requires_confirmation
|
||||
self.is_streaming = is_streaming
|
||||
self.tags = tags or []
|
||||
|
||||
def get_manifest(self) -> ToolManifest:
|
||||
"""获取工具元数据
|
||||
|
||||
Returns:
|
||||
工具元数据
|
||||
"""
|
||||
return ToolManifest(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
category=self.category,
|
||||
parameters=self.get_parameters(),
|
||||
return_schema=self.get_return_schema(),
|
||||
permission_class=self.permission_class,
|
||||
side_effect_scope=self.side_effect_scope,
|
||||
requires_confirmation=self.requires_confirmation,
|
||||
is_streaming=self.is_streaming,
|
||||
tags=self.tags,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
"""获取参数 Schema(JSON Schema 格式)
|
||||
|
||||
Returns:
|
||||
参数 schema
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
"""获取返回值 Schema
|
||||
|
||||
Returns:
|
||||
返回值 schema
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs) -> T:
|
||||
"""执行工具
|
||||
|
||||
Args:
|
||||
**kwargs: 工具参数
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
pass
|
||||
|
||||
async def execute_safe(self, **kwargs) -> dict[str, Any]:
|
||||
"""安全执行工具,捕获异常
|
||||
|
||||
Args:
|
||||
**kwargs: 工具参数
|
||||
|
||||
Returns:
|
||||
包含 success 和 result/error 的字典
|
||||
"""
|
||||
try:
|
||||
result = await self.execute(**kwargs)
|
||||
return {"success": True, "result": result}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}(name={self.name!r})>"
|
||||
|
||||
|
||||
class ReadTool(BaseTool):
|
||||
"""只读工具基类"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs.setdefault("category", ToolCategory.READ)
|
||||
kwargs.setdefault("permission_class", PermissionClass.READ)
|
||||
kwargs.setdefault("side_effect_scope", SideEffectScope.NONE)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class WriteTool(BaseTool):
|
||||
"""写入工具基类"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs.setdefault("category", ToolCategory.WRITE)
|
||||
kwargs.setdefault("permission_class", PermissionClass.WRITE)
|
||||
kwargs.setdefault("side_effect_scope", SideEffectScope.LOCAL_STATE)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class DBWriteTool(BaseTool):
|
||||
"""数据库写入工具基类"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs.setdefault("category", ToolCategory.DB_WRITE)
|
||||
kwargs.setdefault("permission_class", PermissionClass.WRITE)
|
||||
kwargs.setdefault("side_effect_scope", SideEffectScope.DB_WRITE)
|
||||
kwargs.setdefault("requires_confirmation", True)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class ExternalTool(BaseTool):
|
||||
"""外部工具基类(执行外部命令等)"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs.setdefault("category", ToolCategory.EXTERNAL)
|
||||
kwargs.setdefault("permission_class", PermissionClass.EXTERNAL)
|
||||
kwargs.setdefault("side_effect_scope", SideEffectScope.NETWORK)
|
||||
kwargs.setdefault("requires_confirmation", True)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class NetworkTool(BaseTool):
|
||||
"""网络工具基类"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs.setdefault("category", ToolCategory.NETWORK)
|
||||
kwargs.setdefault("permission_class", PermissionClass.EXTERNAL)
|
||||
kwargs.setdefault("side_effect_scope", SideEffectScope.NETWORK)
|
||||
super().__init__(**kwargs)
|
||||
43
backend/app/agents/tools/builtins/__init__.py
Normal file
43
backend/app/agents/tools/builtins/__init__.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""内置工具集 - Phase 6.4
|
||||
|
||||
新的内置工具,使用 BaseTool 基类。
|
||||
"""
|
||||
|
||||
from app.agents.tools.builtins.file_tools import (
|
||||
GlobTool,
|
||||
GrepTool,
|
||||
ReadFileTool,
|
||||
WriteFileTool,
|
||||
)
|
||||
|
||||
from app.agents.tools.builtins.system_tools import (
|
||||
BashTool,
|
||||
PowerShellTool,
|
||||
)
|
||||
|
||||
from app.agents.tools.builtins.dev_tools import (
|
||||
LSPTools,
|
||||
GitTool,
|
||||
)
|
||||
|
||||
from app.agents.tools.builtins.collaboration_tools import (
|
||||
TeamAgentTool,
|
||||
TaskBroadcastTool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# File tools
|
||||
"GlobTool",
|
||||
"GrepTool",
|
||||
"ReadFileTool",
|
||||
"WriteFileTool",
|
||||
# System tools
|
||||
"BashTool",
|
||||
"PowerShellTool",
|
||||
# Dev tools
|
||||
"LSPTools",
|
||||
"GitTool",
|
||||
# Collaboration tools
|
||||
"TeamAgentTool",
|
||||
"TaskBroadcastTool",
|
||||
]
|
||||
129
backend/app/agents/tools/builtins/collaboration_tools.py
Normal file
129
backend/app/agents/tools/builtins/collaboration_tools.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""协作工具 - Phase 6.4"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.base import WriteTool
|
||||
from app.agents.tools.manifest import (
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
)
|
||||
|
||||
|
||||
class TeamAgentTool(WriteTool):
|
||||
"""团队 Agent 通信工具
|
||||
|
||||
用于与其他 Agent 进行消息传递和协作。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="team_agent",
|
||||
description="向团队 Agent 发送消息或请求协作",
|
||||
permission_class=PermissionClass.WRITE,
|
||||
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||
tags=["collaboration", "team", "agent"],
|
||||
)
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_name": {
|
||||
"type": "string",
|
||||
"description": "目标 Agent 名称",
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "要发送的消息",
|
||||
},
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["send", "request", "delegate"],
|
||||
"description": "操作类型",
|
||||
},
|
||||
},
|
||||
"required": ["agent_name", "message"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"success": {"type": "boolean"},
|
||||
"response": {"type": "string"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(self, agent_name: str, message: str, action: str = "send") -> dict[str, Any]:
|
||||
# 注意:实际实现需要通过 Agent 通信协议
|
||||
# 这里只是一个框架实现
|
||||
return {
|
||||
"success": True,
|
||||
"response": f"Message '{action}' to agent '{agent_name}': {message}",
|
||||
"agent_name": agent_name,
|
||||
"action": action,
|
||||
}
|
||||
|
||||
|
||||
class TaskBroadcastTool(WriteTool):
|
||||
"""任务广播工具
|
||||
|
||||
向多个 Agent 广播任务。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="task_broadcast",
|
||||
description="向多个 Agent 广播任务",
|
||||
permission_class=PermissionClass.WRITE,
|
||||
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||
tags=["collaboration", "broadcast", "task"],
|
||||
)
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_names": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "目标 Agent 列表",
|
||||
},
|
||||
"task": {
|
||||
"type": "string",
|
||||
"description": "要广播的任务描述",
|
||||
},
|
||||
"priority": {
|
||||
"type": "string",
|
||||
"enum": ["low", "normal", "high", "urgent"],
|
||||
"description": "任务优先级",
|
||||
},
|
||||
},
|
||||
"required": ["agent_names", "task"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"success": {"type": "boolean"},
|
||||
"broadcast_to": {"type": "array", "items": {"type": "string"}},
|
||||
"responses": {"type": "array"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
agent_names: list[str],
|
||||
task: str,
|
||||
priority: str = "normal",
|
||||
) -> dict[str, Any]:
|
||||
# 注意:实际实现需要通过 Agent 通信协议
|
||||
# 这里只是一个框架实现
|
||||
return {
|
||||
"success": True,
|
||||
"broadcast_to": agent_names,
|
||||
"task": task,
|
||||
"priority": priority,
|
||||
"responses": [f"Acknowledged by {agent}" for agent in agent_names],
|
||||
}
|
||||
155
backend/app/agents/tools/builtins/dev_tools.py
Normal file
155
backend/app/agents/tools/builtins/dev_tools.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""开发工具 - Phase 6.4"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.base import ReadTool, WriteTool
|
||||
from app.agents.tools.manifest import (
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
)
|
||||
|
||||
|
||||
class LSPTools(ReadTool):
|
||||
"""语言服务器协议工具集
|
||||
|
||||
提供代码导航、查找引用等 LSP 功能。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="lsp_tools",
|
||||
description="LSP 代码导航和查找引用",
|
||||
permission_class=PermissionClass.READ,
|
||||
side_effect_scope=SideEffectScope.NONE,
|
||||
tags=["development", "lsp", "code"],
|
||||
)
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["goto_definition", "find_references", "document_symbols"],
|
||||
"description": "LSP 操作类型",
|
||||
},
|
||||
"file": {
|
||||
"type": "string",
|
||||
"description": "文件路径",
|
||||
},
|
||||
"line": {
|
||||
"type": "integer",
|
||||
"description": "行号(1-based)",
|
||||
},
|
||||
"character": {
|
||||
"type": "integer",
|
||||
"description": "列号(0-based)",
|
||||
},
|
||||
},
|
||||
"required": ["action", "file"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"success": {"type": "boolean"},
|
||||
"results": {"type": "array"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
action: str,
|
||||
file: str,
|
||||
line: int = 1,
|
||||
character: int = 0,
|
||||
) -> dict[str, Any]:
|
||||
# 注意:实际 LSP 调用需要通过 lsp-utils 或类似库
|
||||
# 这里只是一个框架实现
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"LSP action '{action}' not fully implemented - requires LSP server integration",
|
||||
"action": action,
|
||||
"file": file,
|
||||
"position": {"line": line, "character": character},
|
||||
}
|
||||
|
||||
|
||||
class GitTool(ReadTool):
|
||||
"""Git 操作工具
|
||||
|
||||
提供常用的 Git 操作。
|
||||
"""
|
||||
|
||||
def __init__(self, repo_path: str = "."):
|
||||
super().__init__(
|
||||
name="git",
|
||||
description="执行 Git 命令",
|
||||
permission_class=PermissionClass.EXTERNAL,
|
||||
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||
requires_confirmation=True,
|
||||
tags=["development", "git", "version-control"],
|
||||
)
|
||||
self.repo_path = repo_path
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "Git 子命令和参数,如 'status' 或 'log --oneline -10'",
|
||||
},
|
||||
"repo_path": {
|
||||
"type": "string",
|
||||
"description": "仓库路径(可选)",
|
||||
},
|
||||
},
|
||||
"required": ["command"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"stdout": {"type": "string"},
|
||||
"stderr": {"type": "string"},
|
||||
"returncode": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(self, command: str, repo_path: str | None = None) -> dict[str, Any]:
|
||||
import asyncio
|
||||
import os
|
||||
import platform
|
||||
|
||||
repo = repo_path or self.repo_path
|
||||
|
||||
# 构建完整的 git 命令
|
||||
if platform.system() == "Windows":
|
||||
full_command = f'git -C "{repo}" {command}'
|
||||
else:
|
||||
full_command = f"git -C '{repo}' {command}"
|
||||
|
||||
try:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
full_command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
return {
|
||||
"stdout": stdout.decode("utf-8", errors="replace"),
|
||||
"stderr": stderr.decode("utf-8", errors="replace"),
|
||||
"returncode": process.returncode,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": str(e),
|
||||
"returncode": -1,
|
||||
}
|
||||
255
backend/app/agents/tools/builtins/file_tools.py
Normal file
255
backend/app/agents/tools/builtins/file_tools.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""文件操作工具 - Phase 6.4"""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.base import ExternalTool, ReadTool, WriteTool
|
||||
from app.agents.tools.manifest import (
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
ToolCategory,
|
||||
)
|
||||
|
||||
|
||||
class GlobTool(ReadTool):
|
||||
"""文件路径匹配工具
|
||||
|
||||
使用 glob 模式查找文件。
|
||||
"""
|
||||
|
||||
def __init__(self, root_dir: str = "."):
|
||||
super().__init__(
|
||||
name="glob",
|
||||
description="使用 glob 模式查找文件路径",
|
||||
permission_class=PermissionClass.READ,
|
||||
side_effect_scope=SideEffectScope.NONE,
|
||||
tags=["file", "search", "glob"],
|
||||
)
|
||||
self.root_dir = root_dir
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Glob 模式,如 **/*.py",
|
||||
},
|
||||
"root_dir": {
|
||||
"type": "string",
|
||||
"description": "搜索根目录(可选)",
|
||||
},
|
||||
},
|
||||
"required": ["pattern"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
|
||||
async def execute(self, pattern: str, root_dir: str | None = None) -> list[str]:
|
||||
import glob as glob_module
|
||||
|
||||
root = root_dir or self.root_dir
|
||||
return glob_module.glob(pattern, root_dir=root, recursive=True)
|
||||
|
||||
|
||||
class GrepTool(ReadTool):
|
||||
"""文件内容搜索工具
|
||||
|
||||
在文件中搜索匹配的行。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="grep",
|
||||
description="在文件中搜索匹配的文本行",
|
||||
permission_class=PermissionClass.READ,
|
||||
side_effect_scope=SideEffectScope.NONE,
|
||||
tags=["file", "search", "text"],
|
||||
)
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "正则表达式模式",
|
||||
},
|
||||
"paths": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "要搜索的文件路径列表",
|
||||
},
|
||||
"case_sensitive": {
|
||||
"type": "boolean",
|
||||
"description": "是否区分大小写",
|
||||
},
|
||||
},
|
||||
"required": ["pattern", "paths"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file": {"type": "string"},
|
||||
"line": {"type": "integer"},
|
||||
"content": {"type": "string"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self, pattern: str, paths: list[str], case_sensitive: bool = True
|
||||
) -> list[dict[str, Any]]:
|
||||
import re
|
||||
|
||||
flags = 0 if case_sensitive else re.IGNORECASE
|
||||
regex = re.compile(pattern, flags)
|
||||
results = []
|
||||
|
||||
for path in paths:
|
||||
if not os.path.isfile(path):
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
if regex.search(line):
|
||||
results.append(
|
||||
{
|
||||
"file": path,
|
||||
"line": line_num,
|
||||
"content": line.rstrip(),
|
||||
}
|
||||
)
|
||||
except (UnicodeDecodeError, PermissionError):
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class ReadFileTool(ReadTool):
|
||||
"""文件读取工具"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="read_file",
|
||||
description="读取文件内容",
|
||||
permission_class=PermissionClass.READ,
|
||||
side_effect_scope=SideEffectScope.NONE,
|
||||
tags=["file", "read"],
|
||||
)
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "文件路径",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "最大行数",
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "起始行号",
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {"type": "string"},
|
||||
"lines": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(self, path: str, limit: int | None = None, offset: int = 0) -> dict[str, Any]:
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
total_lines = len(lines)
|
||||
start = max(0, offset)
|
||||
end = len(lines) if limit is None else min(start + limit, len(lines))
|
||||
|
||||
content = "".join(lines[start:end])
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"lines": total_lines,
|
||||
"truncated": limit is not None and end < len(lines),
|
||||
}
|
||||
|
||||
|
||||
class WriteFileTool(WriteTool):
|
||||
"""文件写入工具"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="write_file",
|
||||
description="写入文件内容",
|
||||
permission_class=PermissionClass.WRITE,
|
||||
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||
requires_confirmation=True,
|
||||
tags=["file", "write"],
|
||||
)
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "文件路径",
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "文件内容",
|
||||
},
|
||||
"append": {
|
||||
"type": "boolean",
|
||||
"description": "是否追加模式",
|
||||
},
|
||||
},
|
||||
"required": ["path", "content"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"success": {"type": "boolean"},
|
||||
"bytes_written": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(self, path: str, content: str, append: bool = False) -> dict[str, Any]:
|
||||
mode = "a" if append else "w"
|
||||
|
||||
# 确保目录存在
|
||||
directory = os.path.dirname(path)
|
||||
if directory and not os.path.exists(directory):
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
with open(path, mode, encoding="utf-8") as f:
|
||||
bytes_written = f.write(content)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"bytes_written": bytes_written,
|
||||
}
|
||||
193
backend/app/agents/tools/builtins/system_tools.py
Normal file
193
backend/app/agents/tools/builtins/system_tools.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""系统工具 - Phase 6.4"""
|
||||
|
||||
import asyncio
|
||||
import shlex
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.base import ExternalTool
|
||||
from app.agents.tools.manifest import (
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
)
|
||||
|
||||
|
||||
class BashTool(ExternalTool):
|
||||
"""Bash 命令执行工具"""
|
||||
|
||||
def __init__(self, working_dir: str = "."):
|
||||
super().__init__(
|
||||
name="bash",
|
||||
description="执行 Bash 命令",
|
||||
permission_class=PermissionClass.EXTERNAL,
|
||||
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||
requires_confirmation=True,
|
||||
tags=["system", "bash", "shell"],
|
||||
)
|
||||
self.working_dir = working_dir
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "要执行的 Bash 命令",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "超时时间(秒)",
|
||||
},
|
||||
"working_dir": {
|
||||
"type": "string",
|
||||
"description": "工作目录(可选)",
|
||||
},
|
||||
},
|
||||
"required": ["command"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"stdout": {"type": "string"},
|
||||
"stderr": {"type": "string"},
|
||||
"returncode": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self, command: str, timeout: int = 30, working_dir: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
import os
|
||||
|
||||
cwd = working_dir or self.working_dir
|
||||
|
||||
try:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": f"Command timed out after {timeout} seconds",
|
||||
"returncode": -1,
|
||||
}
|
||||
|
||||
return {
|
||||
"stdout": stdout.decode("utf-8", errors="replace"),
|
||||
"stderr": stderr.decode("utf-8", errors="replace"),
|
||||
"returncode": process.returncode,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": str(e),
|
||||
"returncode": -1,
|
||||
}
|
||||
|
||||
|
||||
class PowerShellTool(ExternalTool):
|
||||
"""PowerShell 命令执行工具"""
|
||||
|
||||
def __init__(self, working_dir: str = "."):
|
||||
super().__init__(
|
||||
name="powershell",
|
||||
description="执行 PowerShell 命令",
|
||||
permission_class=PermissionClass.EXTERNAL,
|
||||
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||
requires_confirmation=True,
|
||||
tags=["system", "powershell", "shell"],
|
||||
)
|
||||
self.working_dir = working_dir
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "要执行的 PowerShell 命令",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "超时时间(秒)",
|
||||
},
|
||||
"working_dir": {
|
||||
"type": "string",
|
||||
"description": "工作目录(可选)",
|
||||
},
|
||||
},
|
||||
"required": ["command"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"stdout": {"type": "string"},
|
||||
"stderr": {"type": "string"},
|
||||
"returncode": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self, command: str, timeout: int = 30, working_dir: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
import platform
|
||||
|
||||
# 检测是否是 Windows 平台
|
||||
is_windows = platform.system() == "Windows"
|
||||
|
||||
if not is_windows:
|
||||
# 非 Windows 平台,可能没有 PowerShell
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": "PowerShell is not available on this platform",
|
||||
"returncode": -1,
|
||||
}
|
||||
|
||||
cwd = working_dir or self.working_dir
|
||||
|
||||
try:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
"powershell.exe",
|
||||
"-NoProfile",
|
||||
"-Command",
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": f"Command timed out after {timeout} seconds",
|
||||
"returncode": -1,
|
||||
}
|
||||
|
||||
return {
|
||||
"stdout": stdout.decode("utf-8", errors="replace"),
|
||||
"stderr": stderr.decode("utf-8", errors="replace"),
|
||||
"returncode": process.returncode,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": str(e),
|
||||
"returncode": -1,
|
||||
}
|
||||
@@ -4,19 +4,12 @@ from langchain_core.tools import tool
|
||||
from app.database import async_session
|
||||
from app.models.forum import ForumPost, ForumReply
|
||||
from app.agents.context import get_current_user
|
||||
from app.agents.tools.async_bridge import run_async
|
||||
from sqlalchemy import select
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def _run_async(coro, timeout: int = 30):
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
|
||||
return run_async(coro, timeout=timeout)
|
||||
|
||||
|
||||
@tool
|
||||
|
||||
46
backend/app/agents/tools/hooks/__init__.py
Normal file
46
backend/app/agents/tools/hooks/__init__.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Hook 系统 - Phase 6.2"""
|
||||
|
||||
from app.agents.tools.hooks.types import (
|
||||
HookDefinition,
|
||||
HookResult,
|
||||
HookStage,
|
||||
HookTrigger,
|
||||
HookType,
|
||||
ExecutionContext,
|
||||
HookHandler,
|
||||
PreToolHook,
|
||||
PostToolHook,
|
||||
ErrorToolHook,
|
||||
SkipToolHook,
|
||||
)
|
||||
from app.agents.tools.hooks.manager import (
|
||||
HookManager,
|
||||
get_hook_manager,
|
||||
reset_hook_manager,
|
||||
)
|
||||
from app.agents.tools.hooks.executor import (
|
||||
HookExecutor,
|
||||
get_hook_executor,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Types
|
||||
"HookType",
|
||||
"HookStage",
|
||||
"HookTrigger",
|
||||
"HookDefinition",
|
||||
"HookResult",
|
||||
"ExecutionContext",
|
||||
"HookHandler",
|
||||
"PreToolHook",
|
||||
"PostToolHook",
|
||||
"ErrorToolHook",
|
||||
"SkipToolHook",
|
||||
# Manager
|
||||
"HookManager",
|
||||
"get_hook_manager",
|
||||
"reset_hook_manager",
|
||||
# Executor
|
||||
"HookExecutor",
|
||||
"get_hook_executor",
|
||||
]
|
||||
11
backend/app/agents/tools/hooks/builtins/__init__.py
Normal file
11
backend/app/agents/tools/hooks/builtins/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""内置 Hook 集合 - Phase 7"""
|
||||
|
||||
from app.agents.tools.hooks.builtins.audit_log import AuditLogHook
|
||||
from app.agents.tools.hooks.builtins.dangerous_confirmation import DangerousConfirmationHook
|
||||
from app.agents.tools.hooks.builtins.security_scan import SecurityScanHook
|
||||
|
||||
__all__ = [
|
||||
"AuditLogHook",
|
||||
"DangerousConfirmationHook",
|
||||
"SecurityScanHook",
|
||||
]
|
||||
115
backend/app/agents/tools/hooks/builtins/audit_log.py
Normal file
115
backend/app/agents/tools/hooks/builtins/audit_log.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""审计日志 Hook - Phase 7.2
|
||||
|
||||
记录所有工具调用到审计日志。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.hooks.types import (
|
||||
ExecutionContext,
|
||||
HookResult,
|
||||
HookType,
|
||||
)
|
||||
from app.agents.tools.manifest import ToolCategory
|
||||
|
||||
|
||||
class AuditLogHook:
|
||||
"""审计日志 Hook
|
||||
|
||||
记录所有工具调用的详细信息,包括:
|
||||
- 调用时间
|
||||
- 工具名称
|
||||
- 输入参数
|
||||
- 执行结果
|
||||
- 执行时长
|
||||
- 用户 ID
|
||||
"""
|
||||
|
||||
def __init__(self, log_path: str | None = None):
|
||||
"""
|
||||
Args:
|
||||
log_path: 日志文件路径,None 则输出到 stdout
|
||||
"""
|
||||
self.log_path = log_path
|
||||
self._logs: list[dict[str, Any]] = []
|
||||
|
||||
async def pre_tool_use(self, context: ExecutionContext) -> HookResult:
|
||||
"""工具执行前记录"""
|
||||
log_entry = {
|
||||
"event": "pre_tool",
|
||||
"tool_name": context.tool_name,
|
||||
"input": context.tool_input,
|
||||
"user_id": context.user_id,
|
||||
"session_id": context.session_id,
|
||||
}
|
||||
self._logs.append(log_entry)
|
||||
self._write_log(log_entry)
|
||||
return HookResult(
|
||||
hook_name="audit_log",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
)
|
||||
|
||||
async def post_tool_use(self, context: ExecutionContext, result: Any) -> HookResult:
|
||||
"""工具执行后记录"""
|
||||
log_entry = {
|
||||
"event": "post_tool",
|
||||
"tool_name": context.tool_name,
|
||||
"result": str(result)[:500] if result else None,
|
||||
"duration_ms": (
|
||||
(context.end_time - context.start_time) * 1000
|
||||
if context.start_time and context.end_time
|
||||
else None
|
||||
),
|
||||
}
|
||||
self._logs.append(log_entry)
|
||||
self._write_log(log_entry)
|
||||
return HookResult(
|
||||
hook_name="audit_log",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
modified_output=result,
|
||||
)
|
||||
|
||||
async def tool_error(self, context: ExecutionContext, error: Exception) -> HookResult:
|
||||
"""工具出错时记录"""
|
||||
log_entry = {
|
||||
"event": "tool_error",
|
||||
"tool_name": context.tool_name,
|
||||
"error": str(error),
|
||||
"error_type": type(error).__name__,
|
||||
}
|
||||
self._logs.append(log_entry)
|
||||
self._write_log(log_entry)
|
||||
return HookResult(
|
||||
hook_name="audit_log",
|
||||
success=False,
|
||||
continue_execution=True,
|
||||
error=str(error),
|
||||
)
|
||||
|
||||
def _write_log(self, entry: dict[str, Any]) -> None:
|
||||
"""写入日志"""
|
||||
import json
|
||||
import datetime
|
||||
|
||||
entry["timestamp"] = datetime.datetime.now().isoformat()
|
||||
|
||||
if self.log_path:
|
||||
try:
|
||||
with open(self.log_path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||
except Exception:
|
||||
# 日志写入失败不影响主流程
|
||||
pass
|
||||
else:
|
||||
# 输出到 stdout
|
||||
print(f"[AUDIT] {json.dumps(entry, ensure_ascii=False)}")
|
||||
|
||||
def get_logs(self) -> list[dict[str, Any]]:
|
||||
"""获取所有日志"""
|
||||
return self._logs.copy()
|
||||
|
||||
def clear_logs(self) -> None:
|
||||
"""清空日志"""
|
||||
self._logs.clear()
|
||||
@@ -0,0 +1,142 @@
|
||||
"""危险操作确认 Hook - Phase 7.2
|
||||
|
||||
对危险操作要求用户确认。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.hooks.types import (
|
||||
ExecutionContext,
|
||||
HookResult,
|
||||
)
|
||||
from app.agents.tools.manifest import SideEffectScope
|
||||
|
||||
|
||||
# 危险操作关键词
|
||||
DANGEROUS_PATTERNS = [
|
||||
# 文件操作
|
||||
"delete",
|
||||
"remove",
|
||||
"rm ",
|
||||
"rmdir",
|
||||
"unlink",
|
||||
"format",
|
||||
"truncate",
|
||||
# 系统操作
|
||||
"shutdown",
|
||||
"reboot",
|
||||
"kill",
|
||||
"pkill",
|
||||
"sudo",
|
||||
"chmod",
|
||||
"chown",
|
||||
# 数据操作
|
||||
"drop",
|
||||
"truncate",
|
||||
"delete from",
|
||||
"delete.*where",
|
||||
"insert into.*select",
|
||||
"update.*set",
|
||||
# 网络操作
|
||||
"curl",
|
||||
"wget",
|
||||
"nc ",
|
||||
"netcat",
|
||||
"ssh ",
|
||||
"scp ",
|
||||
"sftp ",
|
||||
# 环境变量
|
||||
"export.*secret",
|
||||
"export.*key",
|
||||
"export.*token",
|
||||
]
|
||||
|
||||
|
||||
class DangerousConfirmationHook:
|
||||
"""危险操作确认 Hook
|
||||
|
||||
检查工具调用是否包含危险操作,如是则要求确认。
|
||||
"""
|
||||
|
||||
def __init__(self, auto_block: bool = False):
|
||||
"""
|
||||
Args:
|
||||
auto_block: True 表示自动拦截危险操作,False 表示仅警告
|
||||
"""
|
||||
self.auto_block = auto_block
|
||||
self._pending_confirmations: dict[str, bool] = {}
|
||||
|
||||
async def pre_tool_use(self, context: ExecutionContext) -> HookResult:
|
||||
"""检查是否为危险操作"""
|
||||
is_dangerous = self._check_dangerous(context.tool_name, context.tool_input)
|
||||
|
||||
if is_dangerous:
|
||||
if self.auto_block:
|
||||
return HookResult(
|
||||
hook_name="dangerous_confirmation",
|
||||
success=False,
|
||||
continue_execution=False,
|
||||
error=f"危险操作被自动拦截: {context.tool_name}",
|
||||
metadata={"dangerous": True, "auto_blocked": True},
|
||||
)
|
||||
else:
|
||||
# 标记需要确认
|
||||
context.metadata["requires_confirmation"] = True
|
||||
context.metadata["dangerous_operation"] = True
|
||||
return HookResult(
|
||||
hook_name="dangerous_confirmation",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
metadata={"dangerous": True, "requires_confirmation": True},
|
||||
)
|
||||
|
||||
return HookResult(
|
||||
hook_name="dangerous_confirmation",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
)
|
||||
|
||||
def _check_dangerous(self, tool_name: str, tool_input: dict[str, Any]) -> bool:
|
||||
"""检查是否为危险操作"""
|
||||
# 检查工具名称
|
||||
dangerous_tools = [
|
||||
"delete",
|
||||
"remove",
|
||||
"drop",
|
||||
"truncate",
|
||||
"kill",
|
||||
"shutdown",
|
||||
"reboot",
|
||||
"bash",
|
||||
"powershell",
|
||||
"shell",
|
||||
]
|
||||
|
||||
if tool_name.lower() in dangerous_tools:
|
||||
return True
|
||||
|
||||
# 检查输入参数
|
||||
input_str = str(tool_input).lower()
|
||||
|
||||
for pattern in DANGEROUS_PATTERNS:
|
||||
if pattern.lower() in input_str:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def confirm(self, session_id: str, confirmed: bool) -> None:
|
||||
"""确认危险操作
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
confirmed: True 表示用户确认,False 表示取消
|
||||
"""
|
||||
self._pending_confirmations[session_id] = confirmed
|
||||
|
||||
def is_confirmed(self, session_id: str) -> bool:
|
||||
"""检查是否已确认"""
|
||||
return self._pending_confirmations.get(session_id, False)
|
||||
|
||||
def clear_confirmation(self, session_id: str) -> None:
|
||||
"""清除确认状态"""
|
||||
self._pending_confirmations.pop(session_id, None)
|
||||
183
backend/app/agents/tools/hooks/builtins/security_scan.py
Normal file
183
backend/app/agents/tools/hooks/builtins/security_scan.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""安全扫描 Hook - Phase 7.2
|
||||
|
||||
扫描工具调用和结果中的敏感信息。
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.hooks.types import (
|
||||
ExecutionContext,
|
||||
HookResult,
|
||||
)
|
||||
|
||||
|
||||
# 敏感信息模式
|
||||
SENSITIVE_PATTERNS = {
|
||||
"api_key": [
|
||||
r"api[_-]?key['\"]?\s*[:=]\s*['\"]?[a-zA-Z0-9_\-]{20,}",
|
||||
r"apikey['\"]?\s*[:=]\s*['\"]?[a-zA-Z0-9_\-]{20,}",
|
||||
],
|
||||
"password": [
|
||||
r"password['\"]?\s*[:=]\s*['\"]?[^\s'\"]{8,}",
|
||||
r"passwd['\"]?\s*[:=]\s*['\"]?[^\s'\"]{8,}",
|
||||
r"secret['\"]?\s*[:=]\s*['\"]?[a-zA-Z0-9_\-]{20,}",
|
||||
],
|
||||
"token": [
|
||||
r"token['\"]?\s*[:=]\s*['\"]?[a-zA-Z0-9_\-\.]{20,}",
|
||||
r"bearer\s+[a-zA-Z0-9_\-\.]+",
|
||||
r"ghp_[a-zA-Z0-9]{36}",
|
||||
r"sk-[a-zA-Z0-9]{48}",
|
||||
],
|
||||
"private_key": [
|
||||
r"-----BEGIN (RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----",
|
||||
r"-----END (RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----",
|
||||
],
|
||||
"ip_address": [
|
||||
r"\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b",
|
||||
],
|
||||
"email": [
|
||||
r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class SecurityScanHook:
|
||||
"""安全扫描 Hook
|
||||
|
||||
扫描工具输入和输出中的敏感信息,进行脱敏处理。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redact: bool = True,
|
||||
block_on_detect: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
redact: 是否对敏感信息进行脱敏
|
||||
block_on_detect: 检测到敏感信息时是否阻止执行
|
||||
"""
|
||||
self.redact = redact
|
||||
self.block_on_detect = block_on_detect
|
||||
self._compiled_patterns = {
|
||||
name: [re.compile(p, re.IGNORECASE) for p in patterns]
|
||||
for name, patterns in SENSITIVE_PATTERNS.items()
|
||||
}
|
||||
|
||||
async def pre_tool_use(self, context: ExecutionContext) -> HookResult:
|
||||
"""扫描输入参数"""
|
||||
detected = self._scan_dict(context.tool_input)
|
||||
|
||||
if detected:
|
||||
context.metadata["security_detected"] = detected
|
||||
|
||||
if self.block_on_detect:
|
||||
return HookResult(
|
||||
hook_name="security_scan",
|
||||
success=False,
|
||||
continue_execution=False,
|
||||
error=f"检测到敏感信息: {', '.join(detected.keys())}",
|
||||
metadata={"detected": detected, "blocked": True},
|
||||
)
|
||||
|
||||
if self.redact:
|
||||
redacted_input = self._redact_dict(context.tool_input.copy())
|
||||
return HookResult(
|
||||
hook_name="security_scan",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
modified_input=redacted_input,
|
||||
metadata={"detected": detected, "redacted": True},
|
||||
)
|
||||
|
||||
return HookResult(
|
||||
hook_name="security_scan",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
)
|
||||
|
||||
async def post_tool_use(self, context: ExecutionContext, result: Any) -> HookResult:
|
||||
"""扫描输出结果"""
|
||||
if isinstance(result, dict):
|
||||
detected = self._scan_dict(result)
|
||||
|
||||
if detected:
|
||||
context.metadata["security_detected_output"] = detected
|
||||
|
||||
if self.redact:
|
||||
redacted_result = self._redact_dict(result.copy())
|
||||
return HookResult(
|
||||
hook_name="security_scan",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
modified_output=redacted_result,
|
||||
metadata={"detected": detected, "redacted": True},
|
||||
)
|
||||
|
||||
elif isinstance(result, str):
|
||||
detected = self._scan_string(result)
|
||||
if detected:
|
||||
context.metadata["security_detected_output"] = detected
|
||||
|
||||
if self.redact:
|
||||
redacted_result = self._redact_string(result)
|
||||
return HookResult(
|
||||
hook_name="security_scan",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
modified_output=redacted_result,
|
||||
metadata={"detected": detected, "redacted": True},
|
||||
)
|
||||
|
||||
return HookResult(
|
||||
hook_name="security_scan",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
modified_output=result,
|
||||
)
|
||||
|
||||
def _scan_dict(self, data: dict[str, Any]) -> dict[str, list[str]]:
|
||||
"""扫描字典中的敏感信息"""
|
||||
result: dict[str, list[str]] = {}
|
||||
|
||||
for key, value in data.items():
|
||||
if isinstance(value, str):
|
||||
found = self._scan_string(value)
|
||||
if found:
|
||||
result[key] = found
|
||||
|
||||
return result
|
||||
|
||||
def _scan_string(self, text: str) -> list[str]:
|
||||
"""扫描字符串中的敏感信息"""
|
||||
found_types = []
|
||||
|
||||
for name, patterns in self._compiled_patterns.items():
|
||||
for pattern in patterns:
|
||||
if pattern.search(text):
|
||||
if name not in found_types:
|
||||
found_types.append(name)
|
||||
break
|
||||
|
||||
return found_types
|
||||
|
||||
def _redact_dict(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""脱敏字典中的敏感信息"""
|
||||
for key, value in data.items():
|
||||
if isinstance(value, str):
|
||||
data[key] = self._redact_string(value)
|
||||
elif isinstance(value, dict):
|
||||
data[key] = self._redact_dict(value)
|
||||
elif isinstance(value, list):
|
||||
data[key] = [self._redact_string(v) if isinstance(v, str) else v for v in value]
|
||||
|
||||
return data
|
||||
|
||||
def _redact_string(self, text: str) -> str:
|
||||
"""脱敏字符串中的敏感信息"""
|
||||
for name, patterns in self._compiled_patterns.items():
|
||||
for pattern in patterns:
|
||||
text = pattern.sub(f"[REDACTED:{name}]", text)
|
||||
|
||||
return text
|
||||
105
backend/app/agents/tools/hooks/config.py
Normal file
105
backend/app/agents/tools/hooks/config.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Hook 配置持久化 - Phase 7.3"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.hooks.manager import get_hook_manager
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookConfigEntry:
|
||||
"""Hook 配置条目"""
|
||||
|
||||
name: str
|
||||
hook_type: str
|
||||
enabled: bool
|
||||
tool_names: list[str] | None = None
|
||||
categories: list[str] | None = None
|
||||
priority: int = 0
|
||||
|
||||
|
||||
class HookConfigPersistence:
|
||||
"""Hook 配置持久化"""
|
||||
|
||||
def __init__(self, config_path: str | None = None):
|
||||
"""
|
||||
Args:
|
||||
config_path: 配置文件路径,None 则使用默认路径
|
||||
"""
|
||||
if config_path is None:
|
||||
config_path = os.path.join(
|
||||
os.path.dirname(__file__), "..", "..", "..", "..", "config", "hooks.json"
|
||||
)
|
||||
self.config_path = config_path
|
||||
|
||||
def load_config(self) -> list[HookConfigEntry]:
|
||||
"""从文件加载 Hook 配置"""
|
||||
if not os.path.exists(self.config_path):
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(self.config_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return [HookConfigEntry(**entry) for entry in data]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def save_config(self, entries: list[HookConfigEntry]) -> bool:
|
||||
"""保存 Hook 配置到文件"""
|
||||
try:
|
||||
os.makedirs(os.path.dirname(self.config_path), exist_ok=True)
|
||||
with open(self.config_path, "w", encoding="utf-8") as f:
|
||||
json.dump([asdict(e) for e in entries], f, indent=2, ensure_ascii=False)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def apply_config(self) -> int:
|
||||
"""应用配置到 HookManager
|
||||
|
||||
Returns:
|
||||
应用的 Hook 数量
|
||||
"""
|
||||
from app.agents.tools.hooks.types import HookType
|
||||
|
||||
manager = get_hook_manager()
|
||||
entries = self.load_config()
|
||||
count = 0
|
||||
|
||||
for entry in entries:
|
||||
if entry.enabled:
|
||||
from app.agents.tools.hooks.types import HookDefinition, HookTrigger
|
||||
|
||||
trigger = HookTrigger(
|
||||
tool_names=entry.tool_names,
|
||||
categories=entry.categories,
|
||||
)
|
||||
|
||||
# 创建空的 handler,只是注册配置
|
||||
hook_def = HookDefinition(
|
||||
name=entry.name,
|
||||
hook_type=HookType(entry.hook_type),
|
||||
trigger=trigger,
|
||||
handler=lambda ctx, *args: ctx,
|
||||
priority=entry.priority,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
manager.register(hook_def)
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
# 全局单例
|
||||
_persistence: HookConfigPersistence | None = None
|
||||
|
||||
|
||||
def get_hook_config_persistence() -> HookConfigPersistence:
|
||||
"""获取全局 Hook 配置持久化实例"""
|
||||
global _persistence
|
||||
if _persistence is None:
|
||||
_persistence = HookConfigPersistence()
|
||||
return _persistence
|
||||
5
backend/app/agents/tools/hooks/custom/__init__.py
Normal file
5
backend/app/agents/tools/hooks/custom/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""自定义 Hook 加载器包"""
|
||||
|
||||
from app.agents.tools.hooks.custom.loader import CustomHookLoader, get_custom_hook_loader
|
||||
|
||||
__all__ = ["CustomHookLoader", "get_custom_hook_loader"]
|
||||
153
backend/app/agents/tools/hooks/custom/loader.py
Normal file
153
backend/app/agents/tools/hooks/custom/loader.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""自定义 Hook 加载器 - Phase 7.4
|
||||
|
||||
支持动态加载用户自定义的 Hook。
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.hooks.types import HookDefinition, HookType, HookTrigger, HookResult
|
||||
|
||||
|
||||
class CustomHookLoader:
|
||||
"""自定义 Hook 加载器
|
||||
|
||||
从指定目录动态加载自定义 Hook 模块。
|
||||
"""
|
||||
|
||||
def __init__(self, hooks_dir: str | None = None):
|
||||
"""
|
||||
Args:
|
||||
hooks_dir: Hook 目录,None 则使用默认目录
|
||||
"""
|
||||
if hooks_dir is None:
|
||||
hooks_dir = os.path.join(
|
||||
os.path.dirname(__file__), "..", "..", "..", "data", "custom_hooks"
|
||||
)
|
||||
self.hooks_dir = hooks_dir
|
||||
self._loaded_hooks: dict[str, HookDefinition] = {}
|
||||
|
||||
def load_all(self) -> list[HookDefinition]:
|
||||
"""加载所有自定义 Hook
|
||||
|
||||
Returns:
|
||||
Hook 定义列表
|
||||
"""
|
||||
hooks = []
|
||||
|
||||
if not os.path.exists(self.hooks_dir):
|
||||
return hooks
|
||||
|
||||
for filename in os.listdir(self.hooks_dir):
|
||||
if filename.endswith(".py") and not filename.startswith("_"):
|
||||
hook_path = os.path.join(self.hooks_dir, filename)
|
||||
hook_def = self._load_hook_from_file(hook_path, filename[:-3])
|
||||
if hook_def:
|
||||
hooks.append(hook_def)
|
||||
self._loaded_hooks[hook_def.name] = hook_def
|
||||
|
||||
return hooks
|
||||
|
||||
def _load_hook_from_file(self, hook_path: str, module_name: str) -> HookDefinition | None:
|
||||
"""从文件加载 Hook
|
||||
|
||||
Args:
|
||||
hook_path: Hook 文件路径
|
||||
module_name: 模块名
|
||||
|
||||
Returns:
|
||||
Hook 定义或 None
|
||||
"""
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location(module_name, hook_path)
|
||||
if not spec or not spec.loader:
|
||||
return None
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# 查找 HOOK_DEFINITION 或 hook_definition
|
||||
hook_def = getattr(module, "HOOK_DEFINITION", None) or getattr(
|
||||
module, "hook_definition", None
|
||||
)
|
||||
|
||||
if hook_def and isinstance(hook_def, HookDefinition):
|
||||
return hook_def
|
||||
|
||||
# 如果没有定义,尝试从函数自动推断
|
||||
if hasattr(module, "pre_tool_hook") or hasattr(module, "post_tool_hook"):
|
||||
return self._infer_hook_definition(module, module_name)
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def _infer_hook_definition(self, module: Any, module_name: str) -> HookDefinition | None:
|
||||
"""从模块函数推断 Hook 定义
|
||||
|
||||
Args:
|
||||
module: 模块对象
|
||||
module_name: 模块名
|
||||
|
||||
Returns:
|
||||
Hook 定义或 None
|
||||
"""
|
||||
hook_type = None
|
||||
handler = None
|
||||
|
||||
if hasattr(module, "pre_tool_hook"):
|
||||
handler = module.pre_tool_hook
|
||||
hook_type = HookType.PRE_TOOL_USE
|
||||
elif hasattr(module, "post_tool_hook"):
|
||||
handler = module.post_tool_hook
|
||||
hook_type = HookType.POST_TOOL_USE
|
||||
elif hasattr(module, "error_tool_hook"):
|
||||
handler = module.error_tool_hook
|
||||
hook_type = HookType.TOOL_ERROR
|
||||
|
||||
if not handler or not hook_type:
|
||||
return None
|
||||
|
||||
return HookDefinition(
|
||||
name=module_name,
|
||||
hook_type=hook_type,
|
||||
trigger=HookTrigger(),
|
||||
handler=handler,
|
||||
priority=0,
|
||||
enabled=True,
|
||||
description=f"Auto-loaded hook from {module_name}",
|
||||
)
|
||||
|
||||
def get_hook(self, name: str) -> HookDefinition | None:
|
||||
"""获取已加载的 Hook
|
||||
|
||||
Args:
|
||||
name: Hook 名称
|
||||
|
||||
Returns:
|
||||
Hook 定义或 None
|
||||
"""
|
||||
return self._loaded_hooks.get(name)
|
||||
|
||||
def reload(self) -> list[HookDefinition]:
|
||||
"""重新加载所有 Hook
|
||||
|
||||
Returns:
|
||||
重新加载的 Hook 列表
|
||||
"""
|
||||
self._loaded_hooks.clear()
|
||||
return self.load_all()
|
||||
|
||||
|
||||
# 全局加载器
|
||||
_loader: CustomHookLoader | None = None
|
||||
|
||||
|
||||
def get_custom_hook_loader() -> CustomHookLoader:
|
||||
"""获取全局自定义 Hook 加载器"""
|
||||
global _loader
|
||||
if _loader is None:
|
||||
_loader = CustomHookLoader()
|
||||
return _loader
|
||||
170
backend/app/agents/tools/hooks/executor.py
Normal file
170
backend/app/agents/tools/hooks/executor.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""Hook 执行器 - Phase 6.2
|
||||
|
||||
执行 Hook 拦截逻辑。
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.hooks.manager import get_hook_manager
|
||||
from app.agents.tools.hooks.types import (
|
||||
HookDefinition,
|
||||
HookResult,
|
||||
HookType,
|
||||
ExecutionContext,
|
||||
)
|
||||
|
||||
|
||||
class HookExecutor:
|
||||
"""Hook 执行器
|
||||
|
||||
负责在工具执行前后执行 Hook 逻辑。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._manager = get_hook_manager()
|
||||
|
||||
async def execute_pre_hooks(
|
||||
self, context: ExecutionContext
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""执行 pre-tool Hook
|
||||
|
||||
Args:
|
||||
context: 执行上下文
|
||||
|
||||
Returns:
|
||||
(是否继续执行, 修改后的输入)
|
||||
"""
|
||||
hooks = self._manager.get_hooks(HookType.PRE_TOOL_USE, context.tool_name)
|
||||
modified_input = context.tool_input
|
||||
|
||||
for hook in hooks:
|
||||
try:
|
||||
# 调用 hook handler
|
||||
handler = hook.handler
|
||||
if callable(handler):
|
||||
result = await self._call_hook(handler, context)
|
||||
if result and not result.continue_execution:
|
||||
# Hook 决定中断执行
|
||||
return False, modified_input
|
||||
if result.modified_input is not None:
|
||||
modified_input = result.modified_input
|
||||
except Exception as e:
|
||||
# Hook 出错,默认继续执行
|
||||
pass
|
||||
|
||||
return True, modified_input
|
||||
|
||||
async def execute_post_hooks(self, context: ExecutionContext, result: Any) -> Any:
|
||||
"""执行 post-tool Hook
|
||||
|
||||
Args:
|
||||
context: 执行上下文
|
||||
result: 工具执行结果
|
||||
|
||||
Returns:
|
||||
修改后的结果
|
||||
"""
|
||||
hooks = self._manager.get_hooks(HookType.POST_TOOL_USE, context.tool_name)
|
||||
modified_result = result
|
||||
|
||||
for hook in hooks:
|
||||
try:
|
||||
handler = hook.handler
|
||||
if callable(handler):
|
||||
hook_result = await self._call_hook(handler, context, modified_result)
|
||||
if hook_result and hook_result.modified_output is not None:
|
||||
modified_result = hook_result.modified_output
|
||||
except Exception:
|
||||
# Hook 出错,默认保留原结果
|
||||
pass
|
||||
|
||||
return modified_result
|
||||
|
||||
async def execute_error_hooks(
|
||||
self, context: ExecutionContext, error: Exception
|
||||
) -> HookResult | None:
|
||||
"""执行 error Hook
|
||||
|
||||
Args:
|
||||
context: 执行上下文
|
||||
error: 异常
|
||||
|
||||
Returns:
|
||||
Hook 结果,如果返回 None 则继续传播错误
|
||||
"""
|
||||
hooks = self._manager.get_hooks(HookType.TOOL_ERROR, context.tool_name)
|
||||
|
||||
for hook in hooks:
|
||||
try:
|
||||
handler = hook.handler
|
||||
if callable(handler):
|
||||
result = await self._call_hook(handler, context, error)
|
||||
if result is not None and result.continue_execution:
|
||||
return result
|
||||
except Exception:
|
||||
# Hook 出错,继续执行其他 error hooks
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
async def execute_skip_check(self, context: ExecutionContext) -> bool:
|
||||
"""检查是否应跳过工具执行
|
||||
|
||||
Args:
|
||||
context: 执行上下文
|
||||
|
||||
Returns:
|
||||
True 表示跳过,False 表示执行
|
||||
"""
|
||||
hooks = self._manager.get_hooks(HookType.TOOL_SKIP, context.tool_name)
|
||||
|
||||
for hook in hooks:
|
||||
try:
|
||||
handler = hook.handler
|
||||
if callable(handler):
|
||||
result = await self._call_hook(handler, context)
|
||||
if result is not None and isinstance(result, bool):
|
||||
return result
|
||||
except Exception:
|
||||
# Hook 出错,默认不跳过
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
async def _call_hook(
|
||||
self, handler: Any, context: ExecutionContext, *args: Any
|
||||
) -> HookResult | None:
|
||||
"""调用 Hook 处理函数
|
||||
|
||||
Args:
|
||||
handler: Hook 处理函数
|
||||
context: 执行上下文
|
||||
*args: 额外参数
|
||||
|
||||
Returns:
|
||||
Hook 结果
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# 如果是普通函数,直接调用
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
return await handler(context, *args)
|
||||
else:
|
||||
return handler(context, *args)
|
||||
|
||||
|
||||
# 全局单例
|
||||
_executor: HookExecutor | None = None
|
||||
|
||||
|
||||
def get_hook_executor() -> HookExecutor:
|
||||
"""获取全局 Hook 执行器
|
||||
|
||||
Returns:
|
||||
全局 HookExecutor 实例
|
||||
"""
|
||||
global _executor
|
||||
if _executor is None:
|
||||
_executor = HookExecutor()
|
||||
return _executor
|
||||
174
backend/app/agents/tools/hooks/manager.py
Normal file
174
backend/app/agents/tools/hooks/manager.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""Hook 管理器 - Phase 6.2
|
||||
|
||||
管理 Hook 的注册、查找和配置。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.hooks.types import (
|
||||
HookDefinition,
|
||||
HookResult,
|
||||
HookTrigger,
|
||||
HookType,
|
||||
ExecutionContext,
|
||||
)
|
||||
|
||||
|
||||
class HookManager:
|
||||
"""Hook 管理器
|
||||
|
||||
管理全局 Hook 的注册和配置。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._hooks: dict[HookType, list[HookDefinition]] = {
|
||||
HookType.PRE_TOOL_USE: [],
|
||||
HookType.POST_TOOL_USE: [],
|
||||
HookType.TOOL_ERROR: [],
|
||||
HookType.TOOL_SKIP: [],
|
||||
}
|
||||
self._global_hooks: list[HookDefinition] = [] # 全局 Hook(对所有工具生效)
|
||||
|
||||
def register(self, definition: HookDefinition) -> None:
|
||||
"""注册 Hook
|
||||
|
||||
Args:
|
||||
definition: Hook 定义
|
||||
"""
|
||||
if definition.trigger.tool_names is None and definition.trigger.categories is None:
|
||||
# 全局 Hook
|
||||
self._global_hooks.append(definition)
|
||||
else:
|
||||
# 特定工具 Hook
|
||||
self._hooks[definition.hook_type].append(definition)
|
||||
|
||||
# 按优先级排序
|
||||
self._hooks[definition.hook_type].sort(key=lambda h: h.priority, reverse=True)
|
||||
self._global_hooks.sort(key=lambda h: h.priority, reverse=True)
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""注销 Hook
|
||||
|
||||
Args:
|
||||
name: Hook 名称
|
||||
|
||||
Returns:
|
||||
是否成功注销
|
||||
"""
|
||||
# 从特定工具 Hook 中移除
|
||||
for hooks in self._hooks.values():
|
||||
for i, hook in enumerate(hooks):
|
||||
if hook.name == name:
|
||||
hooks.pop(i)
|
||||
return True
|
||||
|
||||
# 从全局 Hook 中移除
|
||||
for i, hook in enumerate(self._global_hooks):
|
||||
if hook.name == name:
|
||||
self._global_hooks.pop(i)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_hooks(self, hook_type: HookType, tool_name: str | None = None) -> list[HookDefinition]:
|
||||
"""获取指定类型和工具的 Hook
|
||||
|
||||
Args:
|
||||
hook_type: Hook 类型
|
||||
tool_name: 工具名称(可选)
|
||||
|
||||
Returns:
|
||||
匹配的 Hook 列表
|
||||
"""
|
||||
result: list[HookDefinition] = []
|
||||
|
||||
# 添加全局 Hook
|
||||
for hook in self._global_hooks:
|
||||
if hook.hook_type == hook_type and hook.enabled:
|
||||
result.append(hook)
|
||||
|
||||
# 添加特定工具 Hook
|
||||
for hook in self._hooks[hook_type]:
|
||||
if not hook.enabled:
|
||||
continue
|
||||
|
||||
if hook.trigger.tool_names is None and hook.trigger.categories is None:
|
||||
continue
|
||||
|
||||
# 检查是否匹配
|
||||
if hook.trigger.tool_names and tool_name not in hook.trigger.tool_names:
|
||||
continue
|
||||
|
||||
result.append(hook)
|
||||
|
||||
return result
|
||||
|
||||
def list_all(self) -> list[HookDefinition]:
|
||||
"""列出所有已注册的 Hook
|
||||
|
||||
Returns:
|
||||
Hook 列表
|
||||
"""
|
||||
all_hooks = list(self._global_hooks)
|
||||
for hooks in self._hooks.values():
|
||||
all_hooks.extend(hooks)
|
||||
return all_hooks
|
||||
|
||||
def enable(self, name: str) -> bool:
|
||||
"""启用 Hook
|
||||
|
||||
Args:
|
||||
name: Hook 名称
|
||||
|
||||
Returns:
|
||||
是否成功启用
|
||||
"""
|
||||
for hook in self.list_all():
|
||||
if hook.name == name:
|
||||
hook.enabled = True
|
||||
return True
|
||||
return False
|
||||
|
||||
def disable(self, name: str) -> bool:
|
||||
"""禁用 Hook
|
||||
|
||||
Args:
|
||||
name: Hook 名称
|
||||
|
||||
Returns:
|
||||
是否成功禁用
|
||||
"""
|
||||
for hook in self.list_all():
|
||||
if hook.name == name:
|
||||
hook.enabled = False
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清除所有 Hook"""
|
||||
self._hooks = {ht: [] for ht in HookType}
|
||||
self._global_hooks = []
|
||||
|
||||
|
||||
# 全局单例
|
||||
_global_hook_manager: HookManager | None = None
|
||||
|
||||
|
||||
def get_hook_manager() -> HookManager:
|
||||
"""获取全局 Hook 管理器
|
||||
|
||||
Returns:
|
||||
全局 HookManager 实例
|
||||
"""
|
||||
global _global_hook_manager
|
||||
if _global_hook_manager is None:
|
||||
_global_hook_manager = HookManager()
|
||||
return _global_hook_manager
|
||||
|
||||
|
||||
def reset_hook_manager() -> None:
|
||||
"""重置全局 Hook 管理器(用于测试)"""
|
||||
global _global_hook_manager
|
||||
if _global_hook_manager is not None:
|
||||
_global_hook_manager.clear()
|
||||
_global_hook_manager = None
|
||||
90
backend/app/agents/tools/hooks/types.py
Normal file
90
backend/app/agents/tools/hooks/types.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Hook 类型定义 - Phase 6.2
|
||||
|
||||
Hook 拦截系统类型定义。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
class HookType(Enum):
|
||||
"""Hook 类型"""
|
||||
|
||||
PRE_TOOL_USE = "pre_tool_use" # 工具执行前
|
||||
POST_TOOL_USE = "post_tool_use" # 工具执行后
|
||||
TOOL_ERROR = "tool_error" # 工具执行出错
|
||||
TOOL_SKIP = "tool_skip" # 工具跳过(条件执行)
|
||||
|
||||
|
||||
class HookStage(Enum):
|
||||
"""Hook 执行阶段"""
|
||||
|
||||
BEFORE = "before"
|
||||
AFTER = "after"
|
||||
ON_ERROR = "on_error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookTrigger:
|
||||
"""Hook 触发条件"""
|
||||
|
||||
tool_names: list[str] | None = None # 只对特定工具生效,None 表示全部
|
||||
categories: list[str] | None = None # 只对特定类别生效
|
||||
conditions: dict[str, Any] | None = None # 自定义条件
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookDefinition:
|
||||
"""Hook 定义"""
|
||||
|
||||
name: str
|
||||
hook_type: HookType
|
||||
trigger: HookTrigger
|
||||
handler: Callable[..., Any] # Hook 处理函数
|
||||
priority: int = 0 # 优先级,数字越大越先执行
|
||||
enabled: bool = True
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookResult:
|
||||
"""Hook 执行结果"""
|
||||
|
||||
hook_name: str
|
||||
success: bool
|
||||
continue_execution: bool = True # False 表示中断执行
|
||||
modified_input: Any = None # 修改后的输入
|
||||
modified_output: Any = None # 修改后的输出
|
||||
error: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionContext:
|
||||
"""工具执行上下文"""
|
||||
|
||||
tool_name: str
|
||||
tool_input: dict[str, Any]
|
||||
user_id: str | None = None
|
||||
session_id: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# 执行结果(由 HookExecutor 填充)
|
||||
result: Any = None
|
||||
error: Exception | None = None
|
||||
start_time: float | None = None
|
||||
end_time: float | None = None
|
||||
|
||||
|
||||
# Hook 处理函数类型
|
||||
HookHandler = Callable[[ExecutionContext, HookDefinition], HookResult]
|
||||
|
||||
# Pre-hook: 在工具执行前调用,可以修改输入或决定是否跳过
|
||||
PreToolHook = Callable[[ExecutionContext], tuple[bool, dict[str, Any] | None]]
|
||||
# post-hook: 在工具执行后调用,可以修改输出
|
||||
PostToolHook = Callable[[ExecutionContext, Any], Any]
|
||||
# Error hook: 在工具出错时调用
|
||||
ErrorToolHook = Callable[[ExecutionContext, Exception], HookResult | None]
|
||||
# Skip hook: 决定是否跳过工具执行
|
||||
SkipToolHook = Callable[[ExecutionContext], bool]
|
||||
77
backend/app/agents/tools/manifest.py
Normal file
77
backend/app/agents/tools/manifest.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""工具元数据和数据类型定义"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ToolCategory(Enum):
|
||||
"""工具类别"""
|
||||
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
EXTERNAL = "external"
|
||||
DB_WRITE = "db_write"
|
||||
NETWORK = "network"
|
||||
|
||||
|
||||
class SideEffectScope(Enum):
|
||||
"""副作用范围"""
|
||||
|
||||
NONE = "none"
|
||||
LOCAL_STATE = "local_state"
|
||||
DB_WRITE = "db_write"
|
||||
NETWORK = "network"
|
||||
|
||||
|
||||
class PermissionClass(Enum):
|
||||
"""权限级别"""
|
||||
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
EXTERNAL = "external"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolManifest:
|
||||
"""工具元数据"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
category: ToolCategory
|
||||
parameters: dict[str, Any] # JSON Schema
|
||||
return_schema: dict[str, Any]
|
||||
permission_class: PermissionClass
|
||||
side_effect_scope: SideEffectScope
|
||||
requires_confirmation: bool = False
|
||||
is_streaming: bool = False
|
||||
tags: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"category": self.category.value,
|
||||
"parameters": self.parameters,
|
||||
"return_schema": self.return_schema,
|
||||
"permission_class": self.permission_class.value,
|
||||
"side_effect_scope": self.side_effect_scope.value,
|
||||
"requires_confirmation": self.requires_confirmation,
|
||||
"is_streaming": self.is_streaming,
|
||||
"tags": self.tags,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookConfig:
|
||||
"""Hook 配置"""
|
||||
|
||||
name: str
|
||||
hook_type: str # "pre_tool_use", "post_tool_use", "tool_error", "tool_skip"
|
||||
filter_names: list[str] | None = None # 只对特定工具生效,None 表示全部
|
||||
|
||||
def matches_tool(self, tool_name: str) -> bool:
|
||||
"""检查 Hook 是否对指定工具生效"""
|
||||
if self.filter_names is None:
|
||||
return True
|
||||
return tool_name in self.filter_names
|
||||
251
backend/app/agents/tools/migration.py
Normal file
251
backend/app/agents/tools/migration.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""工具迁移和向后兼容层 - Phase 6.1
|
||||
|
||||
将现有 @tool 装饰的工具迁移到 ToolRegistry,同时保持向后兼容。
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
from typing import Any, Callable
|
||||
|
||||
from app.agents.tools.manifest import (
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
ToolCategory,
|
||||
ToolManifest,
|
||||
)
|
||||
from app.agents.tools.registry import get_tool_registry
|
||||
|
||||
|
||||
# 现有工具的类别映射
|
||||
_TOOL_CATEGORY_MAP: dict[str, tuple[ToolCategory, PermissionClass, SideEffectScope]] = {
|
||||
# 知识检索 - 只读
|
||||
"search_knowledge": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||
"get_knowledge_graph_context": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||
"hybrid_search": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||
"web_search": (ToolCategory.NETWORK, PermissionClass.EXTERNAL, SideEffectScope.NETWORK),
|
||||
# 知识构建 - 写入
|
||||
"build_knowledge_graph": (
|
||||
ToolCategory.WRITE,
|
||||
PermissionClass.WRITE,
|
||||
SideEffectScope.LOCAL_STATE,
|
||||
),
|
||||
# 任务工具
|
||||
"get_tasks": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||
"create_task": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||
"update_task_status": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||
# 日程工具
|
||||
"get_schedule_day": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||
"create_todo": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||
"create_schedule_task": (
|
||||
ToolCategory.WRITE,
|
||||
PermissionClass.WRITE,
|
||||
SideEffectScope.LOCAL_STATE,
|
||||
),
|
||||
"create_reminder": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||
"create_goal": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||
"resolve_time_expression": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||
# 论坛工具
|
||||
"get_forum_posts": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||
"create_forum_post": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||
"scan_forum_for_instructions": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||
}
|
||||
|
||||
|
||||
def get_tool_category(name: str) -> tuple[ToolCategory, PermissionClass, SideEffectScope]:
|
||||
"""获取工具的类别信息"""
|
||||
return _TOOL_CATEGORY_MAP.get(
|
||||
name,
|
||||
(ToolCategory.EXTERNAL, PermissionClass.EXTERNAL, SideEffectScope.NETWORK),
|
||||
)
|
||||
|
||||
|
||||
def infer_tags_from_docstring(docstring: str | None) -> list[str]:
|
||||
"""从 docstring 推断工具标签"""
|
||||
if not docstring:
|
||||
return []
|
||||
tags = []
|
||||
doc_lower = docstring.lower()
|
||||
if "搜索" in docstring or "查询" in docstring or "search" in doc_lower:
|
||||
tags.append("search")
|
||||
if "创建" in docstring or "新建" in docstring or "create" in doc_lower:
|
||||
tags.append("create")
|
||||
if "获取" in docstring or "读取" in docstring or "get" in doc_lower:
|
||||
tags.append("read")
|
||||
if "更新" in docstring or "修改" in docstring or "update" in doc_lower:
|
||||
tags.append("update")
|
||||
return tags
|
||||
|
||||
|
||||
def migrate_tool(tool_func: Callable) -> Callable:
|
||||
"""将现有 @tool 装饰的函数迁移到 ToolRegistry
|
||||
|
||||
Args:
|
||||
tool_func: LangChain @tool 装饰的函数
|
||||
|
||||
Returns:
|
||||
原函数(已注册到 registry)
|
||||
"""
|
||||
registry = get_tool_registry()
|
||||
|
||||
# 如果已经注册,跳过
|
||||
if registry.get(tool_func.name):
|
||||
return tool_func
|
||||
|
||||
# 获取类别信息
|
||||
category, permission, side_effect = get_tool_category(tool_func.name)
|
||||
|
||||
# 从 docstring 提取 description
|
||||
description = tool_func.description if hasattr(tool_func, "description") else ""
|
||||
|
||||
# 推断 tags
|
||||
tags = infer_tags_from_docstring(description)
|
||||
tags.append("migrated")
|
||||
|
||||
# 创建 manifest
|
||||
manifest = ToolManifest(
|
||||
name=tool_func.name,
|
||||
description=description,
|
||||
category=category,
|
||||
parameters={}, # LangChain @tool 动态处理参数
|
||||
return_schema={},
|
||||
permission_class=permission,
|
||||
side_effect_scope=side_effect,
|
||||
requires_confirmation=side_effect != SideEffectScope.NONE,
|
||||
is_streaming=False,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
# 注册到 registry
|
||||
registry.register(manifest, tool_func)
|
||||
|
||||
return tool_func
|
||||
|
||||
|
||||
def migrate_all_tools() -> int:
|
||||
"""迁移所有现有工具到 ToolRegistry
|
||||
|
||||
Returns:
|
||||
迁移的工具数量
|
||||
"""
|
||||
from app.agents.tools import (
|
||||
ALL_TOOLS,
|
||||
KNOWLEDGE_GRAPH_TOOLS,
|
||||
KNOWLEDGE_RETRIEVAL_TOOLS,
|
||||
SCHEDULE_READ_TOOLS,
|
||||
SCHEDULE_WRITE_TOOLS,
|
||||
TASK_TOOLS,
|
||||
FORUM_TOOLS,
|
||||
)
|
||||
|
||||
all_tools = (
|
||||
KNOWLEDGE_RETRIEVAL_TOOLS
|
||||
+ KNOWLEDGE_GRAPH_TOOLS
|
||||
+ TASK_TOOLS
|
||||
+ SCHEDULE_READ_TOOLS
|
||||
+ SCHEDULE_WRITE_TOOLS
|
||||
+ FORUM_TOOLS
|
||||
)
|
||||
|
||||
count = 0
|
||||
for tool in all_tools:
|
||||
try:
|
||||
migrate_tool(tool)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
print(f"Failed to migrate tool {getattr(tool, 'name', 'unknown')}: {e}")
|
||||
|
||||
return count
|
||||
|
||||
|
||||
class BackwardCompatTool:
|
||||
"""向后兼容工具包装器
|
||||
|
||||
确保现有代码通过 registry.get_executor() 仍能正常调用工具。
|
||||
"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self._registry = get_tool_registry()
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
executor = self._registry.get_executor(self.name)
|
||||
if executor is None:
|
||||
raise ValueError(f"Tool not found in registry: {self.name}")
|
||||
return executor(*args, **kwargs)
|
||||
|
||||
def invoke(self, tool_input: dict[str, Any]) -> Any:
|
||||
"""LangChain 风格的 invoke 调用"""
|
||||
executor = self._registry.get_executor(self.name)
|
||||
if executor is None:
|
||||
raise ValueError(f"Tool not found in registry: {self.name}")
|
||||
|
||||
# 处理位置参数
|
||||
if isinstance(tool_input, dict):
|
||||
return executor(**tool_input)
|
||||
return executor(tool_input)
|
||||
|
||||
|
||||
def create_compat_layer() -> dict[str, BackwardCompatTool]:
|
||||
"""创建向后兼容层
|
||||
|
||||
返回一个字典,允许通过名称访问兼容的工具包装器。
|
||||
"""
|
||||
registry = get_tool_registry()
|
||||
tools = registry.list_all()
|
||||
|
||||
return {tool.name: BackwardCompatTool(tool.name) for tool in tools}
|
||||
|
||||
|
||||
# 自动迁移装饰器
|
||||
def auto_migrate(func: Callable) -> Callable:
|
||||
"""自动迁移装饰器
|
||||
|
||||
用于装饰新的 @tool 函数,自动注册到 registry。
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# 迁移到 registry
|
||||
migrate_tool(wrapper)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# 便捷函数:获取兼容的工具执行器
|
||||
def get_tool_executor(name: str) -> Callable | None:
|
||||
"""获取工具执行器(兼容层)
|
||||
|
||||
优先从 registry 获取,fallback 到直接导入。
|
||||
"""
|
||||
registry = get_tool_registry()
|
||||
executor = registry.get_executor(name)
|
||||
|
||||
if executor is not None:
|
||||
return executor
|
||||
|
||||
# Fallback: 直接从模块导入(仅用于迁移期间)
|
||||
try:
|
||||
from app.agents.tools import (
|
||||
TASK_TOOLS,
|
||||
SCHEDULE_READ_TOOLS,
|
||||
SCHEDULE_WRITE_TOOLS,
|
||||
FORUM_TOOLS,
|
||||
KNOWLEDGE_RETRIEVAL_TOOLS,
|
||||
)
|
||||
|
||||
all_tools = (
|
||||
KNOWLEDGE_RETRIEVAL_TOOLS
|
||||
+ TASK_TOOLS
|
||||
+ SCHEDULE_READ_TOOLS
|
||||
+ SCHEDULE_WRITE_TOOLS
|
||||
+ FORUM_TOOLS
|
||||
)
|
||||
|
||||
for tool in all_tools:
|
||||
if hasattr(tool, "name") and tool.name == name:
|
||||
return tool
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return None
|
||||
206
backend/app/agents/tools/registry.py
Normal file
206
backend/app/agents/tools/registry.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""工具注册表 - 工具系统重构 Phase 6.1"""
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable
|
||||
|
||||
from app.agents.tools.manifest import HookConfig, ToolManifest
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""工具注册表
|
||||
|
||||
统一管理所有工具的注册、发现和调用。
|
||||
支持工具元数据、权限分类、Hook 拦截。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._tools: dict[str, ToolManifest] = {}
|
||||
self._executors: dict[str, Callable] = {}
|
||||
self._hooks: dict[str, list[HookConfig]] = defaultdict(list)
|
||||
|
||||
def register(
|
||||
self, manifest: ToolManifest, executor: Callable, hooks: list[HookConfig] | None = None
|
||||
) -> None:
|
||||
"""注册工具
|
||||
|
||||
Args:
|
||||
manifest: 工具元数据
|
||||
executor: 工具执行函数
|
||||
hooks: 可选的 Hook 配置列表
|
||||
"""
|
||||
if manifest.name in self._tools:
|
||||
raise ValueError(f"Tool already registered: {manifest.name}")
|
||||
|
||||
self._tools[manifest.name] = manifest
|
||||
self._executors[manifest.name] = executor
|
||||
|
||||
if hooks:
|
||||
for hook in hooks:
|
||||
self._hooks[manifest.name].append(hook)
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""注销工具
|
||||
|
||||
Args:
|
||||
name: 工具名称
|
||||
|
||||
Returns:
|
||||
是否成功注销
|
||||
"""
|
||||
if name not in self._tools:
|
||||
return False
|
||||
|
||||
del self._tools[name]
|
||||
del self._executors[name]
|
||||
if name in self._hooks:
|
||||
del self._hooks[name]
|
||||
return True
|
||||
|
||||
def get(self, name: str) -> ToolManifest | None:
|
||||
"""获取工具元数据
|
||||
|
||||
Args:
|
||||
name: 工具名称
|
||||
|
||||
Returns:
|
||||
工具元数据,不存在返回 None
|
||||
"""
|
||||
return self._tools.get(name)
|
||||
|
||||
def get_executor(self, name: str) -> Callable | None:
|
||||
"""获取工具执行器
|
||||
|
||||
Args:
|
||||
name: 工具名称
|
||||
|
||||
Returns:
|
||||
工具执行函数,不存在返回 None
|
||||
"""
|
||||
return self._executors.get(name)
|
||||
|
||||
def get_hooks(self, name: str) -> list[HookConfig]:
|
||||
"""获取工具的 Hook 配置
|
||||
|
||||
Args:
|
||||
name: 工具名称
|
||||
|
||||
Returns:
|
||||
Hook 配置列表
|
||||
"""
|
||||
return self._hooks.get(name, [])
|
||||
|
||||
def list_all(self) -> list[ToolManifest]:
|
||||
"""列出所有已注册的工具
|
||||
|
||||
Returns:
|
||||
工具元数据列表
|
||||
"""
|
||||
return list(self._tools.values())
|
||||
|
||||
def list_by_category(self, category: Any) -> list[ToolManifest]:
|
||||
"""按类别列出工具
|
||||
|
||||
Args:
|
||||
category: 工具类别
|
||||
|
||||
Returns:
|
||||
该类别下的所有工具
|
||||
"""
|
||||
return [t for t in self._tools.values() if t.category == category]
|
||||
|
||||
def list_by_permission(self, permission: Any) -> list[ToolManifest]:
|
||||
"""按权限级别列出工具
|
||||
|
||||
Args:
|
||||
permission: 权限级别
|
||||
|
||||
Returns:
|
||||
该权限级别下的所有工具
|
||||
"""
|
||||
return [t for t in self._tools.values() if t.permission_class == permission]
|
||||
|
||||
def search_by_tag(self, tag: str) -> list[ToolManifest]:
|
||||
"""按标签搜索工具
|
||||
|
||||
Args:
|
||||
tag: 标签
|
||||
|
||||
Returns:
|
||||
包含该标签的工具
|
||||
"""
|
||||
return [t for t in self._tools.values() if tag in t.tags]
|
||||
|
||||
def search_by_name(self, keyword: str) -> list[ToolManifest]:
|
||||
"""按名称关键词搜索工具
|
||||
|
||||
Args:
|
||||
keyword: 关键词
|
||||
|
||||
Returns:
|
||||
名称包含关键词的工具
|
||||
"""
|
||||
keyword = keyword.lower()
|
||||
return [t for t in self._tools.values() if keyword in t.name.lower()]
|
||||
|
||||
def get_requires_confirmation(self, name: str) -> bool:
|
||||
"""检查工具是否需要确认
|
||||
|
||||
Args:
|
||||
name: 工具名称
|
||||
|
||||
Returns:
|
||||
是否需要确认
|
||||
"""
|
||||
manifest = self._tools.get(name)
|
||||
return manifest.requires_confirmation if manifest else False
|
||||
|
||||
def get_is_streaming(self, name: str) -> bool:
|
||||
"""检查工具是否支持流式执行
|
||||
|
||||
Args:
|
||||
name: 工具名称
|
||||
|
||||
Returns:
|
||||
是否支持流式
|
||||
"""
|
||||
manifest = self._tools.get(name)
|
||||
return manifest.is_streaming if manifest else False
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空注册表"""
|
||||
self._tools.clear()
|
||||
self._executors.clear()
|
||||
self._hooks.clear()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._tools)
|
||||
|
||||
def __contains__(self, name: str) -> bool:
|
||||
return name in self._tools
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._tools.values())
|
||||
|
||||
|
||||
# 全局单例实例
|
||||
_global_registry: ToolRegistry | None = None
|
||||
|
||||
|
||||
def get_tool_registry() -> ToolRegistry:
|
||||
"""获取全局工具注册表单例
|
||||
|
||||
Returns:
|
||||
全局 ToolRegistry 实例
|
||||
"""
|
||||
global _global_registry
|
||||
if _global_registry is None:
|
||||
_global_registry = ToolRegistry()
|
||||
return _global_registry
|
||||
|
||||
|
||||
def reset_tool_registry() -> None:
|
||||
"""重置全局工具注册表(用于测试)"""
|
||||
global _global_registry
|
||||
if _global_registry is not None:
|
||||
_global_registry.clear()
|
||||
_global_registry = None
|
||||
@@ -2,8 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import date, datetime
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
@@ -11,21 +9,16 @@ from langchain_core.tools import tool
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agents.context import get_current_user
|
||||
from app.agents.tools.async_bridge import run_async
|
||||
from app.database import async_session
|
||||
from app.models.goal import Goal, GoalStatus
|
||||
from app.models.reminder import Reminder
|
||||
from app.models.task import Task, TaskPriority, TaskStatus
|
||||
from app.models.todo import DailyTodo, TodoSource
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def _run_async(coro, timeout: int = 30):
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
|
||||
return run_async(coro, timeout=timeout)
|
||||
|
||||
|
||||
def _parse_date(value: str | None) -> date:
|
||||
|
||||
@@ -5,25 +5,16 @@ Agent 工具集 - 知识库 & 图谱相关
|
||||
由于 LangChain 工具系统是同步的,内部用 run_in_executor 处理 async 逻辑。
|
||||
"""
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import asyncio
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.agents.context import get_current_user
|
||||
from app.agents.tools.async_bridge import run_async
|
||||
from app.database import async_session
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def _run_async(coro, timeout: int = 30):
|
||||
"""在同步上下文中运行 async 代码"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.run_in_executor(_executor, lambda: asyncio.run(coro))
|
||||
return future.result(timeout=timeout)
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
return run_async(coro, timeout=timeout)
|
||||
|
||||
|
||||
@tool
|
||||
|
||||
210
backend/app/agents/tools/streaming.py
Normal file
210
backend/app/agents/tools/streaming.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""流式工具执行器 - Phase 6.3
|
||||
|
||||
支持流式输出的工具执行器。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from app.agents.tools.hooks.executor import get_hook_executor
|
||||
from app.agents.tools.hooks.types import ExecutionContext
|
||||
from app.agents.tools.registry import get_tool_registry
|
||||
|
||||
|
||||
class StreamingToolExecutor:
|
||||
"""流式工具执行器
|
||||
|
||||
支持:
|
||||
- 普通工具的同步/异步执行
|
||||
- 流式工具的流式输出
|
||||
- Hook 拦截(pre/post/error)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._registry = get_tool_registry()
|
||||
self._hook_executor = get_hook_executor()
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any],
|
||||
user_id: str | None = None,
|
||||
) -> Any:
|
||||
"""执行工具(非流式)
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
tool_input: 工具输入参数
|
||||
user_id: 用户 ID(可选)
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
"""
|
||||
# 创建执行上下文
|
||||
context = ExecutionContext(
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# 获取工具和执行器
|
||||
manifest = self._registry.get(tool_name)
|
||||
if manifest is None:
|
||||
raise ValueError(f"Tool not found: {tool_name}")
|
||||
|
||||
executor = self._registry.get_executor(tool_name)
|
||||
if executor is None:
|
||||
raise ValueError(f"Executor not found for tool: {tool_name}")
|
||||
|
||||
# 检查是否跳过
|
||||
if await self._hook_executor.execute_skip_check(context):
|
||||
return {"skipped": True, "tool": tool_name}
|
||||
|
||||
# 执行 pre-hooks
|
||||
continue_execution, modified_input = await self._hook_executor.execute_pre_hooks(context)
|
||||
if not continue_execution:
|
||||
return {"pre_hook_aborted": True, "tool": tool_name}
|
||||
|
||||
# 执行工具
|
||||
try:
|
||||
context.start_time = asyncio.get_event_loop().time()
|
||||
|
||||
# 判断是同步还是异步执行
|
||||
if asyncio.iscoroutinefunction(executor):
|
||||
result = await executor(**modified_input)
|
||||
else:
|
||||
# 同步函数在线程池中执行
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(None, lambda: executor(**modified_input))
|
||||
|
||||
context.result = result
|
||||
context.end_time = asyncio.get_event_loop().time()
|
||||
|
||||
# 执行 post-hooks
|
||||
result = await self._hook_executor.execute_post_hooks(context, result)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
context.error = e
|
||||
context.end_time = asyncio.get_event_loop().time()
|
||||
|
||||
# 执行 error-hooks
|
||||
error_result = await self._hook_executor.execute_error_hooks(context, e)
|
||||
if error_result is not None:
|
||||
return error_result
|
||||
|
||||
raise
|
||||
|
||||
async def execute_streaming(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any],
|
||||
user_id: str | None = None,
|
||||
) -> AsyncGenerator[dict[str, Any], None]:
|
||||
"""执行流式工具
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
tool_input: 工具输入参数
|
||||
user_id: 用户 ID(可选)
|
||||
|
||||
Yields:
|
||||
流式输出片段
|
||||
"""
|
||||
# 创建执行上下文
|
||||
context = ExecutionContext(
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# 获取工具和执行器
|
||||
manifest = self._registry.get(tool_name)
|
||||
if manifest is None:
|
||||
raise ValueError(f"Tool not found: {tool_name}")
|
||||
|
||||
if not manifest.is_streaming:
|
||||
raise ValueError(f"Tool is not streaming: {tool_name}")
|
||||
|
||||
executor = self._registry.get_executor(tool_name)
|
||||
if executor is None:
|
||||
raise ValueError(f"Executor not found for tool: {tool_name}")
|
||||
|
||||
# 检查是否跳过
|
||||
if await self._hook_executor.execute_skip_check(context):
|
||||
yield {"type": "skipped", "tool": tool_name}
|
||||
return
|
||||
|
||||
# 执行 pre-hooks
|
||||
continue_execution, modified_input = await self._hook_executor.execute_pre_hooks(context)
|
||||
if not continue_execution:
|
||||
yield {"type": "pre_hook_aborted", "tool": tool_name}
|
||||
return
|
||||
|
||||
# 执行流式工具
|
||||
try:
|
||||
context.start_time = asyncio.get_event_loop().time()
|
||||
|
||||
# 调用执行器(应该返回 AsyncGenerator)
|
||||
result = executor(**modified_input)
|
||||
|
||||
# 如果是 async generator
|
||||
if asyncio.isasyncgen(result):
|
||||
async for chunk in result:
|
||||
yield {"type": "chunk", "data": chunk}
|
||||
else:
|
||||
# 普通协程
|
||||
data = await result
|
||||
yield {"type": "chunk", "data": data}
|
||||
|
||||
context.end_time = asyncio.get_event_loop().time()
|
||||
yield {"type": "done", "tool": tool_name}
|
||||
|
||||
except Exception as e:
|
||||
context.error = e
|
||||
context.end_time = asyncio.get_event_loop().time()
|
||||
|
||||
yield {"type": "error", "error": str(e), "tool": tool_name}
|
||||
|
||||
# 执行 error-hooks
|
||||
await self._hook_executor.execute_error_hooks(context, e)
|
||||
|
||||
async def execute_batch(
|
||||
self,
|
||||
tool_calls: list[dict[str, Any]],
|
||||
user_id: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""批量执行工具
|
||||
|
||||
Args:
|
||||
tool_calls: 工具调用列表,每个元素包含 tool_name 和 tool_input
|
||||
user_id: 用户 ID(可选)
|
||||
|
||||
Returns:
|
||||
执行结果列表
|
||||
"""
|
||||
tasks = []
|
||||
for call in tool_calls:
|
||||
tool_name = call.get("tool_name") or call.get("name")
|
||||
tool_input = call.get("tool_input") or call.get("input") or {}
|
||||
tasks.append(self.execute(tool_name, tool_input, user_id))
|
||||
|
||||
return await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# 全局单例
|
||||
_executor: StreamingToolExecutor | None = None
|
||||
|
||||
|
||||
def get_streaming_executor() -> StreamingToolExecutor:
|
||||
"""获取全局流式执行器
|
||||
|
||||
Returns:
|
||||
全局 StreamingToolExecutor 实例
|
||||
"""
|
||||
global _executor
|
||||
if _executor is None:
|
||||
_executor = StreamingToolExecutor()
|
||||
return _executor
|
||||
@@ -8,21 +8,13 @@ from langchain_core.tools import tool
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agents.context import get_current_user
|
||||
from app.agents.tools.async_bridge import run_async
|
||||
from app.database import async_session
|
||||
from app.models.task import Task, TaskPriority, TaskStatus
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def _run_async(coro, timeout: int = 30):
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
|
||||
return run_async(coro, timeout=timeout)
|
||||
|
||||
|
||||
def _normalize_title(title: str | None, content: str | None) -> str:
|
||||
|
||||
@@ -241,6 +241,10 @@ def normalize_tool_time_arguments(tool_name: str, args: dict, current_datetime_c
|
||||
if raw_value and not _is_iso_datetime(raw_value):
|
||||
payload = resolve_time_expression_data(raw_value, current_datetime_context=current_datetime_context, prefer="datetime")
|
||||
normalized["reminder_at"] = payload["resolved_datetime"]
|
||||
raw_date = normalized.get("date")
|
||||
if isinstance(raw_date, str) and raw_date.strip() and not _is_iso_date(raw_date):
|
||||
payload = resolve_time_expression_data(raw_date, current_datetime_context=current_datetime_context, prefer="date")
|
||||
normalized["date"] = payload["resolved_date"]
|
||||
return normalized
|
||||
|
||||
if tool_name in {"create_schedule_task", "create_task"}:
|
||||
|
||||
113
backend/app/agents/transport/remote.py
Normal file
113
backend/app/agents/transport/remote.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""远程传输层 - Phase 10.2"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructuredMessage:
|
||||
"""结构化消息"""
|
||||
|
||||
type: str # response, event, tool_call, error
|
||||
data: dict[str, Any]
|
||||
session_id: str | None = None
|
||||
|
||||
|
||||
class RemoteTransport:
|
||||
"""远程传输层
|
||||
|
||||
处理与远程 Agent 的通信。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._connections: dict[str, Any] = {}
|
||||
self._handlers: dict[str, Any] = {}
|
||||
|
||||
async def send_response(self, session_id: str, response: dict[str, Any]) -> bool:
|
||||
"""发送响应
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
response: 响应数据
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
message = StructuredMessage(
|
||||
type="response",
|
||||
data=response,
|
||||
session_id=session_id,
|
||||
)
|
||||
return await self._send(session_id, message)
|
||||
|
||||
async def send_event(self, session_id: str, event: dict[str, Any]) -> bool:
|
||||
"""发送事件
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
event: 事件数据
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
message = StructuredMessage(
|
||||
type="event",
|
||||
data=event,
|
||||
session_id=session_id,
|
||||
)
|
||||
return await self._send(session_id, message)
|
||||
|
||||
async def send_tool_call(self, session_id: str, tool_call: dict[str, Any]) -> bool:
|
||||
"""发送工具调用
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
tool_call: 工具调用数据
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
message = StructuredMessage(
|
||||
type="tool_call",
|
||||
data=tool_call,
|
||||
session_id=session_id,
|
||||
)
|
||||
return await self._send(session_id, message)
|
||||
|
||||
async def _send(self, session_id: str, message: StructuredMessage) -> bool:
|
||||
"""内部发送方法"""
|
||||
if session_id not in self._connections:
|
||||
return False
|
||||
|
||||
try:
|
||||
connection = self._connections[session_id]
|
||||
if hasattr(connection, "send"):
|
||||
await connection.send(json.dumps(message.__dict__))
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
def register_handler(self, event_type: str, handler: Any) -> None:
|
||||
"""注册消息处理器
|
||||
|
||||
Args:
|
||||
event_type: 事件类型
|
||||
handler: 处理函数
|
||||
"""
|
||||
self._handlers[event_type] = handler
|
||||
|
||||
async def handle_message(self, session_id: str, message: dict[str, Any]) -> None:
|
||||
"""处理收到的消息
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
message: 消息数据
|
||||
"""
|
||||
msg_type = message.get("type")
|
||||
handler = self._handlers.get(msg_type)
|
||||
if handler:
|
||||
await handler(session_id, message.get("data"))
|
||||
86
backend/app/agents/transport/structured_io.py
Normal file
86
backend/app/agents/transport/structured_io.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Structured IO for typed input/output - Phase 10.2"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructuredInput:
|
||||
"""Structured input wrapper"""
|
||||
|
||||
skill_name: str
|
||||
parameters: dict[str, Any]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructuredOutput:
|
||||
"""Structured output wrapper"""
|
||||
|
||||
skill_name: str
|
||||
result: Any
|
||||
success: bool
|
||||
error: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class StructuredIO:
|
||||
"""Handles structured input/output for agent communication"""
|
||||
|
||||
def parse_input(self, data: dict[str, Any]) -> StructuredInput:
|
||||
"""Parse structured input from dictionary.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing skill_name, parameters, and metadata
|
||||
|
||||
Returns:
|
||||
StructuredInput instance
|
||||
|
||||
Raises:
|
||||
ValueError: If required fields are missing
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("Input data must be a dictionary")
|
||||
|
||||
skill_name = data.get("skill_name")
|
||||
if not skill_name:
|
||||
raise ValueError("Missing required field: skill_name")
|
||||
if not isinstance(skill_name, str):
|
||||
raise ValueError("skill_name must be a string")
|
||||
|
||||
parameters = data.get("parameters")
|
||||
if parameters is None:
|
||||
raise ValueError("Missing required field: parameters")
|
||||
if not isinstance(parameters, dict):
|
||||
raise ValueError("parameters must be a dictionary")
|
||||
|
||||
metadata = data.get("metadata", {})
|
||||
if not isinstance(metadata, dict):
|
||||
raise ValueError("metadata must be a dictionary")
|
||||
|
||||
return StructuredInput(skill_name=skill_name, parameters=parameters, metadata=metadata)
|
||||
|
||||
def format_output(self, output: StructuredOutput) -> dict[str, Any]:
|
||||
"""Format structured output to dictionary.
|
||||
|
||||
Args:
|
||||
output: StructuredOutput instance
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the output
|
||||
"""
|
||||
result = {
|
||||
"skill_name": output.skill_name,
|
||||
"result": output.result,
|
||||
"success": output.success,
|
||||
}
|
||||
|
||||
if output.error is not None:
|
||||
result["error"] = output.error
|
||||
|
||||
if output.metadata is not None:
|
||||
result["metadata"] = output.metadata
|
||||
|
||||
return result
|
||||
207
backend/app/agents/transport/websocket.py
Normal file
207
backend/app/agents/transport/websocket.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""WebSocket 连接管理 - Phase 10.2
|
||||
|
||||
管理 WebSocket 连接的生命周期。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class WSConnection:
|
||||
"""WebSocket 连接"""
|
||||
|
||||
session_id: str
|
||||
websocket: Any # WebSocket 连接
|
||||
user_id: str | None = None
|
||||
created_at: float | None = None
|
||||
last_ping: float | None = None
|
||||
|
||||
|
||||
class WebSocketManager:
|
||||
"""WebSocket 连接管理器
|
||||
|
||||
管理所有 WebSocket 连接的生命周期。
|
||||
"""
|
||||
|
||||
def __init__(self, ping_interval: float = 30.0):
|
||||
"""
|
||||
Args:
|
||||
ping_interval: 心跳间隔(秒)
|
||||
"""
|
||||
self._connections: dict[str, WSConnection] = {}
|
||||
self._handlers: dict[str, Callable] = {}
|
||||
self._ping_interval = ping_interval
|
||||
self._ping_tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
async def connect(self, session_id: str, websocket: Any, user_id: str | None = None) -> bool:
|
||||
"""建立连接
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
websocket: WebSocket 连接
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
是否连接成功
|
||||
"""
|
||||
import time
|
||||
|
||||
if session_id in self._connections:
|
||||
return False
|
||||
|
||||
conn = WSConnection(
|
||||
session_id=session_id,
|
||||
websocket=websocket,
|
||||
user_id=user_id,
|
||||
created_at=time.time(),
|
||||
last_ping=time.time(),
|
||||
)
|
||||
self._connections[session_id] = conn
|
||||
|
||||
# 启动心跳
|
||||
self._ping_tasks[session_id] = asyncio.create_task(self._ping_loop(session_id))
|
||||
|
||||
return True
|
||||
|
||||
async def disconnect(self, session_id: str) -> bool:
|
||||
"""断开连接
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
|
||||
Returns:
|
||||
是否断开成功
|
||||
"""
|
||||
if session_id not in self._connections:
|
||||
return False
|
||||
|
||||
# 停止心跳
|
||||
if session_id in self._ping_tasks:
|
||||
self._ping_tasks[session_id].cancel()
|
||||
del self._ping_tasks[session_id]
|
||||
|
||||
del self._connections[session_id]
|
||||
return True
|
||||
|
||||
async def send(self, session_id: str, message: dict[str, Any]) -> bool:
|
||||
"""发送消息
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
message: 消息内容
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
if session_id not in self._connections:
|
||||
return False
|
||||
|
||||
try:
|
||||
conn = self._connections[session_id]
|
||||
await conn.websocket.send_json(message)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def broadcast(self, message: dict[str, Any]) -> int:
|
||||
"""广播消息
|
||||
|
||||
Args:
|
||||
message: 消息内容
|
||||
|
||||
Returns:
|
||||
发送成功的数量
|
||||
"""
|
||||
count = 0
|
||||
for session_id in list(self._connections.keys()):
|
||||
if await self.send(session_id, message):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
async def _ping_loop(self, session_id: str) -> None:
|
||||
"""心跳循环
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
"""
|
||||
import time
|
||||
|
||||
while session_id in self._connections:
|
||||
await asyncio.sleep(self._ping_interval)
|
||||
|
||||
if session_id not in self._connections:
|
||||
break
|
||||
|
||||
try:
|
||||
conn = self._connections[session_id]
|
||||
await conn.websocket.send_json({"type": "ping", "timestamp": time.time()})
|
||||
conn.last_ping = time.time()
|
||||
except Exception:
|
||||
await self.disconnect(session_id)
|
||||
break
|
||||
|
||||
def register_handler(self, event_type: str, handler: Callable) -> None:
|
||||
"""注册消息处理器
|
||||
|
||||
Args:
|
||||
event_type: 事件类型
|
||||
handler: 处理函数
|
||||
"""
|
||||
self._handlers[event_type] = handler
|
||||
|
||||
async def handle_message(self, session_id: str, message: dict[str, Any]) -> None:
|
||||
"""处理消息
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
message: 消息内容
|
||||
"""
|
||||
msg_type = message.get("type")
|
||||
handler = self._handlers.get(msg_type)
|
||||
if handler:
|
||||
await handler(session_id, message.get("data"))
|
||||
|
||||
def get_connection(self, session_id: str) -> WSConnection | None:
|
||||
"""获取连接
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
|
||||
Returns:
|
||||
连接信息或 None
|
||||
"""
|
||||
return self._connections.get(session_id)
|
||||
|
||||
def list_connections(self) -> list[WSConnection]:
|
||||
"""列出所有连接
|
||||
|
||||
Returns:
|
||||
连接列表
|
||||
"""
|
||||
return list(self._connections.values())
|
||||
|
||||
def is_connected(self, session_id: str) -> bool:
|
||||
"""检查是否连接
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
|
||||
Returns:
|
||||
是否已连接
|
||||
"""
|
||||
return session_id in self._connections
|
||||
|
||||
|
||||
# 全局单例
|
||||
_ws_manager: WebSocketManager | None = None
|
||||
|
||||
|
||||
def get_websocket_manager() -> WebSocketManager:
|
||||
"""获取全局 WebSocket 管理器"""
|
||||
global _ws_manager
|
||||
if _ws_manager is None:
|
||||
_ws_manager = WebSocketManager()
|
||||
return _ws_manager
|
||||
93
backend/app/agents/verifier.py
Normal file
93
backend/app/agents/verifier.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agents.schemas.task import AgentTask, TaskResult, TaskResultStatus, VerificationStatus
|
||||
from app.agents.state import AgentState
|
||||
|
||||
|
||||
class VerificationVerdict(BaseModel):
|
||||
status: VerificationStatus
|
||||
summary: str | None = None
|
||||
evidence: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
def normalize_task_result(
|
||||
task_result: TaskResult | dict[str, Any],
|
||||
*,
|
||||
default_task_id: str | None = None,
|
||||
) -> TaskResult:
|
||||
payload = task_result.model_dump(mode="json") if isinstance(task_result, TaskResult) else dict(task_result or {})
|
||||
normalized_status = payload.get("status")
|
||||
if normalized_status not in {"completed", "failed", "blocked", "passed", "skipped"}:
|
||||
normalized_status = "failed"
|
||||
return TaskResult(
|
||||
task_id=str(payload.get("task_id") or default_task_id or "unknown-task"),
|
||||
status=cast(TaskResultStatus, normalized_status),
|
||||
summary=payload.get("summary"),
|
||||
evidence=list(payload.get("evidence") or []),
|
||||
owner_agent_id=payload.get("owner_agent_id"),
|
||||
parent_task_id=payload.get("parent_task_id"),
|
||||
child_task_ids=list(payload.get("child_task_ids") or []),
|
||||
thread_id=payload.get("thread_id"),
|
||||
message_id=payload.get("message_id"),
|
||||
message_index=payload.get("message_index") if isinstance(payload.get("message_index"), int) else None,
|
||||
interrupt_records=list(payload.get("interrupt_records") or []),
|
||||
recovery_records=list(payload.get("recovery_records") or []),
|
||||
budget_snapshot=payload.get("budget_snapshot") if isinstance(payload.get("budget_snapshot"), dict) else None,
|
||||
next_action=payload.get("next_action"),
|
||||
output_data=payload.get("output_data") if isinstance(payload.get("output_data"), dict) else None,
|
||||
)
|
||||
|
||||
|
||||
def verify_task_result(
|
||||
*,
|
||||
task: AgentTask | dict[str, Any] | None = None,
|
||||
result: TaskResult | dict[str, Any] | None = None,
|
||||
summary: str | None = None,
|
||||
evidence: list[dict[str, Any]] | None = None,
|
||||
status: VerificationStatus | None = None,
|
||||
) -> VerificationVerdict:
|
||||
normalized_result = result.model_dump() if isinstance(result, TaskResult) else dict(result or {})
|
||||
normalized_task = task.model_dump() if isinstance(task, AgentTask) else dict(task or {})
|
||||
normalized_summary = summary or normalized_result.get("summary") or normalized_task.get("result_summary")
|
||||
normalized_evidence = list(evidence or normalized_result.get("evidence") or normalized_task.get("evidence") or [])
|
||||
|
||||
if status is not None:
|
||||
return VerificationVerdict(status=status, summary=normalized_summary, evidence=normalized_evidence)
|
||||
|
||||
normalized_status = normalized_result.get("status")
|
||||
if normalized_status in {"passed", "failed", "skipped"}:
|
||||
inferred_status = normalized_status
|
||||
elif normalized_status == "completed":
|
||||
inferred_status = "passed"
|
||||
elif normalized_status == "blocked":
|
||||
inferred_status = "skipped"
|
||||
elif normalized_result.get("success") is True:
|
||||
inferred_status = "passed"
|
||||
elif normalized_result.get("success") is False:
|
||||
inferred_status = "failed"
|
||||
elif normalized_summary or normalized_evidence:
|
||||
inferred_status = "skipped"
|
||||
else:
|
||||
inferred_status = "failed"
|
||||
normalized_summary = "No verification input available."
|
||||
|
||||
return VerificationVerdict(
|
||||
status=inferred_status,
|
||||
summary=normalized_summary,
|
||||
evidence=normalized_evidence,
|
||||
)
|
||||
|
||||
|
||||
def apply_verification_verdict(state: AgentState, verdict: VerificationVerdict) -> AgentState:
|
||||
next_state = dict(state)
|
||||
next_state["verification_status"] = verdict.status
|
||||
next_state["verification_summary"] = verdict.summary
|
||||
next_state["verification_evidence"] = list(verdict.evidence)
|
||||
return AgentState(**next_state)
|
||||
|
||||
|
||||
__all__ = ["VerificationVerdict", "apply_verification_verdict", "normalize_task_result", "verify_task_result"]
|
||||
@@ -1,10 +1,13 @@
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from app.config import settings
|
||||
from collections.abc import AsyncGenerator
|
||||
import os
|
||||
import re
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.config import settings
|
||||
|
||||
os.makedirs(settings.DATA_DIR, exist_ok=True)
|
||||
|
||||
engine = create_async_engine(
|
||||
@@ -24,12 +27,9 @@ class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
async def get_db() -> AsyncSession:
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with async_session() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
yield session
|
||||
|
||||
|
||||
async def init_db():
|
||||
@@ -37,6 +37,7 @@ async def init_db():
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
await ensure_log_columns(conn)
|
||||
await ensure_message_columns(conn)
|
||||
await ensure_conversation_columns(conn)
|
||||
await ensure_document_columns(conn)
|
||||
await ensure_user_columns(conn)
|
||||
await ensure_forum_columns(conn)
|
||||
@@ -79,6 +80,20 @@ async def ensure_message_columns(conn):
|
||||
await conn.execute(text(ddl))
|
||||
|
||||
|
||||
async def ensure_conversation_columns(conn):
|
||||
rows = await _get_table_info(conn, 'conversations')
|
||||
if not rows:
|
||||
return
|
||||
|
||||
columns = {row[1] for row in rows}
|
||||
required_columns = {
|
||||
'agent_state': "ALTER TABLE conversations ADD COLUMN agent_state JSON",
|
||||
}
|
||||
for column, ddl in required_columns.items():
|
||||
if column not in columns:
|
||||
await conn.execute(text(ddl))
|
||||
|
||||
|
||||
async def ensure_document_columns(conn):
|
||||
result = await conn.execute(text("PRAGMA table_info(documents)"))
|
||||
rows = result.fetchall()
|
||||
|
||||
@@ -23,6 +23,11 @@ from app.routers import (
|
||||
log_router,
|
||||
system_router,
|
||||
brain_router,
|
||||
hooks_router,
|
||||
plugins_router,
|
||||
marketplace_router,
|
||||
agent_skills_router,
|
||||
agent_sessions_router,
|
||||
)
|
||||
from app.routers.scheduler import router as scheduler_router
|
||||
from app.services.scheduler_service import start_scheduler, stop_scheduler, get_scheduler_status
|
||||
@@ -40,15 +45,15 @@ import os
|
||||
|
||||
|
||||
INSECURE_SECRET_KEYS = {
|
||||
'change-me-in-production',
|
||||
'change-me-to-a-random-secret-key',
|
||||
'jarvis-secret-key-change-in-production',
|
||||
"change-me-in-production",
|
||||
"change-me-to-a-random-secret-key",
|
||||
"jarvis-secret-key-change-in-production",
|
||||
}
|
||||
|
||||
|
||||
def validate_startup_security() -> None:
|
||||
if not settings.DEBUG and settings.SECRET_KEY in INSECURE_SECRET_KEYS:
|
||||
raise RuntimeError('SECRET_KEY must be changed before running with DEBUG disabled')
|
||||
raise RuntimeError("SECRET_KEY must be changed before running with DEBUG disabled")
|
||||
|
||||
|
||||
async def run_startup() -> None:
|
||||
@@ -117,6 +122,11 @@ app.include_router(log_router)
|
||||
app.include_router(system_router)
|
||||
app.include_router(brain_router)
|
||||
app.include_router(scheduler_router)
|
||||
app.include_router(hooks_router)
|
||||
app.include_router(plugins_router)
|
||||
app.include_router(marketplace_router)
|
||||
app.include_router(agent_skills_router)
|
||||
app.include_router(agent_sessions_router)
|
||||
|
||||
|
||||
@app.get("/api/health")
|
||||
|
||||
@@ -9,6 +9,7 @@ class Conversation(BaseModel):
|
||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||
title = Column(String(500), nullable=True)
|
||||
message_count = Column(Integer, default=0)
|
||||
agent_state = Column(JSON, nullable=True)
|
||||
|
||||
messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
@@ -15,3 +15,8 @@ from app.routers.skill import router as skill_router
|
||||
from app.routers.log import router as log_router
|
||||
from app.routers.system import router as system_router
|
||||
from app.routers.brain import router as brain_router
|
||||
from app.routers.hooks import router as hooks_router
|
||||
from app.routers.plugins import router as plugins_router
|
||||
from app.routers.plugins import _marketplace_router as marketplace_router
|
||||
from app.routers.agent_skills import router as agent_skills_router
|
||||
from app.routers.agent_sessions import router as agent_sessions_router
|
||||
|
||||
@@ -1,12 +1,42 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.agents.registry import load_builtin_registry_indexes
|
||||
from app.agents.runtime_metrics import coerce_cost_thresholds, estimate_token_cost, is_cost_budget_warning
|
||||
from app.models.agent import Agent
|
||||
from app.models.conversation import Conversation
|
||||
from app.models.skill import Skill
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.agent import AgentCreate, AgentOut, AgentStats, AgentConfigUpdate, AgentConfigOut
|
||||
from app.schemas.agent import (
|
||||
AgentConfigOut,
|
||||
AgentConfigUpdate,
|
||||
AgentCreate,
|
||||
AgentOut,
|
||||
AgentStats,
|
||||
AgentVisibilityCostByAgentOut,
|
||||
AgentVisibilityCostOut,
|
||||
AgentVisibilityCostSummaryOut,
|
||||
AgentVisibilityEvidenceOut,
|
||||
AgentVisibilityEventsResponse,
|
||||
AgentVisibilityEventOut,
|
||||
AgentVisibilityIsolationOut,
|
||||
AgentVisibilityRuntimeSummaryOut,
|
||||
AgentVisibilityTaskSummaryOut,
|
||||
AgentVisibilityThreadMessageOut,
|
||||
AgentVisibilityThreadOut,
|
||||
AgentVisibilityTopologyNodeOut,
|
||||
AgentVisibilityTopologyOut,
|
||||
AgentVisibilityToolGovernanceItemOut,
|
||||
AgentVisibilityToolGovernanceOut,
|
||||
AgentVisibilityVerifierOut,
|
||||
)
|
||||
from app.services.agent_service import _extract_continuity_snapshot
|
||||
|
||||
router = APIRouter(prefix="/api/agents", tags=["Agent"])
|
||||
|
||||
@@ -21,6 +51,295 @@ SUB_COMMANDERS_BY_ROLE = {
|
||||
"librarian": ["librarian_retrieval", "librarian_graph"],
|
||||
"analyst": ["analyst_progress", "analyst_insights"],
|
||||
}
|
||||
ALLOWED_AGENT_ROLES = set(DEFAULT_AGENT_ROLES) | {
|
||||
role
|
||||
for sub_roles in SUB_COMMANDERS_BY_ROLE.values()
|
||||
for role in sub_roles
|
||||
}
|
||||
|
||||
|
||||
def _parse_visibility_datetime(value: str | None) -> datetime | None:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail="时间参数必须是 ISO 8601 格式") from exc
|
||||
|
||||
|
||||
async def _get_visibility_state(
|
||||
conversation_id: str,
|
||||
*,
|
||||
current_user: User,
|
||||
db: AsyncSession,
|
||||
) -> dict[str, Any]:
|
||||
result = await db.execute(
|
||||
select(Conversation).where(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
conversation = result.scalar_one_or_none()
|
||||
if conversation is None:
|
||||
raise HTTPException(status_code=404, detail="对话不存在")
|
||||
snapshot = _extract_continuity_snapshot(conversation.agent_state)
|
||||
if snapshot is None:
|
||||
raise HTTPException(status_code=404, detail="当前会话暂无可视化运行时数据")
|
||||
return snapshot
|
||||
|
||||
|
||||
def _coerce_event_payload(event: dict[str, Any]) -> AgentVisibilityEventOut:
|
||||
return AgentVisibilityEventOut.model_validate(event)
|
||||
|
||||
|
||||
def _filter_events(
|
||||
events: list[dict[str, Any]],
|
||||
*,
|
||||
agent_id: str | None,
|
||||
thread_id: str | None,
|
||||
event_type: str | None,
|
||||
started_after: datetime | None,
|
||||
ended_before: datetime | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
filtered: list[dict[str, Any]] = []
|
||||
for event in events:
|
||||
if agent_id and event.get("agent_id") != agent_id:
|
||||
continue
|
||||
if thread_id and event.get("thread_id") != thread_id:
|
||||
continue
|
||||
if event_type and event.get("event_type") != event_type:
|
||||
continue
|
||||
timestamp_raw = event.get("timestamp")
|
||||
timestamp = None
|
||||
if isinstance(timestamp_raw, str):
|
||||
try:
|
||||
timestamp = datetime.fromisoformat(timestamp_raw.replace("Z", "+00:00"))
|
||||
except ValueError:
|
||||
timestamp = None
|
||||
if started_after and timestamp and timestamp < started_after:
|
||||
continue
|
||||
if ended_before and timestamp and timestamp > ended_before:
|
||||
continue
|
||||
filtered.append(event)
|
||||
return filtered
|
||||
|
||||
|
||||
def _summarize_tasks(tasks: list[dict[str, Any]], task_results: list[dict[str, Any]]) -> list[AgentVisibilityTaskSummaryOut]:
|
||||
result_by_task_id = {item.get("task_id"): item for item in task_results}
|
||||
summaries: list[AgentVisibilityTaskSummaryOut] = []
|
||||
for task in tasks:
|
||||
task_id = str(task.get("task_id") or "")
|
||||
result = result_by_task_id.get(task_id) or {}
|
||||
evidence = result.get("evidence") or task.get("evidence") or []
|
||||
summaries.append(
|
||||
AgentVisibilityTaskSummaryOut(
|
||||
task_id=task_id,
|
||||
role=task.get("role"),
|
||||
owner_agent_id=task.get("owner_agent_id") or result.get("owner_agent_id"),
|
||||
status=result.get("status") or task.get("status"),
|
||||
summary=result.get("summary") or task.get("result_summary"),
|
||||
evidence_count=len(evidence),
|
||||
)
|
||||
)
|
||||
return summaries
|
||||
|
||||
|
||||
def _build_topology_nodes(
|
||||
state: dict[str, Any],
|
||||
tasks: list[dict[str, Any]],
|
||||
task_results: list[dict[str, Any]],
|
||||
) -> list[AgentVisibilityTopologyNodeOut]:
|
||||
task_counts: dict[str, int] = {}
|
||||
completed_counts: dict[str, int] = {}
|
||||
for task in tasks:
|
||||
owner = str(task.get("owner_agent_id") or "")
|
||||
if owner:
|
||||
task_counts[owner] = task_counts.get(owner, 0) + 1
|
||||
for result in task_results:
|
||||
owner = str(result.get("owner_agent_id") or "")
|
||||
if owner and result.get("status") == "completed":
|
||||
completed_counts[owner] = completed_counts.get(owner, 0) + 1
|
||||
|
||||
root_agent_id = str(state.get("root_agent_id") or state.get("agent_id") or "") or None
|
||||
current_agent = str(state.get("current_agent") or "") or None
|
||||
parent_agent_id = str(state.get("parent_agent_id") or "") or None
|
||||
nodes: dict[str, AgentVisibilityTopologyNodeOut] = {}
|
||||
if root_agent_id:
|
||||
nodes[root_agent_id] = AgentVisibilityTopologyNodeOut(
|
||||
agent_id=root_agent_id,
|
||||
role=root_agent_id.split("-")[0],
|
||||
parent_agent_id=parent_agent_id if root_agent_id != state.get("agent_id") else None,
|
||||
source="root",
|
||||
task_count=task_counts.get(root_agent_id, 0),
|
||||
completed_task_count=completed_counts.get(root_agent_id, 0),
|
||||
)
|
||||
for agent_id in state.get("spawned_agent_ids") or []:
|
||||
agent_id = str(agent_id)
|
||||
nodes[agent_id] = AgentVisibilityTopologyNodeOut(
|
||||
agent_id=agent_id,
|
||||
role=agent_id.split("-")[0],
|
||||
parent_agent_id=root_agent_id,
|
||||
source="spawned",
|
||||
task_count=task_counts.get(agent_id, 0),
|
||||
completed_task_count=completed_counts.get(agent_id, 0),
|
||||
)
|
||||
if current_agent and current_agent not in nodes:
|
||||
nodes[current_agent] = AgentVisibilityTopologyNodeOut(
|
||||
agent_id=current_agent,
|
||||
role=current_agent.split("-")[0],
|
||||
parent_agent_id=None if current_agent == root_agent_id else root_agent_id,
|
||||
source="current",
|
||||
task_count=task_counts.get(current_agent, 0),
|
||||
completed_task_count=completed_counts.get(current_agent, 0),
|
||||
)
|
||||
return list(nodes.values())
|
||||
|
||||
|
||||
def _estimate_runtime_cost(input_tokens: int, output_tokens: int) -> float | None:
|
||||
return estimate_token_cost(input_tokens, output_tokens)
|
||||
|
||||
|
||||
def _build_cost_summary(
|
||||
state: dict[str, Any],
|
||||
*,
|
||||
conversation_id: str,
|
||||
) -> AgentVisibilityCostSummaryOut:
|
||||
input_tokens = int(state.get("input_tokens") or 0)
|
||||
output_tokens = int(state.get("output_tokens") or 0)
|
||||
estimated_cost = _estimate_runtime_cost(input_tokens, output_tokens)
|
||||
thresholds = coerce_cost_thresholds(state.get("cost_thresholds"))
|
||||
total_budget_warning = bool(state.get("budget_warning") or False) or is_cost_budget_warning(
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
estimated_cost,
|
||||
thresholds,
|
||||
)
|
||||
|
||||
by_agent_items: list[AgentVisibilityCostByAgentOut] = []
|
||||
for agent_id, payload in dict(state.get("cost_by_agent") or {}).items():
|
||||
payload_dict = dict(payload or {})
|
||||
agent_input_tokens = int(payload_dict.get("input_tokens") or 0)
|
||||
agent_output_tokens = int(payload_dict.get("output_tokens") or 0)
|
||||
agent_estimated_cost = payload_dict.get("estimated_cost")
|
||||
if agent_estimated_cost is None:
|
||||
agent_estimated_cost = _estimate_runtime_cost(agent_input_tokens, agent_output_tokens)
|
||||
by_agent_items.append(
|
||||
AgentVisibilityCostByAgentOut(
|
||||
agent_id=str(payload_dict.get("agent_id") or agent_id),
|
||||
input_tokens=agent_input_tokens,
|
||||
output_tokens=agent_output_tokens,
|
||||
total_tokens=int(payload_dict.get("total_tokens") or (agent_input_tokens + agent_output_tokens)),
|
||||
estimated_cost=agent_estimated_cost,
|
||||
budget_warning=bool(payload_dict.get("budget_warning") or False),
|
||||
)
|
||||
)
|
||||
by_agent_items.sort(key=lambda item: item.total_tokens, reverse=True)
|
||||
|
||||
return AgentVisibilityCostSummaryOut(
|
||||
conversation_id=conversation_id,
|
||||
total=AgentVisibilityCostOut(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
total_tokens=input_tokens + output_tokens,
|
||||
estimated_cost=estimated_cost,
|
||||
budget_warning=total_budget_warning,
|
||||
),
|
||||
thresholds=thresholds,
|
||||
by_agent=by_agent_items,
|
||||
)
|
||||
|
||||
|
||||
def _build_tool_governance(
|
||||
state: dict[str, Any],
|
||||
*,
|
||||
conversation_id: str,
|
||||
) -> AgentVisibilityToolGovernanceOut:
|
||||
indexes = load_builtin_registry_indexes()
|
||||
tool_outcomes = [dict(item) for item in state.get("tool_outcomes") or [] if isinstance(item, dict)]
|
||||
usage_count_by_tool: dict[str, int] = {}
|
||||
last_result_preview_by_tool: dict[str, str | None] = {}
|
||||
for item in tool_outcomes:
|
||||
tool_name = str(item.get("tool_name") or "")
|
||||
if tool_name == "search_web":
|
||||
tool_name = "web_search"
|
||||
if not tool_name:
|
||||
continue
|
||||
usage_count_by_tool[tool_name] = usage_count_by_tool.get(tool_name, 0) + 1
|
||||
preview = item.get("result_preview")
|
||||
if isinstance(preview, str) and preview:
|
||||
last_result_preview_by_tool[tool_name] = preview
|
||||
|
||||
items = [
|
||||
AgentVisibilityToolGovernanceItemOut(
|
||||
capability_id=capability.capability_id,
|
||||
tool_name=capability.tool_name,
|
||||
permission_class=capability.permission_class.value,
|
||||
side_effect_scope=capability.side_effect_scope.value,
|
||||
supports_retry=capability.supports_retry,
|
||||
idempotent=capability.idempotent,
|
||||
safe_for_parallel_use=capability.safe_for_parallel_use,
|
||||
requires_confirmation=capability.requires_confirmation,
|
||||
usage_count=usage_count_by_tool.get(capability.tool_name, 0),
|
||||
last_result_preview=last_result_preview_by_tool.get(capability.tool_name),
|
||||
)
|
||||
for capability in indexes.capability_by_id.values()
|
||||
]
|
||||
items.sort(key=lambda item: (-item.usage_count, item.tool_name))
|
||||
|
||||
return AgentVisibilityToolGovernanceOut(
|
||||
conversation_id=conversation_id,
|
||||
total_tools=len(items),
|
||||
used_tools=sum(1 for item in items if item.usage_count > 0),
|
||||
items=items,
|
||||
upgrade_candidates=[
|
||||
"worktree_manager",
|
||||
"cost_inspector",
|
||||
"runtime_event_drilldown",
|
||||
"tool_policy_explorer",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _build_runtime_summary(
|
||||
state: dict[str, Any],
|
||||
*,
|
||||
conversation_id: str,
|
||||
) -> AgentVisibilityRuntimeSummaryOut:
|
||||
tasks = [dict(item) for item in state.get("active_tasks") or []]
|
||||
task_results = [dict(item) for item in state.get("task_results") or []]
|
||||
topology_nodes = _build_topology_nodes(state, tasks, task_results)
|
||||
cost_summary = _build_cost_summary(state, conversation_id=conversation_id)
|
||||
input_tokens = cost_summary.total.input_tokens
|
||||
output_tokens = cost_summary.total.output_tokens
|
||||
recent_events_raw = [dict(item) for item in (state.get("event_trace") or [])[-10:]]
|
||||
isolation_mode = str(state.get("isolation_mode") or "none")
|
||||
|
||||
return AgentVisibilityRuntimeSummaryOut(
|
||||
conversation_id=conversation_id,
|
||||
execution_mode=state.get("execution_mode"),
|
||||
current_phase=state.get("current_phase"),
|
||||
current_checkpoint=state.get("current_checkpoint"),
|
||||
phase_history=list(state.get("phase_history") or []),
|
||||
checkpoint_history=list(state.get("checkpoint_history") or []),
|
||||
verifier=AgentVisibilityVerifierOut(
|
||||
conversation_id=conversation_id,
|
||||
status=state.get("verification_status"),
|
||||
summary=state.get("verification_summary"),
|
||||
evidence=list(state.get("verification_evidence") or []),
|
||||
),
|
||||
isolation=AgentVisibilityIsolationOut(
|
||||
mode=isolation_mode,
|
||||
isolation_id=state.get("isolation_id"),
|
||||
workspace_path=state.get("isolation_workspace_path"),
|
||||
parent_conversation_id=state.get("isolation_parent_conversation_id") or state.get("parent_conversation_id"),
|
||||
metadata=dict(state.get("isolation_metadata") or {}),
|
||||
),
|
||||
cost=cost_summary.total,
|
||||
topology_node_count=len(topology_nodes),
|
||||
active_task_count=len(tasks),
|
||||
completed_task_count=sum(1 for item in task_results if item.get("status") == "completed"),
|
||||
recent_events=[_coerce_event_payload(item) for item in recent_events_raw],
|
||||
)
|
||||
|
||||
|
||||
def record_agent_call(agent_id: str):
|
||||
@@ -83,6 +402,7 @@ async def get_agent_hierarchy_stats(
|
||||
@router.get("/config/{agent_id}", response_model=AgentConfigOut)
|
||||
async def get_agent_config(
|
||||
agent_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(select(Agent).where(Agent.role == agent_id))
|
||||
@@ -172,12 +492,189 @@ async def update_agent_config(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/visibility/events", response_model=AgentVisibilityEventsResponse)
|
||||
async def get_visibility_events(
|
||||
conversation_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
agent_id: str | None = None,
|
||||
thread_id: str | None = None,
|
||||
event_type: str | None = None,
|
||||
started_after: str | None = None,
|
||||
ended_before: str | None = None,
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
):
|
||||
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
|
||||
events = [dict(item) for item in state.get("event_trace") or []]
|
||||
filtered = _filter_events(
|
||||
events,
|
||||
agent_id=agent_id,
|
||||
thread_id=thread_id,
|
||||
event_type=event_type,
|
||||
started_after=_parse_visibility_datetime(started_after),
|
||||
ended_before=_parse_visibility_datetime(ended_before),
|
||||
)
|
||||
paged = filtered[offset:offset + limit]
|
||||
return AgentVisibilityEventsResponse(
|
||||
conversation_id=conversation_id,
|
||||
total=len(filtered),
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
items=[_coerce_event_payload(item) for item in paged],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/visibility/topology", response_model=AgentVisibilityTopologyOut)
|
||||
async def get_visibility_topology(
|
||||
conversation_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
|
||||
tasks = [dict(item) for item in state.get("active_tasks") or []]
|
||||
task_results = [dict(item) for item in state.get("task_results") or []]
|
||||
nodes = _build_topology_nodes(state, tasks, task_results)
|
||||
root_agent_id = str(state.get("root_agent_id") or state.get("agent_id") or "") or None
|
||||
edges = [
|
||||
{"parent_agent_id": root_agent_id, "child_agent_id": node.agent_id}
|
||||
for node in nodes
|
||||
if node.parent_agent_id and root_agent_id and node.agent_id != root_agent_id
|
||||
]
|
||||
return AgentVisibilityTopologyOut(
|
||||
conversation_id=conversation_id,
|
||||
root_agent_id=root_agent_id,
|
||||
current_agent=str(state.get("current_agent") or "") or None,
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
tasks=_summarize_tasks(tasks, task_results),
|
||||
task_hierarchy=dict(state.get("task_hierarchy") or {}),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/visibility/tasks/{task_id}/evidence", response_model=AgentVisibilityEvidenceOut)
|
||||
async def get_visibility_task_evidence(
|
||||
task_id: str,
|
||||
conversation_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
|
||||
tasks = [dict(item) for item in state.get("active_tasks") or []]
|
||||
task = next((item for item in tasks if item.get("task_id") == task_id), None)
|
||||
task_results = [dict(item) for item in state.get("task_results") or []]
|
||||
result = next((item for item in task_results if item.get("task_id") == task_id), None)
|
||||
if task is None and result is None:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
tool_outcomes = [
|
||||
dict(evidence)
|
||||
for evidence in (result or {}).get("evidence") or []
|
||||
if isinstance(evidence, dict) and evidence.get("tool_name")
|
||||
]
|
||||
verification_entry = next(
|
||||
(
|
||||
dict(evidence)
|
||||
for evidence in (result or {}).get("evidence") or []
|
||||
if isinstance(evidence, dict) and evidence.get("type") == "verification"
|
||||
),
|
||||
None,
|
||||
)
|
||||
verifier = {
|
||||
"status": (verification_entry or {}).get("status"),
|
||||
"summary": (verification_entry or {}).get("summary"),
|
||||
"evidence": [dict(item) for item in state.get("verification_evidence") or [] if item.get("task_id") == task_id],
|
||||
}
|
||||
return AgentVisibilityEvidenceOut(
|
||||
conversation_id=conversation_id,
|
||||
task_id=task_id,
|
||||
task=task,
|
||||
result=result,
|
||||
tool_outcomes=tool_outcomes,
|
||||
verifier=verifier,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/visibility/threads/{thread_id}/messages", response_model=AgentVisibilityThreadOut)
|
||||
async def get_visibility_thread_messages(
|
||||
thread_id: str,
|
||||
conversation_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
|
||||
items = [
|
||||
AgentVisibilityThreadMessageOut.model_validate(item)
|
||||
for item in state.get("message_trace") or []
|
||||
if isinstance(item, dict) and item.get("thread_id") == thread_id
|
||||
]
|
||||
if not items:
|
||||
raise HTTPException(status_code=404, detail="线程不存在")
|
||||
return AgentVisibilityThreadOut(
|
||||
conversation_id=conversation_id,
|
||||
thread_id=thread_id,
|
||||
total=len(items),
|
||||
items=items,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/visibility/verifier", response_model=AgentVisibilityVerifierOut)
|
||||
async def get_visibility_verifier(
|
||||
conversation_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
|
||||
return AgentVisibilityVerifierOut(
|
||||
conversation_id=conversation_id,
|
||||
status=state.get("verification_status"),
|
||||
summary=state.get("verification_summary"),
|
||||
evidence=list(state.get("verification_evidence") or []),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/visibility/runtime-summary", response_model=AgentVisibilityRuntimeSummaryOut)
|
||||
async def get_visibility_runtime_summary(
|
||||
conversation_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
|
||||
return _build_runtime_summary(state, conversation_id=conversation_id)
|
||||
|
||||
|
||||
@router.get("/visibility/cost", response_model=AgentVisibilityCostSummaryOut)
|
||||
async def get_visibility_cost(
|
||||
conversation_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
|
||||
return _build_cost_summary(state, conversation_id=conversation_id)
|
||||
|
||||
|
||||
@router.get("/visibility/tools", response_model=AgentVisibilityToolGovernanceOut)
|
||||
async def get_visibility_tools(
|
||||
conversation_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
|
||||
return _build_tool_governance(state, conversation_id=conversation_id)
|
||||
|
||||
|
||||
@router.post("", response_model=AgentOut, status_code=201)
|
||||
async def create_agent(
|
||||
data: AgentCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(status_code=403, detail="仅管理员可创建 Agent")
|
||||
if not data.spawn_permission:
|
||||
raise HTTPException(status_code=400, detail="缺少 spawn_permission,禁止直接创建 runtime agent")
|
||||
if data.role not in ALLOWED_AGENT_ROLES:
|
||||
raise HTTPException(status_code=400, detail="不支持的 Agent 角色")
|
||||
|
||||
agent = Agent(
|
||||
name=data.name,
|
||||
role=data.role,
|
||||
@@ -193,6 +690,7 @@ async def create_agent(
|
||||
@router.get("/{agent_id}", response_model=AgentOut)
|
||||
async def get_agent(
|
||||
agent_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(select(Agent).where(Agent.id == agent_id))
|
||||
|
||||
113
backend/app/routers/agent_sessions.py
Normal file
113
backend/app/routers/agent_sessions.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Agent Session API 路由 - Phase 10.3"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.agents.session.manager import AgentSession, create_agent_session, get_agent_session
|
||||
|
||||
router = APIRouter(prefix="/api/agent/sessions", tags=["Agent Sessions"])
|
||||
|
||||
|
||||
@router.post("", response_model=dict[str, Any])
|
||||
async def create_session(
|
||||
user_id: str | None = None,
|
||||
parent_session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""创建新会话"""
|
||||
session = create_agent_session(
|
||||
user_id=user_id,
|
||||
parent_session_id=parent_session_id,
|
||||
)
|
||||
return await session.initialize()
|
||||
|
||||
|
||||
@router.get("/{session_id}", response_model=dict[str, Any])
|
||||
async def get_session(session_id: str) -> dict[str, Any]:
|
||||
"""获取会话信息"""
|
||||
session = get_agent_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
||||
return await session.get_session_summary()
|
||||
|
||||
|
||||
@router.post("/{session_id}/message", response_model=dict[str, str])
|
||||
async def process_message(
|
||||
session_id: str,
|
||||
message: str,
|
||||
response: str,
|
||||
) -> dict[str, str]:
|
||||
"""处理消息"""
|
||||
session = get_agent_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
||||
await session.process_message(message, response)
|
||||
return {"status": "recorded", "session_id": session_id}
|
||||
|
||||
|
||||
@router.post("/{session_id}/spawn", response_model=dict[str, Any])
|
||||
async def spawn_child_session(
|
||||
session_id: str,
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""创建子会话"""
|
||||
session = get_agent_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
||||
child = await session.spawn_child_session(user_id=user_id)
|
||||
return await child.get_session_summary()
|
||||
|
||||
|
||||
@router.get("/{session_id}/history", response_model=dict[str, Any])
|
||||
async def get_session_history(session_id: str) -> dict[str, Any]:
|
||||
"""获取会话历史"""
|
||||
session = get_agent_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"history": session.get_history(),
|
||||
"count": len(session._history),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{session_id}/persist", response_model=dict[str, str])
|
||||
async def persist_session(session_id: str) -> dict[str, str]:
|
||||
"""持久化会话"""
|
||||
session = get_agent_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
||||
success = await session.persist()
|
||||
if success:
|
||||
return {"status": "persisted", "session_id": session_id}
|
||||
raise HTTPException(status_code=500, detail="Failed to persist session")
|
||||
|
||||
|
||||
@router.post("/{session_id}/metadata", response_model=dict[str, Any])
|
||||
async def set_session_metadata(
|
||||
session_id: str,
|
||||
key: str,
|
||||
value: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""设置会话元数据"""
|
||||
session = get_agent_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
||||
session.add_metadata(key, value)
|
||||
await session.persist()
|
||||
return {"key": key, "value": value}
|
||||
|
||||
|
||||
@router.get("/{session_id}/metadata/{key}", response_model=dict[str, Any])
|
||||
async def get_session_metadata(
|
||||
session_id: str,
|
||||
key: str,
|
||||
) -> dict[str, Any]:
|
||||
"""获取会话元数据"""
|
||||
session = get_agent_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
||||
value = session.get_metadata(key)
|
||||
if value is None:
|
||||
raise HTTPException(status_code=404, detail=f"Metadata key '{key}' not found")
|
||||
return {"key": key, "value": value}
|
||||
126
backend/app/routers/agent_skills.py
Normal file
126
backend/app/routers/agent_skills.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Agent Skills API 路由 - Phase 9.6
|
||||
|
||||
使用新的 SkillRegistry (file-based) 而不是 DB-based skill 系统。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.agents.skills.registry import get_skill_registry, SkillRegistry
|
||||
|
||||
router = APIRouter(prefix="/api/agent/skills", tags=["Agent Skills"])
|
||||
|
||||
|
||||
def _skill_to_dict(skill) -> dict[str, Any]:
|
||||
"""将 SkillMetadata 转换为字典"""
|
||||
return {
|
||||
"name": skill.name,
|
||||
"description": skill.description,
|
||||
"tags": skill.tags,
|
||||
"triggers": skill.triggers,
|
||||
"enabled": skill.enabled,
|
||||
"content_preview": skill.content[:200] + "..."
|
||||
if len(skill.content) > 200
|
||||
else skill.content,
|
||||
}
|
||||
|
||||
|
||||
@router.get("", response_model=dict[str, Any])
|
||||
async def list_agent_skills() -> dict[str, Any]:
|
||||
"""列出所有已加载的 Agent Skills"""
|
||||
registry = get_skill_registry()
|
||||
skills = registry.list_all()
|
||||
return {
|
||||
"skills": [_skill_to_dict(s) for s in skills],
|
||||
"count": len(skills),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/search", response_model=dict[str, Any])
|
||||
async def search_agent_skills(
|
||||
query: str,
|
||||
) -> dict[str, Any]:
|
||||
"""搜索 Skills"""
|
||||
registry = get_skill_registry()
|
||||
results = registry.search(query)
|
||||
return {
|
||||
"skills": [_skill_to_dict(s) for s in results],
|
||||
"count": len(results),
|
||||
"query": query,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{skill_name}", response_model=dict[str, Any])
|
||||
async def get_agent_skill(skill_name: str) -> dict[str, Any]:
|
||||
"""获取指定 Skill 详情"""
|
||||
registry = get_skill_registry()
|
||||
skill = registry.get_skill(skill_name)
|
||||
if not skill:
|
||||
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
||||
return {
|
||||
"name": skill.name,
|
||||
"description": skill.description,
|
||||
"tags": skill.tags,
|
||||
"triggers": skill.triggers,
|
||||
"enabled": skill.enabled,
|
||||
"content": skill.content,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{skill_name}/context", response_model=dict[str, str])
|
||||
async def get_skill_context(skill_name: str) -> dict[str, str]:
|
||||
"""获取 Skill 上下文字符串"""
|
||||
registry = get_skill_registry()
|
||||
context = registry.get_skill_context([skill_name])
|
||||
if not context:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Skill '{skill_name}' not found or not enabled"
|
||||
)
|
||||
return {"skill_name": skill_name, "context": context}
|
||||
|
||||
|
||||
@router.post("/context/batch", response_model=dict[str, str])
|
||||
async def get_batch_skill_context(
|
||||
skill_names: list[str],
|
||||
) -> dict[str, str]:
|
||||
"""批量获取多个 Skill 的上下文"""
|
||||
registry = get_skill_registry()
|
||||
context = registry.get_skill_context(skill_names)
|
||||
return {"skills": skill_names, "context": context}
|
||||
|
||||
|
||||
@router.post("/reload", response_model=dict[str, Any])
|
||||
async def reload_skills(
|
||||
skills_dir: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""重新加载所有 Skills"""
|
||||
registry = get_skill_registry()
|
||||
# 清除旧 skills
|
||||
for name in list(registry._skills.keys()):
|
||||
registry.unregister(name)
|
||||
# 重新加载
|
||||
count = registry.load_all(skills_dir)
|
||||
return {"loaded": count, "message": f"Loaded {count} skills"}
|
||||
|
||||
|
||||
@router.post("/{skill_name}/enable", response_model=dict[str, str])
|
||||
async def enable_skill(skill_name: str) -> dict[str, str]:
|
||||
"""启用 Skill"""
|
||||
registry = get_skill_registry()
|
||||
skill = registry.get_skill(skill_name)
|
||||
if not skill:
|
||||
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
||||
skill.enabled = True
|
||||
return {"status": "enabled", "skill_name": skill_name}
|
||||
|
||||
|
||||
@router.post("/{skill_name}/disable", response_model=dict[str, str])
|
||||
async def disable_skill(skill_name: str) -> dict[str, str]:
|
||||
"""禁用 Skill"""
|
||||
registry = get_skill_registry()
|
||||
skill = registry.get_skill(skill_name)
|
||||
if not skill:
|
||||
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
||||
skill.enabled = False
|
||||
return {"status": "disabled", "skill_name": skill_name}
|
||||
241
backend/app/routers/hooks.py
Normal file
241
backend/app/routers/hooks.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""Hook API 路由 - Phase 7.5"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.agents.tools.hooks import HookType
|
||||
from app.agents.tools.hooks.builtins import (
|
||||
AuditLogHook,
|
||||
DangerousConfirmationHook,
|
||||
SecurityScanHook,
|
||||
)
|
||||
from app.agents.tools.hooks.config import (
|
||||
HookConfigEntry,
|
||||
get_hook_config_persistence,
|
||||
)
|
||||
from app.agents.tools.hooks.manager import get_hook_manager
|
||||
|
||||
router = APIRouter(prefix="/api/hooks", tags=["Hooks"])
|
||||
|
||||
|
||||
class HookInfo(BaseModel):
|
||||
"""Hook 信息"""
|
||||
|
||||
name: str
|
||||
hook_type: str
|
||||
description: str
|
||||
builtin: bool
|
||||
|
||||
|
||||
class HookConfigUpdate(BaseModel):
|
||||
"""更新 Hook 配置"""
|
||||
|
||||
entries: list[HookConfigEntry]
|
||||
|
||||
|
||||
class HookConfigResponse(BaseModel):
|
||||
"""Hook 配置响应"""
|
||||
|
||||
entries: list[dict[str, Any]]
|
||||
count: int
|
||||
|
||||
|
||||
class HookStatusResponse(BaseModel):
|
||||
"""Hook 状态响应"""
|
||||
|
||||
name: str
|
||||
enabled: bool
|
||||
hook_type: str
|
||||
registered: bool
|
||||
|
||||
|
||||
# 内置 Hook 注册表
|
||||
BUILTIN_HOOKS: dict[str, dict[str, str]] = {
|
||||
"audit_log": {
|
||||
"name": "audit_log",
|
||||
"hook_type": "pre_tool_use,post_tool_use,tool_error",
|
||||
"description": "审计日志 Hook - 记录所有工具调用",
|
||||
"class": "AuditLogHook",
|
||||
},
|
||||
"dangerous_confirmation": {
|
||||
"name": "dangerous_confirmation",
|
||||
"hook_type": "pre_tool_use",
|
||||
"description": "危险操作确认 Hook - 拦截危险工具调用",
|
||||
"class": "DangerousConfirmationHook",
|
||||
},
|
||||
"security_scan": {
|
||||
"name": "security_scan",
|
||||
"hook_type": "post_tool_use",
|
||||
"description": "安全扫描 Hook - 检测敏感信息泄露",
|
||||
"class": "SecurityScanHook",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get("/available", response_model=list[HookInfo])
|
||||
async def list_available_hooks() -> list[HookInfo]:
|
||||
"""列出所有可用的内置 Hook"""
|
||||
return [
|
||||
HookInfo(
|
||||
name=info["name"],
|
||||
hook_type=info["hook_type"],
|
||||
description=info["description"],
|
||||
builtin=True,
|
||||
)
|
||||
for info in BUILTIN_HOOKS.values()
|
||||
]
|
||||
|
||||
|
||||
@router.get("/config", response_model=HookConfigResponse)
|
||||
async def get_hook_config() -> HookConfigResponse:
|
||||
"""获取当前 Hook 配置"""
|
||||
persistence = get_hook_config_persistence()
|
||||
entries = persistence.load_config()
|
||||
return HookConfigResponse(
|
||||
entries=[vars(e) if isinstance(e, HookConfigEntry) else e for e in entries],
|
||||
count=len(entries),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/config", response_model=HookConfigResponse)
|
||||
async def update_hook_config(
|
||||
entries: list[HookConfigEntry],
|
||||
) -> HookConfigResponse:
|
||||
"""更新 Hook 配置"""
|
||||
persistence = get_hook_config_persistence()
|
||||
success = persistence.save_config(entries)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to save hook config")
|
||||
|
||||
# 应用配置到 HookManager
|
||||
manager = get_hook_manager()
|
||||
manager.clear() # 清除旧配置
|
||||
persistence.apply_config() # 应用新配置
|
||||
|
||||
return HookConfigResponse(
|
||||
entries=[vars(e) if isinstance(e, HookConfigEntry) else e for e in entries],
|
||||
count=len(entries),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/apply-config", response_model=dict[str, Any])
|
||||
async def apply_hook_config() -> dict[str, Any]:
|
||||
"""应用配置文件到 HookManager"""
|
||||
persistence = get_hook_config_persistence()
|
||||
manager = get_hook_manager()
|
||||
manager.clear()
|
||||
count = persistence.apply_config()
|
||||
return {"applied": count, "message": f"Applied {count} hook configurations"}
|
||||
|
||||
|
||||
@router.get("/status", response_model=list[HookStatusResponse])
|
||||
async def get_hook_status() -> list[HookStatusResponse]:
|
||||
"""获取所有已注册 Hook 的状态"""
|
||||
manager = get_hook_manager()
|
||||
all_hooks = manager.list_all()
|
||||
|
||||
# 按名称索引已注册的 hooks
|
||||
registered: dict[str, dict[str, Any]] = {}
|
||||
for hook in all_hooks:
|
||||
registered[hook.name] = {
|
||||
"name": hook.name,
|
||||
"enabled": hook.enabled,
|
||||
"hook_type": hook.hook_type.value,
|
||||
"registered": True,
|
||||
}
|
||||
|
||||
# 合并内置 Hook 信息
|
||||
result: list[HookStatusResponse] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
# 先添加已注册的
|
||||
for hook in all_hooks:
|
||||
result.append(
|
||||
HookStatusResponse(
|
||||
name=hook.name,
|
||||
enabled=hook.enabled,
|
||||
hook_type=hook.hook_type.value,
|
||||
registered=True,
|
||||
)
|
||||
)
|
||||
seen.add(hook.name)
|
||||
|
||||
# 再添加内置但未注册的
|
||||
for name, info in BUILTIN_HOOKS.items():
|
||||
if name not in seen:
|
||||
result.append(
|
||||
HookStatusResponse(
|
||||
name=name,
|
||||
enabled=False,
|
||||
hook_type=info["hook_type"],
|
||||
registered=False,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/{name}/enable", response_model=dict[str, str])
|
||||
async def enable_hook(name: str) -> dict[str, str]:
|
||||
"""启用指定 Hook"""
|
||||
manager = get_hook_manager()
|
||||
if manager.enable(name):
|
||||
return {"status": "enabled", "name": name}
|
||||
raise HTTPException(status_code=404, detail=f"Hook '{name}' not found")
|
||||
|
||||
|
||||
@router.post("/{name}/disable", response_model=dict[str, str])
|
||||
async def disable_hook(name: str) -> dict[str, str]:
|
||||
"""禁用指定 Hook"""
|
||||
manager = get_hook_manager()
|
||||
if manager.disable(name):
|
||||
return {"status": "disabled", "name": name}
|
||||
raise HTTPException(status_code=404, detail=f"Hook '{name}' not found")
|
||||
|
||||
|
||||
@router.post("/register-builtin", response_model=dict[str, str])
|
||||
async def register_builtin_hook(
|
||||
name: str,
|
||||
hook_type: str = "pre_tool_use",
|
||||
) -> dict[str, str]:
|
||||
"""注册内置 Hook 到 HookManager"""
|
||||
from app.agents.tools.hooks.types import HookDefinition, HookTrigger
|
||||
|
||||
manager = get_hook_manager()
|
||||
|
||||
if name == "audit_log":
|
||||
hook_instance = AuditLogHook()
|
||||
handler = hook_instance.pre_tool_use
|
||||
hook_types = [HookType.PRE_TOOL_USE, HookType.POST_TOOL_USE, HookType.TOOL_ERROR]
|
||||
elif name == "dangerous_confirmation":
|
||||
hook_instance = DangerousConfirmationHook()
|
||||
handler = hook_instance.pre_tool_use
|
||||
hook_types = [HookType.PRE_TOOL_USE]
|
||||
elif name == "security_scan":
|
||||
hook_instance = SecurityScanHook()
|
||||
handler = hook_instance.post_tool_use
|
||||
hook_types = [HookType.POST_TOOL_USE]
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown builtin hook: {name}")
|
||||
|
||||
registered = []
|
||||
for ht in hook_types:
|
||||
hook_def = HookDefinition(
|
||||
name=f"{name}_{ht.value}",
|
||||
hook_type=ht,
|
||||
trigger=HookTrigger(),
|
||||
handler=handler,
|
||||
priority=0,
|
||||
enabled=True,
|
||||
description=f"Built-in {name} hook",
|
||||
)
|
||||
manager.register(hook_def)
|
||||
registered.append(ht.value)
|
||||
|
||||
return {
|
||||
"status": "registered",
|
||||
"name": name,
|
||||
"hook_types": ", ".join(registered),
|
||||
}
|
||||
222
backend/app/routers/plugins.py
Normal file
222
backend/app/routers/plugins.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""Plugin API 路由 - Phase 8.6"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.agents.plugins import get_plugin_manager, PluginManifest
|
||||
|
||||
router = APIRouter(prefix="/api/plugins", tags=["Plugins"])
|
||||
|
||||
|
||||
class PluginInfo(BaseModel):
|
||||
"""插件信息"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
version: str
|
||||
description: str
|
||||
author: str
|
||||
enabled: bool
|
||||
main: str
|
||||
|
||||
|
||||
class PluginInstallRequest(BaseModel):
|
||||
"""插件安装请求"""
|
||||
|
||||
plugin_path: str
|
||||
|
||||
|
||||
class PluginListResponse(BaseModel):
|
||||
"""插件列表响应"""
|
||||
|
||||
plugins: list[dict[str, Any]]
|
||||
count: int
|
||||
|
||||
|
||||
# 全局插件市场(简单内存实现)
|
||||
_plugin_marketplace: list[dict[str, str]] = []
|
||||
|
||||
|
||||
def _manifest_to_dict(manifest: PluginManifest, enabled: bool) -> dict[str, Any]:
|
||||
"""将 PluginManifest 转换为字典"""
|
||||
return {
|
||||
"id": manifest.id,
|
||||
"name": manifest.name,
|
||||
"version": manifest.version,
|
||||
"description": manifest.description,
|
||||
"author": manifest.author,
|
||||
"enabled": enabled,
|
||||
"main": manifest.main,
|
||||
}
|
||||
|
||||
|
||||
@router.get("", response_model=PluginListResponse)
|
||||
async def list_plugins() -> PluginListResponse:
|
||||
"""列出所有已安装的插件"""
|
||||
manager = get_plugin_manager()
|
||||
plugins = manager.list_plugins()
|
||||
result = []
|
||||
for p in plugins:
|
||||
enabled = manager.is_enabled(p.id)
|
||||
result.append(_manifest_to_dict(p, enabled))
|
||||
return PluginListResponse(plugins=result, count=len(result))
|
||||
|
||||
|
||||
@router.get("/{plugin_id}", response_model=dict[str, Any])
|
||||
async def get_plugin(plugin_id: str) -> dict[str, Any]:
|
||||
"""获取指定插件信息"""
|
||||
manager = get_plugin_manager()
|
||||
manifest = manager.get_plugin(plugin_id)
|
||||
if not manifest:
|
||||
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found")
|
||||
enabled = manager.is_enabled(plugin_id)
|
||||
return _manifest_to_dict(manifest, enabled)
|
||||
|
||||
|
||||
@router.post("/install", response_model=dict[str, str])
|
||||
async def install_plugin(request: PluginInstallRequest) -> dict[str, str]:
|
||||
"""安装插件"""
|
||||
manager = get_plugin_manager()
|
||||
if not os.path.exists(request.plugin_path):
|
||||
raise HTTPException(status_code=400, detail="Plugin path does not exist")
|
||||
|
||||
if manager.install(request.plugin_path):
|
||||
return {"status": "installed", "path": request.plugin_path}
|
||||
raise HTTPException(status_code=500, detail="Failed to install plugin")
|
||||
|
||||
|
||||
@router.post("/{plugin_id}/enable", response_model=dict[str, str])
|
||||
async def enable_plugin(plugin_id: str) -> dict[str, str]:
|
||||
"""启用插件"""
|
||||
manager = get_plugin_manager()
|
||||
if manager.enable(plugin_id):
|
||||
return {"status": "enabled", "plugin_id": plugin_id}
|
||||
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found")
|
||||
|
||||
|
||||
@router.post("/{plugin_id}/disable", response_model=dict[str, str])
|
||||
async def disable_plugin(plugin_id: str) -> dict[str, str]:
|
||||
"""禁用插件"""
|
||||
manager = get_plugin_manager()
|
||||
if manager.disable(plugin_id):
|
||||
return {"status": "disabled", "plugin_id": plugin_id}
|
||||
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found")
|
||||
|
||||
|
||||
@router.delete("/{plugin_id}", response_model=dict[str, str])
|
||||
async def uninstall_plugin(plugin_id: str) -> dict[str, str]:
|
||||
"""卸载插件"""
|
||||
manager = get_plugin_manager()
|
||||
if manager.uninstall(plugin_id):
|
||||
return {"status": "uninstalled", "plugin_id": plugin_id}
|
||||
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found")
|
||||
|
||||
|
||||
@router.post("/{plugin_id}/reload", response_model=dict[str, str])
|
||||
async def reload_plugin(plugin_id: str) -> dict[str, str]:
|
||||
"""重新加载插件"""
|
||||
manager = get_plugin_manager()
|
||||
if manager.reload(plugin_id):
|
||||
return {"status": "reloaded", "plugin_id": plugin_id}
|
||||
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found")
|
||||
|
||||
|
||||
# === Plugin Marketplace ===
|
||||
|
||||
_marketplace_router = APIRouter(prefix="/api/marketplace", tags=["Plugin Marketplace"])
|
||||
|
||||
|
||||
@_marketplace_router.get("/plugins", response_model=dict[str, Any])
|
||||
async def search_marketplace_plugins(
|
||||
query: str | None = None,
|
||||
category: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""搜索插件市场"""
|
||||
results = _plugin_marketplace
|
||||
if query:
|
||||
results = [
|
||||
p
|
||||
for p in results
|
||||
if query.lower() in p.get("name", "").lower()
|
||||
or query.lower() in p.get("description", "").lower()
|
||||
]
|
||||
if category:
|
||||
results = [p for p in results if p.get("category") == category]
|
||||
return {"plugins": results, "count": len(results)}
|
||||
|
||||
|
||||
@_marketplace_router.get("/plugins/{plugin_id}", response_model=dict[str, Any])
|
||||
async def get_marketplace_plugin(plugin_id: str) -> dict[str, Any]:
|
||||
"""获取市场中的插件详情"""
|
||||
for plugin in _plugin_marketplace:
|
||||
if plugin.get("id") == plugin_id:
|
||||
return plugin
|
||||
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found in marketplace")
|
||||
|
||||
|
||||
@_marketplace_router.post("/plugins", response_model=dict[str, str])
|
||||
async def add_to_marketplace(plugin: dict[str, str]) -> dict[str, str]:
|
||||
"""添加插件到市场(仅供测试/开发)"""
|
||||
if "id" not in plugin or "name" not in plugin:
|
||||
raise HTTPException(status_code=400, detail="Plugin must have id and name")
|
||||
# 移除已存在的同 ID 插件
|
||||
global _plugin_marketplace
|
||||
_plugin_marketplace = [p for p in _plugin_marketplace if p.get("id") != plugin["id"]]
|
||||
_plugin_marketplace.append(plugin)
|
||||
return {"status": "added", "id": plugin["id"]}
|
||||
|
||||
|
||||
@_marketplace_router.post("/plugins/{plugin_id}/download", response_model=dict[str, str])
|
||||
async def download_plugin(plugin_id: str) -> dict[str, str]:
|
||||
"""从市场下载并安装插件"""
|
||||
# Find plugin in marketplace
|
||||
plugin = None
|
||||
for p in _plugin_marketplace:
|
||||
if p.get("id") == plugin_id:
|
||||
plugin = p
|
||||
break
|
||||
|
||||
if not plugin:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Plugin '{plugin_id}' not found in marketplace"
|
||||
)
|
||||
|
||||
download_url = plugin.get("download_url")
|
||||
if not download_url:
|
||||
raise HTTPException(status_code=400, detail="Plugin has no download URL")
|
||||
|
||||
try:
|
||||
# Download the plugin archive
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(download_url, timeout=60.0)
|
||||
response.raise_for_status()
|
||||
archive_content = response.content
|
||||
|
||||
# Extract to temp directory and install
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
archive_path = os.path.join(temp_dir, "plugin.zip")
|
||||
with open(archive_path, "wb") as f:
|
||||
f.write(archive_content)
|
||||
|
||||
extract_dir = os.path.join(temp_dir, "extracted")
|
||||
with zipfile.ZipFile(archive_path, "r") as zf:
|
||||
zf.extractall(extract_dir)
|
||||
|
||||
# Install the plugin
|
||||
manager = get_plugin_manager()
|
||||
if manager.install(extract_dir):
|
||||
return {"status": "installed", "plugin_id": plugin_id}
|
||||
raise HTTPException(status_code=500, detail="Failed to install plugin")
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(status_code=502, detail=f"Download failed: {str(e)}")
|
||||
except zipfile.BadZipFile:
|
||||
raise HTTPException(status_code=502, detail="Invalid plugin archive")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Installation failed: {str(e)}")
|
||||
@@ -1,4 +1,7 @@
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AgentCreate(BaseModel):
|
||||
@@ -6,6 +9,7 @@ class AgentCreate(BaseModel):
|
||||
role: str
|
||||
description: str | None = None
|
||||
system_prompt: str
|
||||
spawn_permission: bool = False
|
||||
|
||||
|
||||
class AgentOut(BaseModel):
|
||||
@@ -55,3 +59,163 @@ class AgentConfigOut(BaseModel):
|
||||
selected_skill_ids: list[str]
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class AgentVisibilityEventOut(BaseModel):
|
||||
event_id: str
|
||||
event_type: str
|
||||
timestamp: datetime
|
||||
conversation_id: str | None = None
|
||||
agent_id: str | None = None
|
||||
sub_commander_id: str | None = None
|
||||
task_id: str | None = None
|
||||
parent_task_id: str | None = None
|
||||
child_task_id: str | None = None
|
||||
thread_id: str | None = None
|
||||
message_id: str | None = None
|
||||
interrupt_id: str | None = None
|
||||
recovery_id: str | None = None
|
||||
payload: dict[str, Any] = Field(default_factory=dict)
|
||||
severity: str = "info"
|
||||
|
||||
|
||||
class AgentVisibilityEventsResponse(BaseModel):
|
||||
conversation_id: str
|
||||
total: int
|
||||
limit: int
|
||||
offset: int
|
||||
items: list[AgentVisibilityEventOut]
|
||||
|
||||
|
||||
class AgentVisibilityTaskSummaryOut(BaseModel):
|
||||
task_id: str
|
||||
role: str | None = None
|
||||
owner_agent_id: str | None = None
|
||||
status: str | None = None
|
||||
summary: str | None = None
|
||||
evidence_count: int = 0
|
||||
|
||||
|
||||
class AgentVisibilityTopologyNodeOut(BaseModel):
|
||||
agent_id: str
|
||||
role: str | None = None
|
||||
parent_agent_id: str | None = None
|
||||
source: str
|
||||
task_count: int = 0
|
||||
completed_task_count: int = 0
|
||||
|
||||
|
||||
class AgentVisibilityTopologyOut(BaseModel):
|
||||
conversation_id: str
|
||||
root_agent_id: str | None = None
|
||||
current_agent: str | None = None
|
||||
nodes: list[AgentVisibilityTopologyNodeOut]
|
||||
edges: list[dict[str, str]]
|
||||
tasks: list[AgentVisibilityTaskSummaryOut]
|
||||
task_hierarchy: dict[str, list[str]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class AgentVisibilityEvidenceOut(BaseModel):
|
||||
conversation_id: str
|
||||
task_id: str
|
||||
task: dict[str, Any] | None = None
|
||||
result: dict[str, Any] | None = None
|
||||
tool_outcomes: list[dict[str, Any]] = Field(default_factory=list)
|
||||
verifier: dict[str, Any]
|
||||
|
||||
|
||||
class AgentVisibilityThreadMessageOut(BaseModel):
|
||||
message_id: str
|
||||
thread_id: str
|
||||
from_agent_id: str
|
||||
to_agent_id: str
|
||||
task_id: str | None = None
|
||||
reply_to_message_id: str | None = None
|
||||
message_type: str
|
||||
content_summary: str
|
||||
created_at: datetime
|
||||
payload: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class AgentVisibilityThreadOut(BaseModel):
|
||||
conversation_id: str
|
||||
thread_id: str
|
||||
total: int
|
||||
items: list[AgentVisibilityThreadMessageOut]
|
||||
|
||||
|
||||
class AgentVisibilityVerifierOut(BaseModel):
|
||||
conversation_id: str
|
||||
status: str | None = None
|
||||
summary: str | None = None
|
||||
evidence: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AgentVisibilityIsolationOut(BaseModel):
|
||||
mode: str = "none"
|
||||
isolation_id: str | None = None
|
||||
workspace_path: str | None = None
|
||||
parent_conversation_id: str | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class AgentVisibilityCostOut(BaseModel):
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
estimated_cost: float | None = None
|
||||
budget_warning: bool = False
|
||||
currency: str = "USD"
|
||||
|
||||
|
||||
class AgentVisibilityCostByAgentOut(BaseModel):
|
||||
agent_id: str
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
estimated_cost: float | None = None
|
||||
budget_warning: bool = False
|
||||
|
||||
|
||||
class AgentVisibilityCostSummaryOut(BaseModel):
|
||||
conversation_id: str
|
||||
total: AgentVisibilityCostOut
|
||||
thresholds: dict[str, float] = Field(default_factory=dict)
|
||||
by_agent: list[AgentVisibilityCostByAgentOut] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AgentVisibilityToolGovernanceItemOut(BaseModel):
|
||||
capability_id: str
|
||||
tool_name: str
|
||||
permission_class: str
|
||||
side_effect_scope: str
|
||||
supports_retry: bool = False
|
||||
idempotent: bool = False
|
||||
safe_for_parallel_use: bool = False
|
||||
requires_confirmation: bool = False
|
||||
usage_count: int = 0
|
||||
last_result_preview: str | None = None
|
||||
|
||||
|
||||
class AgentVisibilityToolGovernanceOut(BaseModel):
|
||||
conversation_id: str
|
||||
total_tools: int = 0
|
||||
used_tools: int = 0
|
||||
items: list[AgentVisibilityToolGovernanceItemOut] = Field(default_factory=list)
|
||||
upgrade_candidates: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AgentVisibilityRuntimeSummaryOut(BaseModel):
|
||||
conversation_id: str
|
||||
execution_mode: str | None = None
|
||||
current_phase: str | None = None
|
||||
current_checkpoint: str | None = None
|
||||
phase_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||
checkpoint_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||
verifier: AgentVisibilityVerifierOut
|
||||
isolation: AgentVisibilityIsolationOut
|
||||
cost: AgentVisibilityCostOut
|
||||
topology_node_count: int = 0
|
||||
active_task_count: int = 0
|
||||
completed_task_count: int = 0
|
||||
recent_events: list[AgentVisibilityEventOut] = Field(default_factory=list)
|
||||
|
||||
@@ -21,6 +21,7 @@ from app.models.conversation import Conversation, Message
|
||||
from app.models.user import User
|
||||
from app.agents.graph import get_agent_graph
|
||||
from app.agents.context import set_current_user, clear_current_user
|
||||
from app.agents.skills.registry import get_skill_registry
|
||||
from app.services import memory_service
|
||||
from app.services.brain_service import BrainService
|
||||
from app.services.llm_service import create_llm_from_config, resolve_provider_capabilities
|
||||
@@ -30,6 +31,56 @@ from app.agents.state import initial_state
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MEMORY_SECTION_HEADERS = (
|
||||
"【用户记忆】",
|
||||
"【之前对话摘要】",
|
||||
"【知识大脑】",
|
||||
)
|
||||
|
||||
|
||||
def _split_memory_context_sections(memory_context: str | None) -> dict[str, str]:
|
||||
text = (memory_context or "").strip()
|
||||
if not text:
|
||||
return {}
|
||||
|
||||
sections: dict[str, str] = {}
|
||||
current_header: str | None = None
|
||||
current_lines: list[str] = []
|
||||
|
||||
for line in text.splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped in MEMORY_SECTION_HEADERS:
|
||||
if current_header and current_lines:
|
||||
sections[current_header] = "\n".join(current_lines).strip()
|
||||
current_header = stripped
|
||||
current_lines = [stripped]
|
||||
continue
|
||||
if current_header:
|
||||
current_lines.append(line)
|
||||
|
||||
if current_header and current_lines:
|
||||
sections[current_header] = "\n".join(current_lines).strip()
|
||||
|
||||
return sections
|
||||
|
||||
|
||||
def _derive_role_memory_contexts(memory_context: str | None) -> dict[str, str | None]:
|
||||
sections = _split_memory_context_sections(memory_context)
|
||||
user_memory = sections.get("【用户记忆】")
|
||||
summaries = sections.get("【之前对话摘要】")
|
||||
knowledge = sections.get("【知识大脑】")
|
||||
|
||||
def _join_parts(*parts: str | None) -> str | None:
|
||||
values = [part for part in parts if part]
|
||||
return "\n\n".join(values) if values else None
|
||||
|
||||
return {
|
||||
"schedule_context_summary": _join_parts(user_memory, summaries),
|
||||
"knowledge_context": knowledge,
|
||||
"analysis_report": _join_parts(summaries, knowledge),
|
||||
}
|
||||
|
||||
|
||||
def _is_streaming_rejection_error(error: Exception, user_llm_config: dict | None) -> bool:
|
||||
capabilities = resolve_provider_capabilities(user_llm_config)
|
||||
error_text = str(error).lower()
|
||||
@@ -45,9 +96,8 @@ def _is_streaming_rejection_error(error: Exception, user_llm_config: dict | None
|
||||
]
|
||||
|
||||
if isinstance(error, BadRequestError):
|
||||
return (
|
||||
getattr(capabilities, "provider", None) not in {"openai", "claude"}
|
||||
and any(marker in error_text for marker in markers)
|
||||
return getattr(capabilities, "provider", None) not in {"openai", "claude"} and any(
|
||||
marker in error_text for marker in markers
|
||||
)
|
||||
|
||||
return any(marker in error_text for marker in markers)
|
||||
@@ -84,14 +134,165 @@ _CONTINUITY_SNAPSHOT_FIELDS = (
|
||||
"current_agent",
|
||||
"next_step",
|
||||
"agent_trace",
|
||||
"agent_id",
|
||||
"parent_agent_id",
|
||||
"root_agent_id",
|
||||
"collaboration_depth",
|
||||
"thread_id",
|
||||
"last_message_id",
|
||||
"message_sequence",
|
||||
"spawned_agent_ids",
|
||||
"current_sub_commander",
|
||||
"active_sub_commanders",
|
||||
"sub_commander_trace",
|
||||
"event_trace",
|
||||
"message_trace",
|
||||
"active_tasks",
|
||||
"task_results",
|
||||
"task_hierarchy",
|
||||
"verification_status",
|
||||
"verification_summary",
|
||||
"verification_evidence",
|
||||
"isolation_mode",
|
||||
"isolation_id",
|
||||
"isolation_workspace_path",
|
||||
"isolation_parent_conversation_id",
|
||||
"isolation_metadata",
|
||||
"input_tokens",
|
||||
"output_tokens",
|
||||
"estimated_cost",
|
||||
"budget_warning",
|
||||
"cost_by_agent",
|
||||
"cost_thresholds",
|
||||
"budget_state",
|
||||
"collaboration_budget_history",
|
||||
"current_phase",
|
||||
"phase_history",
|
||||
"current_checkpoint",
|
||||
"checkpoint_history",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_legacy_turn_context(turn_context: Any, current_agent: Any) -> dict[str, Any] | None:
|
||||
if not isinstance(turn_context, dict):
|
||||
return None
|
||||
normalized = dict(turn_context)
|
||||
active_agent = normalized.pop("active_agent", None)
|
||||
active_sub_flow = normalized.pop("active_sub_flow", None)
|
||||
if isinstance(active_agent, str) and active_agent and "active_agent" not in normalized:
|
||||
normalized["active_agent"] = active_agent
|
||||
if (
|
||||
isinstance(active_sub_flow, str)
|
||||
and active_sub_flow
|
||||
and "active_sub_commander" not in normalized
|
||||
):
|
||||
normalized["active_sub_commander"] = active_sub_flow
|
||||
if not normalized.get("active_agent") and isinstance(current_agent, str) and current_agent:
|
||||
normalized["active_agent"] = current_agent
|
||||
return normalized or None
|
||||
|
||||
|
||||
def _normalize_legacy_pending_action(pending_action: Any) -> dict[str, Any] | None:
|
||||
if not isinstance(pending_action, dict):
|
||||
return None
|
||||
normalized = dict(pending_action)
|
||||
legacy_action_type = normalized.pop("action_type", None)
|
||||
if legacy_action_type and "type" not in normalized:
|
||||
normalized["type"] = legacy_action_type
|
||||
legacy_agent = normalized.pop("agent", None)
|
||||
legacy_sub_flow = normalized.pop("sub_flow", None)
|
||||
if legacy_agent and "owner_agent" not in normalized:
|
||||
normalized["owner_agent"] = legacy_agent
|
||||
if legacy_sub_flow and "owner_sub_commander" not in normalized:
|
||||
normalized["owner_sub_commander"] = legacy_sub_flow
|
||||
legacy_status = normalized.get("status")
|
||||
if legacy_status == "awaiting_confirmation":
|
||||
normalized["status"] = "pending"
|
||||
elif legacy_status == "awaiting_clarification":
|
||||
normalized["status"] = "blocked_on_clarification"
|
||||
return normalized or None
|
||||
|
||||
|
||||
def _normalize_legacy_clarification_context(
|
||||
clarification_context: Any,
|
||||
pending_action: dict[str, Any] | None,
|
||||
current_agent: Any,
|
||||
) -> dict[str, Any] | None:
|
||||
if not isinstance(clarification_context, dict):
|
||||
return None
|
||||
normalized = dict(clarification_context)
|
||||
active_agent = normalized.pop("active_agent", None)
|
||||
sub_flow = normalized.pop("sub_flow", None)
|
||||
awaiting_user_input = normalized.pop("awaiting_user_input", None)
|
||||
if isinstance(active_agent, str) and active_agent and "owning_agent" not in normalized:
|
||||
normalized["owning_agent"] = active_agent
|
||||
if isinstance(sub_flow, str) and sub_flow and "owning_sub_commander" not in normalized:
|
||||
normalized["owning_sub_commander"] = sub_flow
|
||||
if "target_action" not in normalized:
|
||||
target_action = None
|
||||
if pending_action:
|
||||
pending_type = pending_action.get("type")
|
||||
if isinstance(pending_type, str) and pending_type and pending_type != "clarification":
|
||||
target_action = pending_type
|
||||
if target_action is None and isinstance(sub_flow, str) and sub_flow.startswith("create_"):
|
||||
target_action = sub_flow
|
||||
if target_action:
|
||||
normalized["target_action"] = target_action
|
||||
if not normalized.get("owning_agent") and isinstance(current_agent, str) and current_agent:
|
||||
normalized["owning_agent"] = current_agent
|
||||
if awaiting_user_input is True and "status" not in normalized:
|
||||
normalized["status"] = "pending"
|
||||
return normalized or None
|
||||
|
||||
|
||||
def _normalize_legacy_continuity_state(
|
||||
continuity_state: Any,
|
||||
clarification_context: dict[str, Any] | None,
|
||||
) -> dict[str, Any] | None:
|
||||
if not isinstance(continuity_state, dict):
|
||||
return None
|
||||
normalized = dict(continuity_state)
|
||||
normalized.pop("active_agent", None)
|
||||
normalized.pop("active_sub_flow", None)
|
||||
legacy_status = normalized.get("status")
|
||||
if legacy_status == "awaiting_clarification":
|
||||
normalized["status"] = "fresh"
|
||||
if clarification_context and "mode" not in normalized:
|
||||
normalized["mode"] = "resume_after_clarification"
|
||||
return normalized or None
|
||||
|
||||
|
||||
def _normalize_continuity_snapshot(state: dict[str, Any]) -> dict[str, Any]:
|
||||
normalized = dict(state)
|
||||
current_agent = normalized.get("current_agent")
|
||||
pending_action = _normalize_legacy_pending_action(normalized.get("pending_action"))
|
||||
clarification_context = _normalize_legacy_clarification_context(
|
||||
normalized.get("clarification_context"),
|
||||
pending_action,
|
||||
current_agent,
|
||||
)
|
||||
continuity_state = _normalize_legacy_continuity_state(
|
||||
normalized.get("continuity_state"),
|
||||
clarification_context,
|
||||
)
|
||||
turn_context = _normalize_legacy_turn_context(normalized.get("turn_context"), current_agent)
|
||||
if pending_action is not None:
|
||||
normalized["pending_action"] = pending_action
|
||||
if clarification_context is not None:
|
||||
normalized["clarification_context"] = clarification_context
|
||||
if continuity_state is not None:
|
||||
normalized["continuity_state"] = continuity_state
|
||||
if turn_context is not None:
|
||||
normalized["turn_context"] = turn_context
|
||||
return normalized
|
||||
|
||||
|
||||
def _build_continuity_snapshot(state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
normalized_state = _normalize_continuity_snapshot(state)
|
||||
snapshot = {
|
||||
field: state.get(field)
|
||||
field: normalized_state.get(field)
|
||||
for field in _CONTINUITY_SNAPSHOT_FIELDS
|
||||
if state.get(field) is not None
|
||||
if normalized_state.get(field) is not None
|
||||
}
|
||||
if not snapshot:
|
||||
return None
|
||||
@@ -116,7 +317,7 @@ def _extract_continuity_snapshot(payload: Any) -> dict[str, Any] | None:
|
||||
return None
|
||||
state = payload.get("state")
|
||||
if isinstance(state, dict):
|
||||
return state
|
||||
return _normalize_continuity_snapshot(state)
|
||||
return None
|
||||
|
||||
|
||||
@@ -160,11 +361,32 @@ class AgentService:
|
||||
"【当前时间】\n"
|
||||
f"- current_time_utc: {reference['current_time_iso']}\n"
|
||||
f"- current_date_utc: {reference['current_date_iso']}\n"
|
||||
"说明:解析‘今天/明天/后天/本周/下周’等相对时间时,请以 current_time_utc 为准。"
|
||||
"说明:解析'今天/明天/后天/本周/下周'等相对时间时,请以 current_time_utc 为准。"
|
||||
)
|
||||
return context, reference
|
||||
|
||||
async def _get_user_llm_config(self, user_id: str, model_name: str | None = None) -> dict | None:
|
||||
def build_skill_context(self, skill_names: list[str]) -> dict:
|
||||
"""构建 Skills 上下文
|
||||
|
||||
Args:
|
||||
skill_names: Skill 名称列表
|
||||
|
||||
Returns:
|
||||
包含 skills 上下文的字典
|
||||
"""
|
||||
registry = get_skill_registry()
|
||||
merged_context = registry.get_skill_context(skill_names)
|
||||
return {
|
||||
"skills_context": merged_context,
|
||||
"skills_metadata": {
|
||||
"skills": skill_names,
|
||||
"count": len(skill_names),
|
||||
},
|
||||
}
|
||||
|
||||
async def _get_user_llm_config(
|
||||
self, user_id: str, model_name: str | None = None
|
||||
) -> dict | None:
|
||||
"""获取用户的 LLM 模型配置"""
|
||||
user = await self.db.get(User, user_id)
|
||||
if not user or not user.llm_config:
|
||||
@@ -187,7 +409,7 @@ class AgentService:
|
||||
return None
|
||||
|
||||
async def _load_continuity_snapshot(self, conversation: Conversation) -> dict[str, Any] | None:
|
||||
snapshot = _extract_continuity_snapshot(conversation.agent_state)
|
||||
snapshot = _extract_continuity_snapshot(getattr(conversation, "agent_state", None))
|
||||
if snapshot:
|
||||
return snapshot
|
||||
|
||||
@@ -214,13 +436,15 @@ class AgentService:
|
||||
user_llm_config: dict | None,
|
||||
) -> dict[str, Any]:
|
||||
state = initial_state(user_id, conversation.id)
|
||||
state.update({
|
||||
"messages": [HumanMessage(content=full_message)],
|
||||
"memory_context": memory_context,
|
||||
"current_datetime_context": current_datetime_context,
|
||||
"current_datetime_reference": current_datetime_reference,
|
||||
"user_llm_config": user_llm_config,
|
||||
})
|
||||
state.update(
|
||||
{
|
||||
"messages": [HumanMessage(content=full_message)],
|
||||
"memory_context": memory_context,
|
||||
"current_datetime_context": current_datetime_context,
|
||||
"current_datetime_reference": current_datetime_reference,
|
||||
"user_llm_config": user_llm_config,
|
||||
}
|
||||
)
|
||||
previous_snapshot = await self._load_continuity_snapshot(conversation)
|
||||
if previous_snapshot:
|
||||
state.update(previous_snapshot)
|
||||
@@ -282,6 +506,7 @@ class AgentService:
|
||||
file_context = ""
|
||||
if file_ids:
|
||||
from app.services.document_service import DocumentService
|
||||
|
||||
doc_svc = DocumentService(self.db)
|
||||
for file_id in file_ids:
|
||||
content = await doc_svc.get_document_content(user_id, file_id)
|
||||
@@ -347,7 +572,9 @@ class AgentService:
|
||||
set_current_user(user_id)
|
||||
try:
|
||||
graph = get_agent_graph()
|
||||
current_datetime_context, current_datetime_reference = self._build_current_datetime_context()
|
||||
current_datetime_context, current_datetime_reference = (
|
||||
self._build_current_datetime_context()
|
||||
)
|
||||
|
||||
state = await self._build_agent_state(
|
||||
user_id=user_id,
|
||||
@@ -358,8 +585,11 @@ class AgentService:
|
||||
current_datetime_reference=current_datetime_reference,
|
||||
user_llm_config=user_llm_config,
|
||||
)
|
||||
state.update(_derive_role_memory_contexts(memory_ctx))
|
||||
|
||||
yield self._build_progress_event("thinking", "Jarvis 正在分析请求", agent="master", step="理解你的问题")
|
||||
yield self._build_progress_event(
|
||||
"thinking", "Jarvis 正在分析请求", agent="master", step="理解你的问题"
|
||||
)
|
||||
|
||||
try:
|
||||
async for event in graph.astream_events(state, version="v2"):
|
||||
@@ -368,7 +598,13 @@ class AgentService:
|
||||
metadata = event.get("metadata", {})
|
||||
data = event.get("data", {})
|
||||
|
||||
if kind == "on_chain_start" and event_name in {"master", "schedule_planner", "executor", "librarian", "analyst"}:
|
||||
if kind == "on_chain_start" and event_name in {
|
||||
"master",
|
||||
"schedule_planner",
|
||||
"executor",
|
||||
"librarian",
|
||||
"analyst",
|
||||
}:
|
||||
stage_map = {
|
||||
"master": ("thinking", "Jarvis 正在理解请求"),
|
||||
"schedule_planner": ("planning", "Jarvis 正在编排日程"),
|
||||
@@ -376,9 +612,13 @@ class AgentService:
|
||||
"librarian": ("tool", "Jarvis 正在检索知识"),
|
||||
"analyst": ("thinking", "Jarvis 正在分析信息"),
|
||||
}
|
||||
stage, label = stage_map.get(event_name, ("thinking", "Jarvis 正在思考"))
|
||||
yield self._build_progress_event(stage, label, agent=event_name, step=label)
|
||||
|
||||
stage, label = stage_map.get(
|
||||
event_name, ("thinking", "Jarvis 正在思考")
|
||||
)
|
||||
yield self._build_progress_event(
|
||||
stage, label, agent=event_name, step=label
|
||||
)
|
||||
|
||||
elif kind == "on_tool_start":
|
||||
yield self._build_progress_event(
|
||||
"tool",
|
||||
@@ -387,7 +627,7 @@ class AgentService:
|
||||
tool_name=event_name,
|
||||
step=f"正在执行 {event_name}",
|
||||
)
|
||||
|
||||
|
||||
elif kind == "on_tool_end":
|
||||
tool_result = data.get("output")
|
||||
step = f"已完成 {event_name}"
|
||||
@@ -400,14 +640,16 @@ class AgentService:
|
||||
tool_name=event_name,
|
||||
step=step,
|
||||
)
|
||||
|
||||
|
||||
elif kind == "on_chat_model_stream":
|
||||
chunk = data.get("chunk")
|
||||
content = _coerce_event_text(getattr(chunk, "content", "") if chunk else "")
|
||||
content = _coerce_event_text(
|
||||
getattr(chunk, "content", "") if chunk else ""
|
||||
)
|
||||
if content:
|
||||
collected += content
|
||||
yield {"type": "chunk", "content": content}
|
||||
|
||||
|
||||
elif kind == "on_chain_end":
|
||||
output = data.get("output")
|
||||
final_resp = None
|
||||
@@ -422,7 +664,9 @@ class AgentService:
|
||||
|
||||
elif kind == "on_chat_model_end":
|
||||
output = data.get("output")
|
||||
final_content = _coerce_event_text(getattr(output, "content", "") if output else "")
|
||||
final_content = _coerce_event_text(
|
||||
getattr(output, "content", "") if output else ""
|
||||
)
|
||||
if final_content:
|
||||
final_text = final_content
|
||||
if final_text != collected:
|
||||
@@ -431,12 +675,16 @@ class AgentService:
|
||||
|
||||
except Exception as e:
|
||||
if _is_streaming_rejection_error(e, user_llm_config) and not collected:
|
||||
yield self._build_progress_event("responding", "Jarvis 正在生成回复", agent="master", step="fallback")
|
||||
yield self._build_progress_event(
|
||||
"responding", "Jarvis 正在生成回复", agent="master", step="fallback"
|
||||
)
|
||||
try:
|
||||
result_state = await graph.ainvoke(state)
|
||||
if isinstance(result_state, dict):
|
||||
state.update(result_state)
|
||||
fallback_content = result_state.get("final_response") or str(result_state.get("messages", [AIMessage(content="")])[-1].content)
|
||||
fallback_content = result_state.get("final_response") or str(
|
||||
result_state.get("messages", [AIMessage(content="")])[-1].content
|
||||
)
|
||||
collected = str(fallback_content)
|
||||
yield {"type": "chunk", "content": collected}
|
||||
except Exception:
|
||||
@@ -460,11 +708,24 @@ class AgentService:
|
||||
if collected:
|
||||
assistant_msg.content = collected
|
||||
continuity_snapshot = _build_continuity_snapshot(state or {})
|
||||
assistant_msg.attachments = ([{
|
||||
"kind": "agent_continuity_state",
|
||||
**continuity_snapshot,
|
||||
}] if continuity_snapshot else None)
|
||||
conv.agent_state = continuity_snapshot
|
||||
assistant_msg.attachments = (
|
||||
[
|
||||
{
|
||||
"kind": "agent_continuity_state",
|
||||
**continuity_snapshot,
|
||||
}
|
||||
]
|
||||
if continuity_snapshot
|
||||
else None
|
||||
)
|
||||
conv.agent_state = (
|
||||
{
|
||||
"kind": "agent_continuity_state",
|
||||
**continuity_snapshot,
|
||||
}
|
||||
if continuity_snapshot
|
||||
else None
|
||||
)
|
||||
await BrainService(self.db).create_event(
|
||||
user_id,
|
||||
**_build_assistant_event_payload(collected),
|
||||
@@ -542,12 +803,16 @@ class AgentService:
|
||||
importance_signal=1.0,
|
||||
)
|
||||
|
||||
memory_ctx = await memory_service.build_memory_context(self.db, user_id, conversation_id, message)
|
||||
memory_ctx = await memory_service.build_memory_context(
|
||||
self.db, user_id, conversation_id, message
|
||||
)
|
||||
|
||||
set_current_user(user_id)
|
||||
try:
|
||||
graph = get_agent_graph()
|
||||
current_datetime_context, current_datetime_reference = self._build_current_datetime_context()
|
||||
current_datetime_context, current_datetime_reference = (
|
||||
self._build_current_datetime_context()
|
||||
)
|
||||
state = await self._build_agent_state(
|
||||
user_id=user_id,
|
||||
conversation=conv,
|
||||
@@ -557,9 +822,11 @@ class AgentService:
|
||||
current_datetime_reference=current_datetime_reference,
|
||||
user_llm_config=user_llm_config,
|
||||
)
|
||||
|
||||
state.update(_derive_role_memory_contexts(memory_ctx))
|
||||
result_state = await graph.ainvoke(state)
|
||||
response_content = result_state.get("final_response") or str(result_state.get("messages", [AIMessage(content="")])[-1].content)
|
||||
response_content = result_state.get("final_response") or str(
|
||||
result_state.get("messages", [AIMessage(content="")])[-1].content
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("agent_chat_simple_failed")
|
||||
response_content = "抱歉,发生错误。"
|
||||
@@ -580,12 +847,27 @@ class AgentService:
|
||||
)
|
||||
|
||||
assistant_msg.content = response_content
|
||||
continuity_snapshot = _build_continuity_snapshot(result_state) if 'result_state' in locals() else None
|
||||
assistant_msg.attachments = ([{
|
||||
"kind": "agent_continuity_state",
|
||||
**continuity_snapshot,
|
||||
}] if continuity_snapshot else None)
|
||||
conv.agent_state = continuity_snapshot
|
||||
continuity_snapshot = (
|
||||
_build_continuity_snapshot(result_state) if "result_state" in locals() else None
|
||||
)
|
||||
assistant_msg.attachments = (
|
||||
[
|
||||
{
|
||||
"kind": "agent_continuity_state",
|
||||
**continuity_snapshot,
|
||||
}
|
||||
]
|
||||
if continuity_snapshot
|
||||
else None
|
||||
)
|
||||
conv.agent_state = (
|
||||
{
|
||||
"kind": "agent_continuity_state",
|
||||
**continuity_snapshot,
|
||||
}
|
||||
if continuity_snapshot
|
||||
else None
|
||||
)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(assistant_msg)
|
||||
|
||||
|
||||
@@ -4,12 +4,15 @@ Jarvis 记忆系统 (基于 Mem0)
|
||||
底层使用 Mem0 实现事实提取、时间线、矛盾解决和遗忘机制
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, Any
|
||||
from sqlalchemy import select, desc, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.models.memory import UserMemory
|
||||
from app.models.user import User
|
||||
from app.services.brain_service import BrainService
|
||||
from app.config import settings as _settings
|
||||
@@ -23,6 +26,9 @@ except ImportError:
|
||||
Memory = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _get_user_embedding_config(db: AsyncSession, user_id: str) -> dict | None:
|
||||
"""从用户配置中获取 embedding 模型配置"""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
@@ -296,6 +302,23 @@ async def extract_user_memories(
|
||||
return []
|
||||
|
||||
|
||||
def _extract_memory_query_tokens(query: str) -> list[str]:
|
||||
normalized_query = (query or "").lower()
|
||||
tokens = [token for token in re.findall(r"[a-z0-9]+", normalized_query) if len(token) >= 3]
|
||||
|
||||
for chunk in re.findall(r"[\u4e00-\u9fff]+", query or ""):
|
||||
stripped_chunk = chunk.strip()
|
||||
if len(stripped_chunk) >= 4:
|
||||
tokens.append(stripped_chunk)
|
||||
if len(stripped_chunk) > 6:
|
||||
tokens.extend(
|
||||
stripped_chunk[index:index + 4]
|
||||
for index in range(len(stripped_chunk) - 3)
|
||||
)
|
||||
|
||||
return list(dict.fromkeys(tokens))
|
||||
|
||||
|
||||
async def recall_user_memories(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
@@ -304,7 +327,7 @@ async def recall_user_memories(
|
||||
) -> list[dict]:
|
||||
"""
|
||||
根据当前输入召回相关的用户记忆。
|
||||
使用 Mem0 的语义搜索。
|
||||
使用 Mem0 的语义搜索;如果 Mem0 不可用或失败,则回退到本地 UserMemory。
|
||||
"""
|
||||
try:
|
||||
mem0 = await get_mem0(db, user_id)
|
||||
@@ -313,10 +336,56 @@ async def recall_user_memories(
|
||||
filters={"user_id": user_id},
|
||||
limit=top_k,
|
||||
)
|
||||
return results.get("results", [])
|
||||
mem0_results = results.get("results", [])
|
||||
if mem0_results:
|
||||
return mem0_results
|
||||
except Exception as e:
|
||||
print(f"Mem0 search error: {e}")
|
||||
return []
|
||||
|
||||
query_tokens = _extract_memory_query_tokens(query)
|
||||
statement = select(UserMemory).where(UserMemory.user_id == user_id)
|
||||
result = await db.execute(statement.order_by(UserMemory.importance.desc(), UserMemory.created_at.desc()))
|
||||
fallback_memories = list(result.scalars().all())
|
||||
|
||||
if _contains_hint(_normalize_query(query), MEMORY_QUERY_HINTS) or _matches_memory_query_pattern(_normalize_query(query)):
|
||||
return fallback_memories[:top_k]
|
||||
|
||||
if query_tokens:
|
||||
matched_memories = [
|
||||
memory for memory in fallback_memories
|
||||
if any(token in (memory.content or '').lower() for token in query_tokens)
|
||||
]
|
||||
return matched_memories[:top_k]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
async def _mark_memories_recalled(db: AsyncSession, memories: list[UserMemory]) -> None:
|
||||
recalled_at = datetime.now(UTC)
|
||||
updated = False
|
||||
for memory in memories:
|
||||
memory.is_recalled = True
|
||||
memory.recall_count = (memory.recall_count or 0) + 1
|
||||
memory.last_recalled_at = recalled_at
|
||||
updated = True
|
||||
if updated:
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def _run_tolerated_section(
|
||||
db: AsyncSession,
|
||||
section_name: str,
|
||||
builder,
|
||||
) -> str:
|
||||
try:
|
||||
return await builder()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[MemoryService] %s失败,继续构建剩余上下文",
|
||||
section_name,
|
||||
exc_info=True,
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
async def get_user_profile(db: AsyncSession, user_id: str) -> dict:
|
||||
@@ -339,6 +408,131 @@ async def get_user_profile(db: AsyncSession, user_id: str) -> dict:
|
||||
|
||||
# ———— 记忆组装: 供 Agent 使用的上下文 ————
|
||||
|
||||
MEMORY_QUERY_HINTS = (
|
||||
"记住",
|
||||
"记下",
|
||||
"记一下",
|
||||
"记着",
|
||||
"提醒",
|
||||
"偏好",
|
||||
"习惯",
|
||||
)
|
||||
MEMORY_QUERY_PATTERNS = (
|
||||
re.compile(r"\bremember\s+(?:that\s+)?i\b"),
|
||||
)
|
||||
GROUNDING_QUERY_HINTS = (
|
||||
"根据文档",
|
||||
"严格根据",
|
||||
"只根据",
|
||||
"文档内容",
|
||||
"grounded",
|
||||
"strictly based on",
|
||||
"based on the document",
|
||||
"based on the docs",
|
||||
"document only",
|
||||
"docs only",
|
||||
"only use the document",
|
||||
"only use the docs",
|
||||
)
|
||||
AVOID_USER_MEMORY_HINTS = (
|
||||
"不要结合我的个人偏好",
|
||||
"不要结合个人偏好",
|
||||
"不要结合偏好",
|
||||
"不要结合我的记忆",
|
||||
"不要结合记忆",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_query(text: str) -> str:
|
||||
return text.strip().lower()
|
||||
|
||||
|
||||
def _contains_hint(text: str, hints: tuple[str, ...]) -> bool:
|
||||
return any(hint in text for hint in hints)
|
||||
|
||||
|
||||
def _matches_memory_query_pattern(text: str) -> bool:
|
||||
return any(pattern.search(text) for pattern in MEMORY_QUERY_PATTERNS)
|
||||
|
||||
|
||||
def _should_include_user_memories(query: str) -> bool:
|
||||
normalized_query = _normalize_query(query)
|
||||
if _contains_hint(normalized_query, GROUNDING_QUERY_HINTS):
|
||||
return False
|
||||
if _contains_hint(normalized_query, AVOID_USER_MEMORY_HINTS):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _should_include_summaries(query: str) -> bool:
|
||||
normalized_query = _normalize_query(query)
|
||||
if _contains_hint(normalized_query, GROUNDING_QUERY_HINTS):
|
||||
return False
|
||||
if _contains_hint(normalized_query, MEMORY_QUERY_HINTS):
|
||||
return False
|
||||
if _matches_memory_query_pattern(normalized_query):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def _build_user_memory_section(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
current_query: str,
|
||||
) -> str:
|
||||
memories = await recall_user_memories(db, user_id, current_query, top_k=5)
|
||||
if not memories:
|
||||
return ""
|
||||
|
||||
lines = []
|
||||
recalled_user_memories: list[UserMemory] = []
|
||||
for memory in memories:
|
||||
if isinstance(memory, UserMemory):
|
||||
memory_text = memory.content
|
||||
memory_type = memory.memory_type
|
||||
recalled_user_memories.append(memory)
|
||||
else:
|
||||
memory_text = memory.get("memory", memory.get("text", ""))
|
||||
memory_type = memory.get("memory_type")
|
||||
|
||||
if not memory_text:
|
||||
continue
|
||||
|
||||
if memory_type:
|
||||
lines.append(f" [{memory_type}] {memory_text}")
|
||||
else:
|
||||
lines.append(f" - {memory_text}")
|
||||
|
||||
if not lines:
|
||||
return ""
|
||||
|
||||
if recalled_user_memories:
|
||||
await _mark_memories_recalled(db, recalled_user_memories)
|
||||
return "【用户记忆】\n" + "\n".join(lines)
|
||||
|
||||
|
||||
async def _build_summary_section(db: AsyncSession, conversation_id: str) -> str:
|
||||
summaries = await get_summaries(db, conversation_id)
|
||||
if not summaries:
|
||||
return ""
|
||||
|
||||
recent = summaries[-2:]
|
||||
lines = [f"[对话摘要{i + 1}] {summary.summary_text}" for i, summary in enumerate(recent)]
|
||||
return "【之前对话摘要】\n" + "\n".join(lines)
|
||||
|
||||
|
||||
async def _build_brain_section(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
current_query: str,
|
||||
) -> str:
|
||||
brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3)
|
||||
if not brain_memories:
|
||||
return ""
|
||||
|
||||
lines = [f"- {memory.title}: {memory.content}" for memory in brain_memories]
|
||||
return "【知识大脑】\n" + "\n".join(lines)
|
||||
|
||||
|
||||
async def build_memory_context(
|
||||
db: AsyncSession,
|
||||
@@ -350,30 +544,33 @@ async def build_memory_context(
|
||||
构建完整的记忆上下文字符串,
|
||||
供注入到 Agent system prompt 中使用。
|
||||
"""
|
||||
parts = []
|
||||
parts: list[str] = []
|
||||
|
||||
memories = await recall_user_memories(db, user_id, current_query, top_k=5)
|
||||
if memories:
|
||||
lines = []
|
||||
for m in memories:
|
||||
memory_text = m.get("memory", m.get("text", ""))
|
||||
if memory_text:
|
||||
lines.append(f" - {memory_text}")
|
||||
if lines:
|
||||
parts.append("【用户记忆】\n" + "\n".join(lines))
|
||||
if _should_include_user_memories(current_query):
|
||||
user_memory_section = await _run_tolerated_section(
|
||||
db,
|
||||
"用户记忆召回",
|
||||
lambda: _build_user_memory_section(db, user_id, current_query),
|
||||
)
|
||||
if user_memory_section:
|
||||
parts.append(user_memory_section)
|
||||
|
||||
summaries = await get_summaries(db, conversation_id)
|
||||
if summaries:
|
||||
recent = summaries[-2:]
|
||||
lines = [f"[对话摘要{i + 1}] {s.summary_text}" for i, s in enumerate(recent)]
|
||||
parts.append("【之前对话摘要】\n" + "\n".join(lines))
|
||||
if _should_include_summaries(current_query):
|
||||
summary_section = await _run_tolerated_section(
|
||||
db,
|
||||
"对话摘要加载",
|
||||
lambda: _build_summary_section(db, conversation_id),
|
||||
)
|
||||
if summary_section:
|
||||
parts.append(summary_section)
|
||||
|
||||
brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3)
|
||||
if brain_memories:
|
||||
lines = []
|
||||
for memory in brain_memories:
|
||||
lines.append(f"- {memory.title}: {memory.content}")
|
||||
parts.append("【知识大脑】\n" + "\n".join(lines))
|
||||
brain_section = await _run_tolerated_section(
|
||||
db,
|
||||
"知识大脑召回",
|
||||
lambda: _build_brain_section(db, user_id, current_query),
|
||||
)
|
||||
if brain_section:
|
||||
parts.append(brain_section)
|
||||
|
||||
if not parts:
|
||||
return ""
|
||||
|
||||
167
backend/tests/backend/app/agents/test_agent_schemas.py
Normal file
167
backend/tests/backend/app/agents/test_agent_schemas.py
Normal file
@@ -0,0 +1,167 @@
|
||||
from app.agents.schemas.event import AgentEvent
|
||||
from app.agents.schemas.task import AgentTask, CollaborationBudget, InterruptRecord, RecoveryRecord, TaskResult
|
||||
|
||||
|
||||
def test_agent_task_accepts_day1_fields():
|
||||
task = AgentTask(
|
||||
task_id="task-1",
|
||||
title="Verify foundation",
|
||||
status="in_progress",
|
||||
owner_agent_id="executor",
|
||||
role="verifier",
|
||||
goal="check output",
|
||||
expected_evidence=[{"type": "assertion"}],
|
||||
evidence=[{"type": "log"}],
|
||||
result_summary="running",
|
||||
)
|
||||
|
||||
assert task.task_id == "task-1"
|
||||
assert task.owner_agent_id == "executor"
|
||||
assert task.status == "in_progress"
|
||||
assert task.expected_evidence == [{"type": "assertion"}]
|
||||
assert task.evidence == [{"type": "log"}]
|
||||
assert task.result_summary == "running"
|
||||
|
||||
|
||||
def test_agent_task_accepts_day3_runtime_fields():
|
||||
task = AgentTask(
|
||||
task_id="task-2",
|
||||
title="Recover interrupted collaboration",
|
||||
owner_agent_id="executor",
|
||||
parent_task_id="task-1",
|
||||
child_task_ids=["task-2a"],
|
||||
thread_id="thread-1",
|
||||
message_id="msg-1",
|
||||
message_index=2,
|
||||
interrupt_records=[
|
||||
InterruptRecord(
|
||||
interrupt_id="interrupt-1",
|
||||
reason="manual stop",
|
||||
requested_by="coordinator",
|
||||
)
|
||||
],
|
||||
recovery_records=[
|
||||
RecoveryRecord(
|
||||
recovery_id="recovery-1",
|
||||
source_interrupt_id="interrupt-1",
|
||||
resumed_from_task_id="task-2",
|
||||
resumed_from_thread_id="thread-1",
|
||||
strategy="resume_from_checkpoint",
|
||||
)
|
||||
],
|
||||
collaboration_budget=CollaborationBudget(
|
||||
mode="collaboration",
|
||||
max_parallel_tasks=2,
|
||||
remaining_parallel_tasks=1,
|
||||
max_tool_calls=4,
|
||||
remaining_tool_calls=3,
|
||||
max_iterations=5,
|
||||
remaining_iterations=4,
|
||||
escalation_threshold=1,
|
||||
metadata={"max_spawn_depth": 2},
|
||||
),
|
||||
)
|
||||
|
||||
assert task.parent_task_id == "task-1"
|
||||
assert task.child_task_ids == ["task-2a"]
|
||||
assert task.thread_id == "thread-1"
|
||||
assert task.message_id == "msg-1"
|
||||
assert task.message_index == 2
|
||||
assert task.interrupt_records[0].interrupt_id == "interrupt-1"
|
||||
assert task.recovery_records[0].recovery_id == "recovery-1"
|
||||
assert task.collaboration_budget.mode == "collaboration"
|
||||
assert task.collaboration_budget.metadata == {"max_spawn_depth": 2}
|
||||
|
||||
|
||||
def test_agent_event_accepts_day1_fields():
|
||||
event = AgentEvent(
|
||||
event_id="evt-1",
|
||||
event_type="agent.verify.completed",
|
||||
conversation_id="conv-1",
|
||||
agent_id="executor",
|
||||
sub_commander_id="executor_tasks",
|
||||
task_id="task-1",
|
||||
payload={"status": "passed"},
|
||||
severity="info",
|
||||
)
|
||||
|
||||
assert event.event_id == "evt-1"
|
||||
assert event.event_type == "agent.verify.completed"
|
||||
assert event.conversation_id == "conv-1"
|
||||
assert event.payload == {"status": "passed"}
|
||||
assert event.severity == "info"
|
||||
|
||||
|
||||
def test_agent_event_accepts_day3_trace_fields():
|
||||
event = AgentEvent(
|
||||
event_id="evt-2",
|
||||
event_type="agent.collaboration.budget.updated",
|
||||
conversation_id="conv-1",
|
||||
agent_id="coordinator",
|
||||
task_id="task-2",
|
||||
parent_task_id="task-1",
|
||||
child_task_id="task-2a",
|
||||
thread_id="thread-1",
|
||||
message_id="msg-3",
|
||||
interrupt_id="interrupt-1",
|
||||
recovery_id="recovery-1",
|
||||
payload={"remaining_parallel_tasks": 1},
|
||||
severity="warning",
|
||||
)
|
||||
|
||||
assert event.parent_task_id == "task-1"
|
||||
assert event.child_task_id == "task-2a"
|
||||
assert event.thread_id == "thread-1"
|
||||
assert event.message_id == "msg-3"
|
||||
assert event.interrupt_id == "interrupt-1"
|
||||
assert event.recovery_id == "recovery-1"
|
||||
assert event.severity == "warning"
|
||||
|
||||
|
||||
def test_task_result_supports_collaboration_result_fields():
|
||||
result = TaskResult(
|
||||
task_id="task-1",
|
||||
status="completed",
|
||||
summary="retrieval finished",
|
||||
evidence=[{"type": "source"}],
|
||||
owner_agent_id="librarian",
|
||||
next_action="handoff_to_analyst",
|
||||
)
|
||||
|
||||
assert result.status == "completed"
|
||||
assert result.owner_agent_id == "librarian"
|
||||
assert result.next_action == "handoff_to_analyst"
|
||||
|
||||
|
||||
def test_task_result_supports_day3_thread_budget_and_recovery_fields():
|
||||
result = TaskResult(
|
||||
task_id="task-2",
|
||||
status="blocked",
|
||||
owner_agent_id="executor",
|
||||
parent_task_id="task-1",
|
||||
child_task_ids=["task-2a"],
|
||||
thread_id="thread-1",
|
||||
message_id="msg-4",
|
||||
message_index=4,
|
||||
interrupt_records=[{"interrupt_id": "interrupt-1", "reason": "budget exceeded"}],
|
||||
recovery_records=[{"recovery_id": "recovery-1", "strategy": "resume_after_budget_reset"}],
|
||||
budget_snapshot=CollaborationBudget(
|
||||
mode="collaboration",
|
||||
max_parallel_tasks=2,
|
||||
remaining_parallel_tasks=0,
|
||||
max_tool_calls=4,
|
||||
remaining_tool_calls=0,
|
||||
),
|
||||
next_action="resume_after_budget_reset",
|
||||
)
|
||||
|
||||
assert result.parent_task_id == "task-1"
|
||||
assert result.child_task_ids == ["task-2a"]
|
||||
assert result.thread_id == "thread-1"
|
||||
assert result.message_id == "msg-4"
|
||||
assert result.message_index == 4
|
||||
assert result.interrupt_records[0].interrupt_id == "interrupt-1"
|
||||
assert result.recovery_records[0].recovery_id == "recovery-1"
|
||||
assert result.budget_snapshot.mode == "collaboration"
|
||||
assert result.budget_snapshot.remaining_parallel_tasks == 0
|
||||
assert result.next_action == "resume_after_budget_reset"
|
||||
File diff suppressed because it is too large
Load Diff
317
backend/tests/backend/app/agents/test_graph_system_messages.py
Normal file
317
backend/tests/backend/app/agents/test_graph_system_messages.py
Normal file
@@ -0,0 +1,317 @@
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
sys.modules.setdefault("trafilatura", Mock())
|
||||
|
||||
from app.agents.graph import _build_system_messages, _run_sub_commander
|
||||
from app.agents.state import AgentRole
|
||||
|
||||
|
||||
def _base_state(message: str, user_llm_config: dict | None = None) -> dict:
|
||||
return {
|
||||
"messages": [HumanMessage(content=message)],
|
||||
"user_id": "u1",
|
||||
"conversation_id": "c1",
|
||||
"current_agent": AgentRole.MASTER,
|
||||
"active_agents": [AgentRole.MASTER],
|
||||
"current_sub_commander": None,
|
||||
"active_sub_commanders": [],
|
||||
"sub_commander_trace": [],
|
||||
"pending_tasks": [],
|
||||
"completed_tasks": [],
|
||||
"tool_calls": [],
|
||||
"last_tool_result": None,
|
||||
"action_results": [],
|
||||
"created_entities": [],
|
||||
"tool_strategy_used": None,
|
||||
"provider_capabilities": None,
|
||||
"fallback_parse_error": None,
|
||||
"knowledge_context": None,
|
||||
"graph_context": None,
|
||||
"schedule_context_summary": None,
|
||||
"plan": None,
|
||||
"plan_steps": [],
|
||||
"analysis_report": None,
|
||||
"final_response": None,
|
||||
"should_respond": True,
|
||||
"memory_context": "memory context",
|
||||
"current_datetime_context": "CURRENT_TIME: 2026-03-28T12:00:00+08:00",
|
||||
"current_datetime_reference": {
|
||||
"current_time_iso": "2026-03-28T12:00:00+08:00",
|
||||
"current_date_iso": "2026-03-28",
|
||||
"timezone": "UTC",
|
||||
},
|
||||
"user_llm_config": user_llm_config,
|
||||
}
|
||||
|
||||
|
||||
class FakeTool:
|
||||
def __init__(self, name: str, result: str):
|
||||
self.name = name
|
||||
self.result = result
|
||||
self.invocations: list[dict] = []
|
||||
|
||||
def invoke(self, args: dict):
|
||||
self.invocations.append(args)
|
||||
return self.result
|
||||
|
||||
|
||||
class SingleSystemMessageLLM:
|
||||
def __init__(self):
|
||||
self.calls = 0
|
||||
self.system_message_counts: list[int] = []
|
||||
self._jarvis_provider_capabilities = SimpleNamespace(
|
||||
provider="minimax",
|
||||
supports_native_tools=False,
|
||||
preferred_tool_strategy="json_fallback",
|
||||
)
|
||||
|
||||
async def ainvoke(self, messages):
|
||||
self.calls += 1
|
||||
self.system_message_counts.append(
|
||||
sum(1 for message in messages if getattr(message, "type", None) == "system")
|
||||
)
|
||||
if self.system_message_counts[-1] != 1:
|
||||
raise AssertionError(
|
||||
f"expected exactly one system message, got {self.system_message_counts[-1]}"
|
||||
)
|
||||
if self.calls == 1:
|
||||
return AIMessage(
|
||||
content=(
|
||||
'{"mode":"tool_call","tool_calls":[{"name":"create_reminder",'
|
||||
'"arguments":{"title":"blanket","reminder_at":"\\u660e\\u5929 09:00"}}]}'
|
||||
)
|
||||
)
|
||||
return AIMessage(content="created reminder for blanket")
|
||||
|
||||
|
||||
def test_build_system_messages_includes_structured_continuity_summary():
|
||||
state = _base_state("创建")
|
||||
state["pending_action"] = {
|
||||
"type": "schedule_creation",
|
||||
"summary": "为周报安排明天下午提醒",
|
||||
"status": "pending",
|
||||
}
|
||||
state["routing_decision"] = {
|
||||
"target_agent": AgentRole.SCHEDULE_PLANNER.value,
|
||||
"reason": "continue_pending_action",
|
||||
}
|
||||
state["continuity_state"] = {"status": "fresh"}
|
||||
|
||||
messages = _build_system_messages(
|
||||
state,
|
||||
"manager prompt",
|
||||
AgentRole.SCHEDULE_PLANNER,
|
||||
"schedule_planning",
|
||||
)
|
||||
|
||||
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
|
||||
assert "pending_action" in system_text
|
||||
assert "schedule_creation" in system_text
|
||||
assert "continue_pending_action" in system_text
|
||||
assert "为周报安排明天下午提醒" in system_text
|
||||
|
||||
|
||||
def test_build_system_messages_skips_structured_continuity_when_pending_action_is_not_pending():
|
||||
state = _base_state("创建")
|
||||
state["pending_action"] = {
|
||||
"type": "schedule_creation",
|
||||
"summary": "为周报安排明天下午提醒",
|
||||
"status": "completed",
|
||||
}
|
||||
state["routing_decision"] = {
|
||||
"target_agent": AgentRole.SCHEDULE_PLANNER.value,
|
||||
"reason": "continue_pending_action",
|
||||
}
|
||||
state["continuity_state"] = {"status": "fresh"}
|
||||
|
||||
messages = _build_system_messages(
|
||||
state,
|
||||
"manager prompt",
|
||||
AgentRole.SCHEDULE_PLANNER,
|
||||
"schedule_planning",
|
||||
)
|
||||
|
||||
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
|
||||
assert "structured_continuity" not in system_text
|
||||
assert "continue_pending_action" not in system_text
|
||||
|
||||
|
||||
def test_build_system_messages_skips_structured_continuity_when_routing_reason_is_not_continuation():
|
||||
state = _base_state("创建")
|
||||
state["pending_action"] = {
|
||||
"type": "schedule_creation",
|
||||
"summary": "为周报安排明天下午提醒",
|
||||
"status": "pending",
|
||||
}
|
||||
state["routing_decision"] = {
|
||||
"target_agent": AgentRole.SCHEDULE_PLANNER.value,
|
||||
"reason": "initial_schedule_detection",
|
||||
}
|
||||
state["continuity_state"] = {"status": "fresh"}
|
||||
|
||||
messages = _build_system_messages(
|
||||
state,
|
||||
"manager prompt",
|
||||
AgentRole.SCHEDULE_PLANNER,
|
||||
"schedule_planning",
|
||||
)
|
||||
|
||||
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
|
||||
assert "structured_continuity" not in system_text
|
||||
assert "continue_pending_action" not in system_text
|
||||
|
||||
|
||||
def test_build_system_messages_skips_structured_continuity_when_routing_decision_missing():
|
||||
state = _base_state("创建")
|
||||
state["pending_action"] = {
|
||||
"type": "schedule_creation",
|
||||
"summary": "为周报安排明天下午提醒",
|
||||
}
|
||||
state["routing_decision"] = None
|
||||
|
||||
messages = _build_system_messages(
|
||||
state,
|
||||
"manager prompt",
|
||||
AgentRole.SCHEDULE_PLANNER,
|
||||
"schedule_planning",
|
||||
)
|
||||
|
||||
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
|
||||
assert "pending_action" not in system_text
|
||||
assert "schedule_creation" not in system_text
|
||||
assert "为周报安排明天下午提醒" not in system_text
|
||||
|
||||
|
||||
def test_build_system_messages_skips_stale_structured_continuity_for_unrelated_new_request():
|
||||
state = _base_state(
|
||||
"帮我搜索 Rust 异步 trait 最佳实践",
|
||||
{
|
||||
"provider": "openai",
|
||||
"model": "MiniMax-M2.7-highspeed",
|
||||
"base_url": "https://api.minimaxi.com/v1",
|
||||
},
|
||||
)
|
||||
state["current_agent"] = AgentRole.SCHEDULE_PLANNER
|
||||
state["pending_action"] = {
|
||||
"type": "schedule_creation",
|
||||
"summary": "为周报安排明天下午提醒",
|
||||
"status": "pending",
|
||||
}
|
||||
state["routing_decision"] = {
|
||||
"target_agent": AgentRole.SCHEDULE_PLANNER.value,
|
||||
"reason": "continue_pending_action",
|
||||
}
|
||||
state["continuity_state"] = {
|
||||
"status": "stale",
|
||||
"override_reason": "new_explicit_request",
|
||||
}
|
||||
|
||||
messages = _build_system_messages(
|
||||
state,
|
||||
"manager prompt",
|
||||
AgentRole.SCHEDULE_PLANNER,
|
||||
"schedule_planning",
|
||||
)
|
||||
|
||||
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
|
||||
assert "structured_continuity" not in system_text
|
||||
assert "pending_action" not in system_text
|
||||
assert "continue_pending_action" not in system_text
|
||||
|
||||
|
||||
def test_build_system_messages_uses_role_scoped_context_instead_of_raw_memory_blob():
|
||||
state = _base_state("帮我搜索 Rust 异步 trait 最佳实践")
|
||||
state["memory_context"] = "【用户记忆】\n- 用户喜欢燕麦拿铁。\n\n【之前对话摘要】\n[对话摘要1] 之前聊过提醒。\n\n【知识大脑】\n- Rust Async: trait object 需要 pin。"
|
||||
state["schedule_context_summary"] = "【用户记忆】\n- 用户喜欢燕麦拿铁。\n\n【之前对话摘要】\n[对话摘要1] 之前聊过提醒。"
|
||||
state["knowledge_context"] = "【知识大脑】\n- Rust Async: trait object 需要 pin。"
|
||||
state["analysis_report"] = "【之前对话摘要】\n[对话摘要1] 之前聊过提醒。\n\n【知识大脑】\n- Rust Async: trait object 需要 pin。"
|
||||
|
||||
messages = _build_system_messages(
|
||||
state,
|
||||
"manager prompt",
|
||||
AgentRole.LIBRARIAN,
|
||||
"librarian_retrieval",
|
||||
)
|
||||
|
||||
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
|
||||
assert "角色上下文" in system_text
|
||||
assert "【知识大脑】" in system_text
|
||||
assert "Rust Async" in system_text
|
||||
assert "用户喜欢燕麦拿铁" not in system_text
|
||||
assert "记忆上下文" not in system_text
|
||||
|
||||
|
||||
def test_build_system_messages_keeps_fresh_structured_continuity_for_matching_followup():
|
||||
state = _base_state(
|
||||
"创建",
|
||||
{
|
||||
"provider": "openai",
|
||||
"model": "MiniMax-M2.7-highspeed",
|
||||
"base_url": "https://api.minimaxi.com/v1",
|
||||
},
|
||||
)
|
||||
state["current_agent"] = AgentRole.SCHEDULE_PLANNER
|
||||
state["pending_action"] = {
|
||||
"type": "schedule_creation",
|
||||
"summary": "为周报安排明天下午提醒",
|
||||
"status": "pending",
|
||||
}
|
||||
state["routing_decision"] = {
|
||||
"target_agent": AgentRole.SCHEDULE_PLANNER.value,
|
||||
"reason": "continue_pending_action",
|
||||
}
|
||||
state["continuity_state"] = {
|
||||
"status": "fresh",
|
||||
}
|
||||
|
||||
messages = _build_system_messages(
|
||||
state,
|
||||
"manager prompt",
|
||||
AgentRole.SCHEDULE_PLANNER,
|
||||
"schedule_planning",
|
||||
)
|
||||
|
||||
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
|
||||
assert "pending_action" in system_text
|
||||
assert "continue_pending_action" in system_text
|
||||
|
||||
|
||||
async def test_run_sub_commander_coalesces_system_messages_for_openai_compatible_provider(
|
||||
monkeypatch,
|
||||
):
|
||||
fake_llm = SingleSystemMessageLLM()
|
||||
fake_tool = FakeTool("create_reminder", "created reminder: blanket @ tomorrow 09:00")
|
||||
|
||||
monkeypatch.setattr("app.agents.graph._get_llm_for_state", lambda state: fake_llm)
|
||||
monkeypatch.setitem(
|
||||
__import__("app.agents.graph", fromlist=["SUB_COMMANDER_TOOLSETS"]).SUB_COMMANDER_TOOLSETS,
|
||||
"schedule_planning",
|
||||
[fake_tool],
|
||||
)
|
||||
|
||||
state = _base_state(
|
||||
"给我设置明天的提醒,提醒我收被子",
|
||||
{
|
||||
"provider": "openai",
|
||||
"model": "MiniMax-M2.7-highspeed",
|
||||
"base_url": "https://api.minimaxi.com/v1",
|
||||
},
|
||||
)
|
||||
state["current_agent"] = AgentRole.SCHEDULE_PLANNER
|
||||
|
||||
result = await _run_sub_commander(
|
||||
state,
|
||||
AgentRole.SCHEDULE_PLANNER,
|
||||
"manager prompt",
|
||||
"给我设置明天的提醒,提醒我收被子",
|
||||
use_tools=True,
|
||||
)
|
||||
|
||||
assert fake_llm.system_message_counts == [1, 1]
|
||||
assert result["tool_strategy_used"] == "json_fallback"
|
||||
assert fake_tool.invocations == [{"title": "blanket", "reminder_at": "2026-03-29T09:00:00"}]
|
||||
assert result["final_response"] == "created reminder for blanket"
|
||||
@@ -1,4 +1,4 @@
|
||||
from app.agents.prompts import MASTER_SYSTEM_PROMPT
|
||||
from app.agents.prompts import COORDINATOR_SYSTEM_PROMPT, MASTER_SYSTEM_PROMPT
|
||||
|
||||
|
||||
def test_master_prompt_forbids_subagent_rollcall_in_simple_greetings():
|
||||
@@ -10,3 +10,10 @@ def test_master_prompt_does_not_include_full_canned_answers_for_greetings_or_ide
|
||||
assert 'Jarvis:您好。我在。' not in MASTER_SYSTEM_PROMPT
|
||||
assert 'Jarvis:我是 Jarvis。' not in MASTER_SYSTEM_PROMPT
|
||||
assert 'Jarvis:主要做三件事。' not in MASTER_SYSTEM_PROMPT
|
||||
|
||||
|
||||
def test_coordinator_prompt_limits_collaboration_scope():
|
||||
assert "2~4 个子任务" in COORDINATOR_SYSTEM_PROMPT
|
||||
assert "禁止无限递归拆分" in COORDINATOR_SYSTEM_PROMPT
|
||||
assert "schedule_planner" in COORDINATOR_SYSTEM_PROMPT
|
||||
assert "librarian" in COORDINATOR_SYSTEM_PROMPT
|
||||
|
||||
@@ -5,11 +5,13 @@ from app.agents.prompts import (
|
||||
SUB_COMMANDER_PROMPTS_BY_KEY,
|
||||
TOP_LEVEL_SYSTEM_PROMPTS_BY_KEY,
|
||||
)
|
||||
from app.agents.registry import build_registry_indexes, load_builtin_registry_bundle
|
||||
from app.agents.registry import build_registry_indexes, load_builtin_registry_bundle, load_builtin_registry_indexes
|
||||
from app.agents.registry.indexes import summarize_registry_indexes
|
||||
from app.agents.registry.models import (
|
||||
AgentManifest,
|
||||
CapabilityManifest,
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
SpecialistTemplateManifest,
|
||||
SubCommanderManifest,
|
||||
)
|
||||
@@ -251,17 +253,34 @@ def test_builtin_capabilities_reference_actual_runtime_tool_names() -> None:
|
||||
assert manifest_tool_names == expected_tool_names
|
||||
|
||||
|
||||
def test_builtin_sub_commander_capabilities_match_runtime_toolsets() -> None:
|
||||
capabilities_by_tool_name = {
|
||||
manifest.tool_name: manifest.capability_id for manifest in BUILTIN_CAPABILITY_MANIFESTS
|
||||
}
|
||||
def test_builtin_capability_metadata_distinguishes_read_and_write_surfaces() -> None:
|
||||
capability_by_id = {manifest.capability_id: manifest for manifest in BUILTIN_CAPABILITY_MANIFESTS}
|
||||
|
||||
for sub_commander in BUILTIN_SUB_COMMANDER_MANIFESTS:
|
||||
expected_capability_ids = {
|
||||
capabilities_by_tool_name[tool.name]
|
||||
for tool in SUB_COMMANDER_TOOLSETS[sub_commander.sub_commander_id]
|
||||
}
|
||||
assert set(sub_commander.capability_ids) == expected_capability_ids
|
||||
assert capability_by_id["get_tasks"].permission_class == PermissionClass.READ
|
||||
assert capability_by_id["get_tasks"].side_effect_scope == SideEffectScope.NONE
|
||||
assert capability_by_id["get_tasks"].supports_retry is True
|
||||
assert capability_by_id["get_tasks"].idempotent is True
|
||||
assert capability_by_id["get_tasks"].safe_for_parallel_use is True
|
||||
assert capability_by_id["get_tasks"].requires_confirmation is False
|
||||
|
||||
assert capability_by_id["create_reminder"].permission_class == PermissionClass.WRITE
|
||||
assert capability_by_id["create_reminder"].side_effect_scope == SideEffectScope.LOCAL_STATE
|
||||
assert capability_by_id["create_reminder"].supports_retry is False
|
||||
assert capability_by_id["create_reminder"].idempotent is False
|
||||
assert capability_by_id["create_reminder"].safe_for_parallel_use is False
|
||||
assert capability_by_id["create_reminder"].requires_confirmation is True
|
||||
|
||||
assert capability_by_id["web_search"].permission_class == PermissionClass.EXTERNAL
|
||||
assert capability_by_id["web_search"].side_effect_scope == SideEffectScope.NETWORK
|
||||
|
||||
|
||||
def test_load_builtin_registry_indexes_is_cached_and_matches_bundle_indexes() -> None:
|
||||
cached = load_builtin_registry_indexes()
|
||||
rebuilt = build_registry_indexes(load_builtin_registry_bundle())
|
||||
|
||||
assert cached is load_builtin_registry_indexes()
|
||||
assert cached.capability_id_by_tool_name == rebuilt.capability_id_by_tool_name
|
||||
assert cached.capability_by_id["create_reminder"].requires_confirmation is True
|
||||
|
||||
|
||||
def test_builtin_manifests_form_a_valid_registry_bundle() -> None:
|
||||
@@ -288,6 +307,7 @@ def test_build_registry_indexes_exposes_manifest_lookups_by_id() -> None:
|
||||
indexes = build_registry_indexes(bundle)
|
||||
|
||||
assert indexes.agent_by_id
|
||||
assert indexes.agent_by_role_value
|
||||
assert indexes.sub_commander_by_id
|
||||
assert indexes.capability_by_id
|
||||
assert isinstance(indexes.specialist_template_by_id, Mapping)
|
||||
@@ -343,6 +363,14 @@ def test_build_registry_indexes_exposes_prompt_keys_skill_context_keys_and_capab
|
||||
sub_commander.sub_commander_id: tuple(sub_commander.capability_ids)
|
||||
for sub_commander in bundle.sub_commanders
|
||||
}
|
||||
assert indexes.agent_by_role_value == {
|
||||
agent.role_value: agent for agent in bundle.agents
|
||||
}
|
||||
assert indexes.spawnable_role_values_by_agent_id == {
|
||||
agent.agent_id: tuple(agent.allowed_spawn_role_values)
|
||||
for agent in bundle.agents
|
||||
if agent.can_spawn_children and agent.allowed_spawn_role_values
|
||||
}
|
||||
|
||||
|
||||
def test_validate_registry_bundle_accepts_loaded_builtin_registry_bundle() -> None:
|
||||
|
||||
135
backend/tests/backend/app/agents/test_schema_verifier.py
Normal file
135
backend/tests/backend/app/agents/test_schema_verifier.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from app.agents.schemas import AgentEvent, AgentTask, TaskResult
|
||||
from app.agents.schemas.task import CollaborationBudget, InterruptRecord, RecoveryRecord
|
||||
from app.agents.state import initial_state
|
||||
from app.agents.verifier import apply_verification_verdict, normalize_task_result, verify_task_result
|
||||
|
||||
|
||||
def test_agent_task_supports_day3_interrupt_recovery_and_budget_fields():
|
||||
interrupt = InterruptRecord(interrupt_id="interrupt-1", reason="user_cancel")
|
||||
recovery = RecoveryRecord(recovery_id="recovery-1", source_interrupt_id="interrupt-1", resumed_from_task_id="task-1")
|
||||
budget = CollaborationBudget(
|
||||
mode="collaboration",
|
||||
max_parallel_tasks=3,
|
||||
remaining_parallel_tasks=2,
|
||||
max_tool_calls=6,
|
||||
remaining_tool_calls=4,
|
||||
metadata={"phase": "day3"},
|
||||
)
|
||||
|
||||
task = AgentTask(
|
||||
task_id="task-1",
|
||||
title="Recover interrupted collaboration task",
|
||||
owner_agent_id="analyst",
|
||||
role="analyst",
|
||||
parent_task_id="parent-1",
|
||||
child_task_ids=["child-1"],
|
||||
thread_id="thread-1",
|
||||
message_id="message-1",
|
||||
message_index=3,
|
||||
interrupt_records=[interrupt],
|
||||
recovery_records=[recovery],
|
||||
collaboration_budget=budget,
|
||||
)
|
||||
|
||||
payload = task.model_dump(mode="json")
|
||||
|
||||
assert payload["parent_task_id"] == "parent-1"
|
||||
assert payload["child_task_ids"] == ["child-1"]
|
||||
assert payload["thread_id"] == "thread-1"
|
||||
assert payload["message_id"] == "message-1"
|
||||
assert payload["message_index"] == 3
|
||||
assert payload["interrupt_records"][0]["interrupt_id"] == "interrupt-1"
|
||||
assert payload["recovery_records"][0]["recovery_id"] == "recovery-1"
|
||||
assert payload["collaboration_budget"]["mode"] == "collaboration"
|
||||
assert payload["collaboration_budget"]["remaining_tool_calls"] == 4
|
||||
|
||||
|
||||
def test_agent_event_supports_day3_thread_interrupt_and_recovery_metadata():
|
||||
event = AgentEvent(
|
||||
event_id="evt-1",
|
||||
event_type="agent.task.recovered",
|
||||
conversation_id="conv-1",
|
||||
agent_id="executor",
|
||||
task_id="task-1",
|
||||
parent_task_id="parent-1",
|
||||
child_task_id="child-1",
|
||||
thread_id="thread-1",
|
||||
message_id="message-1",
|
||||
interrupt_id="interrupt-1",
|
||||
recovery_id="recovery-1",
|
||||
severity="warning",
|
||||
payload={"status": "resumed"},
|
||||
)
|
||||
|
||||
payload = event.model_dump(mode="json")
|
||||
|
||||
assert payload["event_type"] == "agent.task.recovered"
|
||||
assert payload["parent_task_id"] == "parent-1"
|
||||
assert payload["child_task_id"] == "child-1"
|
||||
assert payload["thread_id"] == "thread-1"
|
||||
assert payload["message_id"] == "message-1"
|
||||
assert payload["interrupt_id"] == "interrupt-1"
|
||||
assert payload["recovery_id"] == "recovery-1"
|
||||
assert payload["severity"] == "warning"
|
||||
|
||||
|
||||
def test_normalize_task_result_preserves_day3_metadata_fields():
|
||||
result = normalize_task_result(
|
||||
{
|
||||
"task_id": "task-1",
|
||||
"status": "completed",
|
||||
"summary": "Recovered successfully.",
|
||||
"owner_agent_id": "executor",
|
||||
"parent_task_id": "parent-1",
|
||||
"child_task_ids": ["child-1"],
|
||||
"thread_id": "thread-1",
|
||||
"message_id": "message-1",
|
||||
"message_index": 2,
|
||||
"interrupt_records": [{"interrupt_id": "interrupt-1", "reason": "user_pause"}],
|
||||
"recovery_records": [{"recovery_id": "recovery-1", "source_interrupt_id": "interrupt-1"}],
|
||||
"budget_snapshot": {"mode": "collaboration", "max_parallel_tasks": 4},
|
||||
"next_action": "notify_user",
|
||||
"output_data": {"ok": True},
|
||||
}
|
||||
)
|
||||
|
||||
assert result.parent_task_id == "parent-1"
|
||||
assert result.child_task_ids == ["child-1"]
|
||||
assert result.thread_id == "thread-1"
|
||||
assert result.message_id == "message-1"
|
||||
assert result.message_index == 2
|
||||
assert result.interrupt_records[0].interrupt_id == "interrupt-1"
|
||||
assert result.recovery_records[0].recovery_id == "recovery-1"
|
||||
assert result.budget_snapshot.mode == "collaboration"
|
||||
assert result.budget_snapshot.max_parallel_tasks == 4
|
||||
assert result.next_action == "notify_user"
|
||||
assert result.output_data == {"ok": True}
|
||||
|
||||
|
||||
def test_apply_verification_verdict_updates_state_with_recovery_evidence():
|
||||
state = initial_state("u1", "c1")
|
||||
|
||||
verdict = verify_task_result(
|
||||
status="passed",
|
||||
summary="Interrupt and recovery chain verified.",
|
||||
evidence=[
|
||||
{
|
||||
"task_id": "task-1",
|
||||
"thread_id": "thread-1",
|
||||
"interrupt_id": "interrupt-1",
|
||||
"recovery_id": "recovery-1",
|
||||
}
|
||||
],
|
||||
)
|
||||
updated_state = apply_verification_verdict(state, verdict)
|
||||
|
||||
assert updated_state["verification_status"] == "passed"
|
||||
assert updated_state["verification_summary"] == "Interrupt and recovery chain verified."
|
||||
assert updated_state["verification_evidence"] == [
|
||||
{
|
||||
"task_id": "task-1",
|
||||
"thread_id": "thread-1",
|
||||
"interrupt_id": "interrupt-1",
|
||||
"recovery_id": "recovery-1",
|
||||
}
|
||||
]
|
||||
@@ -47,3 +47,27 @@ def test_web_search_tool_returns_stable_message_when_unavailable(monkeypatch):
|
||||
result = web_search.func('Jarvis')
|
||||
|
||||
assert result == '网页搜索不可用: 网页搜索未启用或未配置'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_search_tool_runs_from_active_event_loop(monkeypatch):
|
||||
class FakeService:
|
||||
async def search(self, query: str, limit: int | None = None):
|
||||
assert query == 'Jarvis 最新更新'
|
||||
assert limit == 1
|
||||
return [
|
||||
FakeResult(
|
||||
title='Jarvis release notes',
|
||||
url='https://example.com/jarvis-release',
|
||||
snippet='Latest Jarvis changes.',
|
||||
source='duckduckgo',
|
||||
published_at='2026-03-29',
|
||||
)
|
||||
]
|
||||
|
||||
monkeypatch.setattr('app.services.web_search_service.WebSearchService', FakeService)
|
||||
|
||||
result = web_search.func('Jarvis 最新更新', top_k=1)
|
||||
|
||||
assert '[1] Jarvis release notes' in result
|
||||
assert '链接: https://example.com/jarvis-release' in result
|
||||
|
||||
@@ -2,6 +2,7 @@ import pytest
|
||||
|
||||
from app.agents.tools import forum as forum_tools
|
||||
from app.agents.tools import schedule as schedule_tools
|
||||
from app.agents.tools import search as search_tools
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
|
||||
@@ -12,6 +13,7 @@ from app.agents.tools import task as task_tools
|
||||
(task_tools, "task"),
|
||||
(schedule_tools, "schedule"),
|
||||
(forum_tools, "forum"),
|
||||
(search_tools, "search"),
|
||||
],
|
||||
)
|
||||
async def test_run_async_bridge_works_inside_running_event_loop(module, label):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user