Compare commits
8 Commits
phase1-reg
...
fca7a7cf3d
| Author | SHA1 | Date | |
|---|---|---|---|
| fca7a7cf3d | |||
| d18167826e | |||
| 88955ed550 | |||
| a3fe4d24fc | |||
| e5bd492d74 | |||
| a7b6b5eb90 | |||
| aa0ef0fbea | |||
| 4972b4e6b1 |
220
backend/app/agents/background/executor.py
Normal file
220
backend/app/agents/background/executor.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
"""Background task executor - Phase 10.4"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .manager import (
|
||||||
|
BackgroundTask,
|
||||||
|
BackgroundTaskManager,
|
||||||
|
BackgroundTaskStatus,
|
||||||
|
get_background_task_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BackgroundExecutor:
|
||||||
|
"""Executes background tasks with error handling and result storage.
|
||||||
|
|
||||||
|
Provides methods to execute tasks synchronously or asynchronously,
|
||||||
|
with full integration into BackgroundTaskManager for tracking.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, task_manager: BackgroundTaskManager | None = None):
|
||||||
|
"""Initialize the executor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_manager: Optional BackgroundTaskManager instance.
|
||||||
|
If not provided, uses the global singleton.
|
||||||
|
"""
|
||||||
|
self._task_manager = task_manager or get_background_task_manager()
|
||||||
|
self._executors: dict[str, asyncio.Task] = {}
|
||||||
|
|
||||||
|
async def execute_task(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
func: Callable[..., Coroutine[Any, Any, Any]],
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> BackgroundTask:
|
||||||
|
"""Execute a specific task by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Unique task identifier
|
||||||
|
func: Async function to execute
|
||||||
|
*args: Positional arguments for the function
|
||||||
|
**kwargs: Keyword arguments for the function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The BackgroundTask with result or error
|
||||||
|
"""
|
||||||
|
# Get or create task record
|
||||||
|
task = self._task_manager.get_task_status(task_id)
|
||||||
|
if task is None:
|
||||||
|
# Create a new task record if one doesn't exist
|
||||||
|
task = BackgroundTask(
|
||||||
|
id=task_id,
|
||||||
|
name=f"executor_task_{task_id}",
|
||||||
|
status=BackgroundTaskStatus.PENDING,
|
||||||
|
created_at=datetime.now(),
|
||||||
|
)
|
||||||
|
self._task_manager._tasks[task_id] = task
|
||||||
|
|
||||||
|
# Update status to running
|
||||||
|
task.status = BackgroundTaskStatus.RUNNING
|
||||||
|
task.started_at = datetime.now()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Execute the async function
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
task.status = BackgroundTaskStatus.COMPLETED
|
||||||
|
task.result = result
|
||||||
|
except Exception as e:
|
||||||
|
task.status = BackgroundTaskStatus.FAILED
|
||||||
|
task.error = f"{type(e).__name__}: {str(e)}"
|
||||||
|
task.result = None
|
||||||
|
finally:
|
||||||
|
task.completed_at = datetime.now()
|
||||||
|
# Clean up executor reference
|
||||||
|
if task_id in self._executors:
|
||||||
|
del self._executors[task_id]
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
async def execute_async(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
func: Callable[..., Coroutine[Any, Any, Any]],
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Execute a task asynchronously in the background.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Unique task identifier
|
||||||
|
func: Async function to execute
|
||||||
|
*args: Positional arguments for the function
|
||||||
|
**kwargs: Keyword arguments for the function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The task ID
|
||||||
|
"""
|
||||||
|
# Create task record if it doesn't exist
|
||||||
|
if self._task_manager.get_task_status(task_id) is None:
|
||||||
|
self._task_manager._tasks[task_id] = BackgroundTask(
|
||||||
|
id=task_id,
|
||||||
|
name=f"async_task_{task_id}",
|
||||||
|
status=BackgroundTaskStatus.PENDING,
|
||||||
|
created_at=datetime.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create and store the asyncio task
|
||||||
|
async_task = asyncio.create_task(self.execute_task(task_id, func, *args, **kwargs))
|
||||||
|
self._executors[task_id] = async_task
|
||||||
|
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
def cancel_task(self, task_id: str) -> bool:
|
||||||
|
"""Cancel a running task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The task ID to cancel
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if cancelled, False if not found or not running
|
||||||
|
"""
|
||||||
|
if task_id not in self._executors:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._executors[task_id].cancel()
|
||||||
|
del self._executors[task_id]
|
||||||
|
|
||||||
|
# Update task status
|
||||||
|
task = self._task_manager.get_task_status(task_id)
|
||||||
|
if task:
|
||||||
|
task.status = BackgroundTaskStatus.CANCELLED
|
||||||
|
task.completed_at = datetime.now()
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_task_result(self, task_id: str) -> Any:
|
||||||
|
"""Get the result of a completed task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The task ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The task result or None if not found/not completed
|
||||||
|
"""
|
||||||
|
task = self._task_manager.get_task_status(task_id)
|
||||||
|
if task and task.status == BackgroundTaskStatus.COMPLETED:
|
||||||
|
return task.result
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_task_error(self, task_id: str) -> str | None:
|
||||||
|
"""Get the error of a failed task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The task ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The error message or None if not found/not failed
|
||||||
|
"""
|
||||||
|
task = self._task_manager.get_task_status(task_id)
|
||||||
|
if task and task.status == BackgroundTaskStatus.FAILED:
|
||||||
|
return task.error
|
||||||
|
return None
|
||||||
|
|
||||||
|
def is_task_running(self, task_id: str) -> bool:
|
||||||
|
"""Check if a task is currently running.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The task ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if running, False otherwise
|
||||||
|
"""
|
||||||
|
return task_id in self._executors
|
||||||
|
|
||||||
|
def wait_for_task(self, task_id: str, timeout: float | None = None) -> BackgroundTask:
|
||||||
|
"""Wait for a task to complete.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The task ID to wait for
|
||||||
|
timeout: Optional timeout in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The completed BackgroundTask
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
asyncio.TimeoutError: If task doesn't complete within timeout
|
||||||
|
asyncio.CancelledError: If task is cancelled
|
||||||
|
"""
|
||||||
|
if task_id not in self._executors:
|
||||||
|
task = self._task_manager.get_task_status(task_id)
|
||||||
|
if task:
|
||||||
|
return task
|
||||||
|
raise ValueError(f"Task {task_id} not found")
|
||||||
|
|
||||||
|
async def wait_task() -> BackgroundTask:
|
||||||
|
await self._executors[task_id]
|
||||||
|
return self._task_manager.get_task_status(task_id)
|
||||||
|
|
||||||
|
return asyncio.run_until_complete(asyncio.wait_for(wait_task(), timeout=timeout))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def task_manager(self) -> BackgroundTaskManager:
|
||||||
|
"""Get the underlying task manager."""
|
||||||
|
return self._task_manager
|
||||||
|
|
||||||
|
|
||||||
|
# Global executor instance
|
||||||
|
_executor: BackgroundExecutor | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_background_executor() -> BackgroundExecutor:
|
||||||
|
"""Get the global BackgroundExecutor instance."""
|
||||||
|
global _executor
|
||||||
|
if _executor is None:
|
||||||
|
_executor = BackgroundExecutor()
|
||||||
|
return _executor
|
||||||
119
backend/app/agents/background/manager.py
Normal file
119
backend/app/agents/background/manager.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
"""后台任务系统 - Phase 10.4"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class BackgroundTaskStatus(Enum):
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
CANCELLED = "cancelled"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BackgroundTask:
|
||||||
|
"""后台任务"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
status: BackgroundTaskStatus
|
||||||
|
created_at: datetime
|
||||||
|
started_at: datetime | None = None
|
||||||
|
completed_at: datetime | None = None
|
||||||
|
result: Any = None
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class BackgroundTaskManager:
|
||||||
|
"""后台任务管理器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._tasks: dict[str, BackgroundTask] = {}
|
||||||
|
self._.coroutines: dict[str, asyncio.Task] = {}
|
||||||
|
|
||||||
|
def submit_task(self, name: str, coro: Any, *args, **kwargs) -> str:
|
||||||
|
"""提交后台任务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 任务名称
|
||||||
|
coro: 协程函数
|
||||||
|
*args: 位置参数
|
||||||
|
**kwargs: 关键字参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
任务 ID
|
||||||
|
"""
|
||||||
|
task_id = str(uuid.uuid4())[:8]
|
||||||
|
|
||||||
|
# 创建任务记录
|
||||||
|
self._tasks[task_id] = BackgroundTask(
|
||||||
|
id=task_id,
|
||||||
|
name=name,
|
||||||
|
status=BackgroundTaskStatus.PENDING,
|
||||||
|
created_at=datetime.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建 asyncio task
|
||||||
|
async def run_task():
|
||||||
|
self._tasks[task_id].status = BackgroundTaskStatus.RUNNING
|
||||||
|
self._tasks[task_id].started_at = datetime.now()
|
||||||
|
try:
|
||||||
|
result = await coro(*args, **kwargs)
|
||||||
|
self._tasks[task_id].status = BackgroundTaskStatus.COMPLETED
|
||||||
|
self._tasks[task_id].result = result
|
||||||
|
except Exception as e:
|
||||||
|
self._tasks[task_id].status = BackgroundTaskStatus.FAILED
|
||||||
|
self._tasks[task_id].error = str(e)
|
||||||
|
finally:
|
||||||
|
self._tasks[task_id].completed_at = datetime.now()
|
||||||
|
if task_id in self._coroutines:
|
||||||
|
del self._coroutines[task_id]
|
||||||
|
|
||||||
|
self._coroutines[task_id] = asyncio.create_task(run_task())
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
def cancel_task(self, task_id: str) -> bool:
|
||||||
|
"""取消任务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: 任务 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功取消
|
||||||
|
"""
|
||||||
|
if task_id not in self._tasks:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if task_id in self._coroutines:
|
||||||
|
self._coroutines[task_id].cancel()
|
||||||
|
del self._coroutines[task_id]
|
||||||
|
|
||||||
|
self._tasks[task_id].status = BackgroundTaskStatus.CANCELLED
|
||||||
|
self._tasks[task_id].completed_at = datetime.now()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_task_status(self, task_id: str) -> BackgroundTask | None:
|
||||||
|
"""获取任务状态"""
|
||||||
|
return self._tasks.get(task_id)
|
||||||
|
|
||||||
|
def list_tasks(self) -> list[BackgroundTask]:
|
||||||
|
"""列出所有任务"""
|
||||||
|
return list(self._tasks.values())
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
_manager: BackgroundTaskManager | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_background_task_manager() -> BackgroundTaskManager:
|
||||||
|
"""获取全局后台任务管理器"""
|
||||||
|
global _manager
|
||||||
|
if _manager is None:
|
||||||
|
_manager = BackgroundTaskManager()
|
||||||
|
return _manager
|
||||||
146
backend/app/agents/background/scheduler.py
Normal file
146
backend/app/agents/background/scheduler.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
"""Background task scheduler - Phase 10.4"""
|
||||||
|
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||||
|
from apscheduler.triggers.base import BaseTrigger
|
||||||
|
|
||||||
|
from .manager import BackgroundTaskManager, get_background_task_manager
|
||||||
|
|
||||||
|
|
||||||
|
class BackgroundScheduler:
|
||||||
|
"""Background task scheduler using APScheduler.
|
||||||
|
|
||||||
|
Integrates with BackgroundTaskManager for task tracking and execution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, task_manager: BackgroundTaskManager | None = None):
|
||||||
|
"""Initialize the scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_manager: Optional BackgroundTaskManager instance.
|
||||||
|
If not provided, uses the global singleton.
|
||||||
|
"""
|
||||||
|
self._scheduler = AsyncIOScheduler()
|
||||||
|
self._task_manager = task_manager or get_background_task_manager()
|
||||||
|
self._job_tasks: dict[str, str] = {} # Maps APScheduler job_id to task_id
|
||||||
|
|
||||||
|
def add_job(
|
||||||
|
self,
|
||||||
|
func: Callable[..., Coroutine[Any, Any, Any]],
|
||||||
|
trigger: BaseTrigger,
|
||||||
|
args: tuple[Any, ...] | None = None,
|
||||||
|
kwargs: dict[str, Any] | None = None,
|
||||||
|
id: str | None = None,
|
||||||
|
name: str | None = None,
|
||||||
|
**apscheduler_kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Add a job to the scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: Async function to execute
|
||||||
|
trigger: APScheduler trigger (date, interval, cron, etc.)
|
||||||
|
args: Positional arguments for the function
|
||||||
|
kwargs: Keyword arguments for the function
|
||||||
|
id: Unique job ID (auto-generated if not provided)
|
||||||
|
name: Job name for display purposes
|
||||||
|
**apscheduler_kwargs: Additional APScheduler options
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The job ID
|
||||||
|
"""
|
||||||
|
job_id = id or f"job_{len(self._job_tasks)}"
|
||||||
|
task_name = name or f"scheduled_task_{job_id}"
|
||||||
|
|
||||||
|
# Wrap the async function to integrate with BackgroundTaskManager
|
||||||
|
async def wrapped_func() -> None:
|
||||||
|
coro = func(*(args or ()), **(kwargs or {}))
|
||||||
|
task_id = self._task_manager.submit_task(task_name, coro)
|
||||||
|
self._job_tasks[job_id] = task_id
|
||||||
|
|
||||||
|
self._scheduler.add_job(
|
||||||
|
wrapped_func,
|
||||||
|
trigger=trigger,
|
||||||
|
id=job_id,
|
||||||
|
name=task_name,
|
||||||
|
**apscheduler_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return job_id
|
||||||
|
|
||||||
|
def remove_job(self, job_id: str) -> bool:
|
||||||
|
"""Remove a job from the scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_id: The ID of the job to remove
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if job was removed, False if job didn't exist
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._scheduler.remove_job(job_id)
|
||||||
|
# Clean up task mapping if exists
|
||||||
|
if job_id in self._job_tasks:
|
||||||
|
task_id = self._job_tasks.pop(job_id)
|
||||||
|
# Cancel the background task if still running
|
||||||
|
self._task_manager.cancel_task(task_id)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_jobs(self) -> list[dict[str, Any]]:
|
||||||
|
"""List all scheduled jobs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of job information dictionaries
|
||||||
|
"""
|
||||||
|
jobs = self._scheduler.get_jobs()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": job.id,
|
||||||
|
"name": job.name,
|
||||||
|
"next_run_time": job.next_run_time,
|
||||||
|
"trigger": str(job.trigger),
|
||||||
|
}
|
||||||
|
for job in jobs
|
||||||
|
]
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
"""Start the scheduler."""
|
||||||
|
if not self._scheduler.running:
|
||||||
|
self._scheduler.start()
|
||||||
|
|
||||||
|
def shutdown(self, wait: bool = True) -> None:
|
||||||
|
"""Shutdown the scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wait: Whether to wait for running jobs to complete
|
||||||
|
"""
|
||||||
|
if self._scheduler.running:
|
||||||
|
self._scheduler.shutdown(wait=wait)
|
||||||
|
|
||||||
|
def pause(self) -> None:
|
||||||
|
"""Pause the scheduler."""
|
||||||
|
self._scheduler.pause()
|
||||||
|
|
||||||
|
def resume(self) -> None:
|
||||||
|
"""Resume the scheduler."""
|
||||||
|
self._scheduler.resume()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def task_manager(self) -> BackgroundTaskManager:
|
||||||
|
"""Get the underlying task manager."""
|
||||||
|
return self._task_manager
|
||||||
|
|
||||||
|
|
||||||
|
# Global scheduler instance
|
||||||
|
_scheduler: BackgroundScheduler | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_background_scheduler() -> BackgroundScheduler:
|
||||||
|
"""Get the global BackgroundScheduler instance."""
|
||||||
|
global _scheduler
|
||||||
|
if _scheduler is None:
|
||||||
|
_scheduler = BackgroundScheduler()
|
||||||
|
return _scheduler
|
||||||
508
backend/app/agents/coordinator.py
Normal file
508
backend/app/agents/coordinator.py
Normal file
@@ -0,0 +1,508 @@
|
|||||||
|
"""Agent 协调整器 - Phase 10.5
|
||||||
|
|
||||||
|
统一协调所有 Agent 组件:TeamLeader, RemoteTransport, BackgroundTaskManager, SessionManager
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.background.manager import BackgroundTaskManager, get_background_task_manager
|
||||||
|
from app.agents.session.manager import AgentSession, create_agent_session, get_agent_session
|
||||||
|
from app.agents.team.leader import TeamLeader
|
||||||
|
from app.agents.transport.remote import RemoteTransport
|
||||||
|
|
||||||
|
|
||||||
|
class AgentCoordinator:
|
||||||
|
"""Agent 协调整器
|
||||||
|
|
||||||
|
统一协调所有 Agent 组件,提供单一入口处理各类 Agent 操作。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
background_manager: BackgroundTaskManager | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
background_manager: 后台任务管理器,None 则使用全局单例
|
||||||
|
"""
|
||||||
|
self._team_leaders: dict[str, TeamLeader] = {}
|
||||||
|
self._remote_transport = RemoteTransport()
|
||||||
|
self._background_manager = background_manager or get_background_task_manager()
|
||||||
|
self._sessions: dict[str, AgentSession] = {}
|
||||||
|
|
||||||
|
# === Team 协作方法 ===
|
||||||
|
|
||||||
|
def create_team(self, team_id: str, members: list[str]) -> dict[str, Any]:
|
||||||
|
"""创建团队
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: 团队 ID
|
||||||
|
members: 成员 ID 列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
团队创建结果
|
||||||
|
"""
|
||||||
|
if team_id in self._team_leaders:
|
||||||
|
return {"status": "error", "message": f"Team '{team_id}' already exists"}
|
||||||
|
|
||||||
|
leader = TeamLeader(team_id=team_id, members=members)
|
||||||
|
self._team_leaders[team_id] = leader
|
||||||
|
return {
|
||||||
|
"status": "created",
|
||||||
|
"team_id": team_id,
|
||||||
|
"members": members,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_team(self, team_id: str) -> TeamLeader | None:
|
||||||
|
"""获取团队
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: 团队 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TeamLeader 或 None
|
||||||
|
"""
|
||||||
|
return self._team_leaders.get(team_id)
|
||||||
|
|
||||||
|
def assign_task(self, team_id: str, description: str, member: str) -> dict[str, Any]:
|
||||||
|
"""创建并分配任务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: 团队 ID
|
||||||
|
description: 任务描述
|
||||||
|
member: 成员 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
分配结果
|
||||||
|
"""
|
||||||
|
leader = self._team_leaders.get(team_id)
|
||||||
|
if not leader:
|
||||||
|
return {"status": "error", "message": f"Team '{team_id}' not found"}
|
||||||
|
|
||||||
|
task_id = leader.create_task(description)
|
||||||
|
success = leader.assign_task(task_id, member)
|
||||||
|
return {
|
||||||
|
"status": "assigned" if success else "error",
|
||||||
|
"task_id": task_id,
|
||||||
|
"assignee": member,
|
||||||
|
}
|
||||||
|
|
||||||
|
def broadcast_task(self, team_id: str, description: str) -> dict[str, Any]:
|
||||||
|
"""广播任务给所有成员
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: 团队 ID
|
||||||
|
description: 任务描述
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
广播结果
|
||||||
|
"""
|
||||||
|
leader = self._team_leaders.get(team_id)
|
||||||
|
if not leader:
|
||||||
|
return {"status": "error", "message": f"Team '{team_id}' not found"}
|
||||||
|
|
||||||
|
task_ids = leader.broadcast_task(description)
|
||||||
|
return {
|
||||||
|
"status": "broadcast",
|
||||||
|
"team_id": team_id,
|
||||||
|
"task_ids": task_ids,
|
||||||
|
"member_count": len(leader.members),
|
||||||
|
}
|
||||||
|
|
||||||
|
def collect_team_results(self, team_id: str) -> dict[str, Any]:
|
||||||
|
"""收集团队任务结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: 团队 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
收集结果
|
||||||
|
"""
|
||||||
|
leader = self._team_leaders.get(team_id)
|
||||||
|
if not leader:
|
||||||
|
return {"status": "error", "message": f"Team '{team_id}' not found"}
|
||||||
|
|
||||||
|
results = leader.collect_results()
|
||||||
|
status = leader.get_team_status()
|
||||||
|
return {
|
||||||
|
"status": "collected",
|
||||||
|
"team_id": team_id,
|
||||||
|
"results": results,
|
||||||
|
"completed": status["completed"],
|
||||||
|
"failed": status["failed"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_team_status(self, team_id: str) -> dict[str, Any]:
|
||||||
|
"""获取团队状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: 团队 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
团队状态
|
||||||
|
"""
|
||||||
|
leader = self._team_leaders.get(team_id)
|
||||||
|
if not leader:
|
||||||
|
return {"status": "error", "message": f"Team '{team_id}' not found"}
|
||||||
|
|
||||||
|
return leader.get_team_status()
|
||||||
|
|
||||||
|
# === 后台任务方法 ===
|
||||||
|
|
||||||
|
def submit_background_task(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
coro: Any,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""提交后台任务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 任务名称
|
||||||
|
coro: 协程函数
|
||||||
|
*args: 位置参数
|
||||||
|
**kwargs: 关键字参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
提交结果
|
||||||
|
"""
|
||||||
|
task_id = self._background_manager.submit_task(name, coro, *args, **kwargs)
|
||||||
|
return {
|
||||||
|
"status": "submitted",
|
||||||
|
"task_id": task_id,
|
||||||
|
"name": name,
|
||||||
|
}
|
||||||
|
|
||||||
|
def cancel_background_task(self, task_id: str) -> dict[str, Any]:
|
||||||
|
"""取消后台任务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: 任务 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
取消结果
|
||||||
|
"""
|
||||||
|
success = self._background_manager.cancel_task(task_id)
|
||||||
|
return {
|
||||||
|
"status": "cancelled" if success else "error",
|
||||||
|
"task_id": task_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_background_task_status(self, task_id: str) -> dict[str, Any]:
|
||||||
|
"""获取后台任务状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: 任务 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
任务状态
|
||||||
|
"""
|
||||||
|
task = self._background_manager.get_task_status(task_id)
|
||||||
|
if not task:
|
||||||
|
return {"status": "error", "message": f"Task '{task_id}' not found"}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "found",
|
||||||
|
"task_id": task.id,
|
||||||
|
"name": task.name,
|
||||||
|
"task_status": task.status.value,
|
||||||
|
"result": task.result,
|
||||||
|
"error": task.error,
|
||||||
|
}
|
||||||
|
|
||||||
|
def list_background_tasks(self) -> dict[str, Any]:
|
||||||
|
"""列出所有后台任务
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
任务列表
|
||||||
|
"""
|
||||||
|
tasks = self._background_manager.list_tasks()
|
||||||
|
return {
|
||||||
|
"status": "list",
|
||||||
|
"count": len(tasks),
|
||||||
|
"tasks": [
|
||||||
|
{
|
||||||
|
"id": t.id,
|
||||||
|
"name": t.name,
|
||||||
|
"status": t.status.value,
|
||||||
|
}
|
||||||
|
for t in tasks
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# === 会话方法 ===
|
||||||
|
|
||||||
|
def create_session(
|
||||||
|
self,
|
||||||
|
user_id: str | None = None,
|
||||||
|
parent_session_id: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""创建会话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户 ID
|
||||||
|
parent_session_id: 父会话 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
创建结果
|
||||||
|
"""
|
||||||
|
session = create_agent_session(
|
||||||
|
user_id=user_id,
|
||||||
|
parent_session_id=parent_session_id,
|
||||||
|
)
|
||||||
|
self._sessions[session.session_id] = session
|
||||||
|
return {
|
||||||
|
"status": "created",
|
||||||
|
"session_id": session.session_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"parent_session_id": parent_session_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_session(self, session_id: str) -> AgentSession | None:
|
||||||
|
"""获取会话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentSession 或 None
|
||||||
|
"""
|
||||||
|
return self._sessions.get(session_id) or get_agent_session(session_id)
|
||||||
|
|
||||||
|
async def process_session_message(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
message: str,
|
||||||
|
response: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""处理会话消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话 ID
|
||||||
|
message: 用户消息
|
||||||
|
response: 助手响应
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理结果
|
||||||
|
"""
|
||||||
|
session = self.get_session(session_id)
|
||||||
|
if not session:
|
||||||
|
return {"status": "error", "message": f"Session '{session_id}' not found"}
|
||||||
|
|
||||||
|
await session.process_message(message, response)
|
||||||
|
return {
|
||||||
|
"status": "processed",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message_count": session.context.message_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def spawn_child_session(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""创建子会话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 父会话 ID
|
||||||
|
user_id: 用户 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
创建结果
|
||||||
|
"""
|
||||||
|
session = self.get_session(session_id)
|
||||||
|
if not session:
|
||||||
|
return {"status": "error", "message": f"Session '{session_id}' not found"}
|
||||||
|
|
||||||
|
child = await session.spawn_child_session(user_id=user_id)
|
||||||
|
self._sessions[child.session_id] = child
|
||||||
|
return {
|
||||||
|
"status": "spawned",
|
||||||
|
"parent_session_id": session_id,
|
||||||
|
"child_session_id": child.session_id,
|
||||||
|
"depth": child.context.depth,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_session_summary(self, session_id: str) -> dict[str, Any]:
|
||||||
|
"""获取会话摘要
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
会话摘要
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
session = self.get_session(session_id)
|
||||||
|
if not session:
|
||||||
|
return {"status": "error", "message": f"Session '{session_id}' not found"}
|
||||||
|
|
||||||
|
# get_session_summary is async, so we need to run it
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if loop.is_running():
|
||||||
|
# Create a future
|
||||||
|
future = asyncio.ensure_future(session.get_session_summary())
|
||||||
|
return {"status": "found", "summary": future}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"status": "found",
|
||||||
|
"summary": loop.run_until_complete(session.get_session_summary()),
|
||||||
|
}
|
||||||
|
except RuntimeError:
|
||||||
|
# No event loop, create one
|
||||||
|
return {"status": "found", "summary": asyncio.run(session.get_session_summary())}
|
||||||
|
|
||||||
|
# === 远程传输方法 ===
|
||||||
|
|
||||||
|
def register_remote_handler(self, event_type: str, handler: Any) -> None:
|
||||||
|
"""注册远程消息处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_type: 事件类型
|
||||||
|
handler: 处理函数
|
||||||
|
"""
|
||||||
|
self._remote_transport.register_handler(event_type, handler)
|
||||||
|
|
||||||
|
async def send_remote_response(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
response: dict[str, Any],
|
||||||
|
) -> bool:
|
||||||
|
"""发送远程响应
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话 ID
|
||||||
|
response: 响应数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否发送成功
|
||||||
|
"""
|
||||||
|
return await self._remote_transport.send_response(session_id, response)
|
||||||
|
|
||||||
|
async def send_remote_event(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
event: dict[str, Any],
|
||||||
|
) -> bool:
|
||||||
|
"""发送远程事件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话 ID
|
||||||
|
event: 事件数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否发送成功
|
||||||
|
"""
|
||||||
|
return await self._remote_transport.send_event(session_id, event)
|
||||||
|
|
||||||
|
async def send_remote_tool_call(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
tool_call: dict[str, Any],
|
||||||
|
) -> bool:
|
||||||
|
"""发送远程工具调用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话 ID
|
||||||
|
tool_call: 工具调用数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否发送成功
|
||||||
|
"""
|
||||||
|
return await self._remote_transport.send_tool_call(session_id, tool_call)
|
||||||
|
|
||||||
|
# === 统一协调入口 ===
|
||||||
|
|
||||||
|
async def coordinate(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""统一协调入口
|
||||||
|
|
||||||
|
根据请求类型协调各类 Agent 操作。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 请求数据,包含:
|
||||||
|
- action: 操作类型 (team_create, team_assign, task_submit, session_create, etc.)
|
||||||
|
- 其他参数根据 action 不同而不同
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
协调结果
|
||||||
|
"""
|
||||||
|
action = request.get("action")
|
||||||
|
|
||||||
|
if action == "team_create":
|
||||||
|
return self.create_team(
|
||||||
|
team_id=request["team_id"],
|
||||||
|
members=request["members"],
|
||||||
|
)
|
||||||
|
|
||||||
|
elif action == "team_assign":
|
||||||
|
return self.assign_task(
|
||||||
|
team_id=request["team_id"],
|
||||||
|
description=request["description"],
|
||||||
|
member=request["member"],
|
||||||
|
)
|
||||||
|
|
||||||
|
elif action == "team_broadcast":
|
||||||
|
return self.broadcast_task(
|
||||||
|
team_id=request["team_id"],
|
||||||
|
description=request["description"],
|
||||||
|
)
|
||||||
|
|
||||||
|
elif action == "team_collect":
|
||||||
|
return self.collect_team_results(team_id=request["team_id"])
|
||||||
|
|
||||||
|
elif action == "team_status":
|
||||||
|
return self.get_team_status(team_id=request["team_id"])
|
||||||
|
|
||||||
|
elif action == "task_submit":
|
||||||
|
return self.submit_background_task(
|
||||||
|
name=request["name"],
|
||||||
|
coro=request["coro"],
|
||||||
|
*request.get("args", []),
|
||||||
|
**request.get("kwargs", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
elif action == "task_cancel":
|
||||||
|
return self.cancel_background_task(task_id=request["task_id"])
|
||||||
|
|
||||||
|
elif action == "task_status":
|
||||||
|
return self.get_background_task_status(task_id=request["task_id"])
|
||||||
|
|
||||||
|
elif action == "session_create":
|
||||||
|
return self.create_session(
|
||||||
|
user_id=request.get("user_id"),
|
||||||
|
parent_session_id=request.get("parent_session_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
elif action == "session_message":
|
||||||
|
return await self.process_session_message(
|
||||||
|
session_id=request["session_id"],
|
||||||
|
message=request["message"],
|
||||||
|
response=request["response"],
|
||||||
|
)
|
||||||
|
|
||||||
|
elif action == "session_spawn":
|
||||||
|
return await self.spawn_child_session(
|
||||||
|
session_id=request["session_id"],
|
||||||
|
user_id=request.get("user_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
elif action == "session_summary":
|
||||||
|
return self.get_session_summary(session_id=request["session_id"])
|
||||||
|
|
||||||
|
else:
|
||||||
|
return {"status": "error", "message": f"Unknown action: {action}"}
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
_coordinator: AgentCoordinator | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_agent_coordinator() -> AgentCoordinator:
|
||||||
|
"""获取全局 Agent 协调整器"""
|
||||||
|
global _coordinator
|
||||||
|
if _coordinator is None:
|
||||||
|
_coordinator = AgentCoordinator()
|
||||||
|
return _coordinator
|
||||||
File diff suppressed because it is too large
Load Diff
14
backend/app/agents/isolation/__init__.py
Normal file
14
backend/app/agents/isolation/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from app.agents.isolation.session_isolation import prepare_session_isolation
|
||||||
|
from app.agents.isolation.strategy_selector import IsolationDecision, select_isolation_strategy
|
||||||
|
from app.agents.isolation.worktree_isolation import (
|
||||||
|
WorktreeIsolationError,
|
||||||
|
prepare_worktree_isolation,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"IsolationDecision",
|
||||||
|
"WorktreeIsolationError",
|
||||||
|
"prepare_session_isolation",
|
||||||
|
"prepare_worktree_isolation",
|
||||||
|
"select_isolation_strategy",
|
||||||
|
]
|
||||||
31
backend/app/agents/isolation/session_isolation.py
Normal file
31
backend/app/agents/isolation/session_isolation.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from app.agents.isolation.strategy_selector import IsolationDecision
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_session_isolation(
|
||||||
|
*,
|
||||||
|
state: dict[str, Any],
|
||||||
|
decision: IsolationDecision,
|
||||||
|
role_value: str,
|
||||||
|
sub_commander: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
isolation_id = f"session-{uuid4().hex[:8]}"
|
||||||
|
return {
|
||||||
|
"mode": "session",
|
||||||
|
"isolation_id": isolation_id,
|
||||||
|
"workspace_path": None,
|
||||||
|
"parent_conversation_id": str(state.get("conversation_id") or "") or None,
|
||||||
|
"metadata": {
|
||||||
|
**dict(decision.metadata or {}),
|
||||||
|
"reason": decision.reason,
|
||||||
|
"role": role_value,
|
||||||
|
"sub_commander": sub_commander,
|
||||||
|
"tool_names": list(decision.tool_names),
|
||||||
|
"capability_ids": list(decision.capability_ids),
|
||||||
|
"status": "active",
|
||||||
|
},
|
||||||
|
}
|
||||||
147
backend/app/agents/isolation/strategy_selector.py
Normal file
147
backend/app/agents/isolation/strategy_selector.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from app.agents.registry import load_builtin_registry_indexes
|
||||||
|
from app.agents.registry.models import CapabilityManifest, PermissionClass, SideEffectScope
|
||||||
|
|
||||||
|
|
||||||
|
IsolationMode = Literal["none", "session", "worktree"]
|
||||||
|
|
||||||
|
_WORKTREE_QUERY_MARKERS = (
|
||||||
|
"code",
|
||||||
|
"repo",
|
||||||
|
"repository",
|
||||||
|
"git",
|
||||||
|
"worktree",
|
||||||
|
"branch",
|
||||||
|
"patch",
|
||||||
|
"diff",
|
||||||
|
"refactor",
|
||||||
|
"build",
|
||||||
|
"test",
|
||||||
|
"fix",
|
||||||
|
"file",
|
||||||
|
"files",
|
||||||
|
"python",
|
||||||
|
"typescript",
|
||||||
|
"javascript",
|
||||||
|
"代码",
|
||||||
|
"仓库",
|
||||||
|
"分支",
|
||||||
|
"补丁",
|
||||||
|
"重构",
|
||||||
|
"构建",
|
||||||
|
"测试",
|
||||||
|
"修复",
|
||||||
|
"文件",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class IsolationDecision:
|
||||||
|
mode: IsolationMode
|
||||||
|
reason: str
|
||||||
|
tool_names: tuple[str, ...] = ()
|
||||||
|
capability_ids: tuple[str, ...] = ()
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
def _capability_metadata(capability: CapabilityManifest | None) -> dict[str, Any]:
|
||||||
|
if capability is None:
|
||||||
|
return {}
|
||||||
|
return {
|
||||||
|
"capability_id": capability.capability_id,
|
||||||
|
"tool_name": capability.tool_name,
|
||||||
|
"permission_class": capability.permission_class.value,
|
||||||
|
"side_effect_scope": capability.side_effect_scope.value,
|
||||||
|
"supports_retry": capability.supports_retry,
|
||||||
|
"idempotent": capability.idempotent,
|
||||||
|
"safe_for_parallel_use": capability.safe_for_parallel_use,
|
||||||
|
"requires_confirmation": capability.requires_confirmation,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def select_isolation_strategy(
|
||||||
|
*,
|
||||||
|
user_query: str,
|
||||||
|
tool_names: list[str] | tuple[str, ...],
|
||||||
|
role_value: str,
|
||||||
|
execution_mode: str | None,
|
||||||
|
) -> IsolationDecision:
|
||||||
|
indexes = load_builtin_registry_indexes()
|
||||||
|
capabilities: list[CapabilityManifest] = []
|
||||||
|
capability_ids: list[str] = []
|
||||||
|
|
||||||
|
for tool_name in tool_names:
|
||||||
|
capability_id = indexes.capability_id_by_tool_name.get(tool_name)
|
||||||
|
capability = indexes.capability_by_id.get(capability_id) if capability_id else None
|
||||||
|
if capability is not None:
|
||||||
|
capabilities.append(capability)
|
||||||
|
capability_ids.append(capability.capability_id)
|
||||||
|
|
||||||
|
normalized_query = (user_query or "").strip().lower()
|
||||||
|
has_worktree_query_signal = any(marker in normalized_query for marker in _WORKTREE_QUERY_MARKERS)
|
||||||
|
has_write_capability = any(cap.permission_class == PermissionClass.WRITE for cap in capabilities)
|
||||||
|
has_external_capability = any(cap.permission_class == PermissionClass.EXTERNAL for cap in capabilities)
|
||||||
|
has_non_parallel_capability = any(not cap.safe_for_parallel_use for cap in capabilities)
|
||||||
|
has_stateful_side_effect = any(
|
||||||
|
cap.side_effect_scope in {SideEffectScope.LOCAL_STATE, SideEffectScope.DB_WRITE}
|
||||||
|
for cap in capabilities
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"role": role_value,
|
||||||
|
"execution_mode": execution_mode,
|
||||||
|
"capabilities": [_capability_metadata(capability) for capability in capabilities],
|
||||||
|
"workspace_strategy": "inline",
|
||||||
|
"risk_level": "low",
|
||||||
|
}
|
||||||
|
|
||||||
|
if has_worktree_query_signal:
|
||||||
|
return IsolationDecision(
|
||||||
|
mode="worktree",
|
||||||
|
reason="workspace_mutation_signals_detected",
|
||||||
|
tool_names=tuple(tool_names),
|
||||||
|
capability_ids=tuple(capability_ids),
|
||||||
|
metadata={
|
||||||
|
**metadata,
|
||||||
|
"workspace_strategy": "ephemeral_worktree",
|
||||||
|
"risk_level": "high",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_write_capability or has_stateful_side_effect or has_non_parallel_capability:
|
||||||
|
return IsolationDecision(
|
||||||
|
mode="session",
|
||||||
|
reason="stateful_or_non_parallel_tooling",
|
||||||
|
tool_names=tuple(tool_names),
|
||||||
|
capability_ids=tuple(capability_ids),
|
||||||
|
metadata={
|
||||||
|
**metadata,
|
||||||
|
"workspace_strategy": "isolated_session",
|
||||||
|
"risk_level": "medium",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if execution_mode == "collaboration" or role_value in {"analyst", "librarian"} or has_external_capability:
|
||||||
|
return IsolationDecision(
|
||||||
|
mode="session",
|
||||||
|
reason="context_heavy_or_external_retrieval",
|
||||||
|
tool_names=tuple(tool_names),
|
||||||
|
capability_ids=tuple(capability_ids),
|
||||||
|
metadata={
|
||||||
|
**metadata,
|
||||||
|
"workspace_strategy": "isolated_session",
|
||||||
|
"risk_level": "medium",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return IsolationDecision(
|
||||||
|
mode="none",
|
||||||
|
reason="inline_execution_is_sufficient",
|
||||||
|
tool_names=tuple(tool_names),
|
||||||
|
capability_ids=tuple(capability_ids),
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
83
backend/app/agents/isolation/worktree_isolation.py
Normal file
83
backend/app/agents/isolation/worktree_isolation.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from app.agents.isolation.strategy_selector import IsolationDecision
|
||||||
|
|
||||||
|
|
||||||
|
class WorktreeIsolationError(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _slugify(value: str, *, fallback: str) -> str:
|
||||||
|
slug = re.sub(r"[^a-zA-Z0-9._-]+", "-", (value or "").strip()).strip("-").lower()
|
||||||
|
return slug or fallback
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_git_root() -> Path:
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["git", "rev-parse", "--show-toplevel"],
|
||||||
|
check=True,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
except subprocess.CalledProcessError as exc:
|
||||||
|
raise WorktreeIsolationError(exc.stderr.strip() or exc.stdout.strip() or "git_root_unavailable") from exc
|
||||||
|
git_root = Path(result.stdout.strip())
|
||||||
|
if not git_root.exists():
|
||||||
|
raise WorktreeIsolationError("git_root_not_found")
|
||||||
|
return git_root
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_worktree_isolation(
|
||||||
|
*,
|
||||||
|
state: dict[str, Any],
|
||||||
|
decision: IsolationDecision,
|
||||||
|
role_value: str,
|
||||||
|
sub_commander: str,
|
||||||
|
create_workspace: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
isolation_id = f"worktree-{uuid4().hex[:8]}"
|
||||||
|
conversation_slug = _slugify(str(state.get("conversation_id") or "conversation"), fallback="conversation")
|
||||||
|
role_slug = _slugify(role_value, fallback="agent")
|
||||||
|
git_root = _resolve_git_root()
|
||||||
|
workspace_root = git_root / ".worktrees" / "jarvis" / conversation_slug
|
||||||
|
workspace_path = workspace_root / f"{role_slug}-{isolation_id}"
|
||||||
|
branch = f"jarvis/{conversation_slug}/{role_slug}-{isolation_id}"
|
||||||
|
|
||||||
|
if create_workspace and not workspace_path.exists():
|
||||||
|
workspace_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
try:
|
||||||
|
subprocess.run(
|
||||||
|
["git", "-C", str(git_root), "worktree", "add", "-b", branch, str(workspace_path), "HEAD"],
|
||||||
|
check=True,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
except subprocess.CalledProcessError as exc:
|
||||||
|
raise WorktreeIsolationError(exc.stderr.strip() or exc.stdout.strip() or "worktree_add_failed") from exc
|
||||||
|
|
||||||
|
return {
|
||||||
|
"mode": "worktree",
|
||||||
|
"isolation_id": isolation_id,
|
||||||
|
"workspace_path": str(workspace_path),
|
||||||
|
"parent_conversation_id": str(state.get("conversation_id") or "") or None,
|
||||||
|
"metadata": {
|
||||||
|
**dict(decision.metadata or {}),
|
||||||
|
"reason": decision.reason,
|
||||||
|
"role": role_value,
|
||||||
|
"sub_commander": sub_commander,
|
||||||
|
"tool_names": list(decision.tool_names),
|
||||||
|
"capability_ids": list(decision.capability_ids),
|
||||||
|
"repo_root": str(git_root),
|
||||||
|
"branch": branch,
|
||||||
|
"workspace_strategy": "ephemeral_worktree",
|
||||||
|
"cleanup_status": "pending",
|
||||||
|
"materialized": workspace_path.exists(),
|
||||||
|
},
|
||||||
|
}
|
||||||
20
backend/app/agents/orchestration/__init__.py
Normal file
20
backend/app/agents/orchestration/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
"""高级编排系统 - Phase 10"""
|
||||||
|
|
||||||
|
from app.agents.team.leader import TeamLeader, TeamTask, TaskStatus
|
||||||
|
from app.agents.transport.remote import RemoteTransport, StructuredMessage
|
||||||
|
from app.agents.background.manager import (
|
||||||
|
BackgroundTaskManager,
|
||||||
|
BackgroundTask,
|
||||||
|
get_background_task_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TeamLeader",
|
||||||
|
"TeamTask",
|
||||||
|
"TaskStatus",
|
||||||
|
"RemoteTransport",
|
||||||
|
"StructuredMessage",
|
||||||
|
"BackgroundTaskManager",
|
||||||
|
"BackgroundTask",
|
||||||
|
"get_background_task_manager",
|
||||||
|
]
|
||||||
12
backend/app/agents/plugins/__init__.py
Normal file
12
backend/app/agents/plugins/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""插件系统 - Phase 8"""
|
||||||
|
|
||||||
|
from app.agents.plugins.manager import PluginManager, get_plugin_manager
|
||||||
|
from app.agents.plugins.manifest import PluginManifest
|
||||||
|
from app.agents.plugins.sandbox import PluginSandbox
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PluginManager",
|
||||||
|
"PluginManifest",
|
||||||
|
"PluginSandbox",
|
||||||
|
"get_plugin_manager",
|
||||||
|
]
|
||||||
19
backend/app/agents/plugins/builtins/code_helper/__init__.py
Normal file
19
backend/app/agents/plugins/builtins/code_helper/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
"""Code Helper Plugin - Linting, formatting, and code explanation tools"""
|
||||||
|
|
||||||
|
|
||||||
|
def lint_file(file_path: str) -> dict:
|
||||||
|
"""Lint a source file and return issues found."""
|
||||||
|
return {"status": "ok", "tool": "lint_file", "result": f"Linting {file_path}"}
|
||||||
|
|
||||||
|
|
||||||
|
def format_file(file_path: str) -> dict:
|
||||||
|
"""Format a source file and return the result."""
|
||||||
|
return {"status": "ok", "tool": "format_file", "result": f"Formatting {file_path}"}
|
||||||
|
|
||||||
|
|
||||||
|
def explain_code(code_snippet: str) -> dict:
|
||||||
|
"""Explain a code snippet and return the explanation."""
|
||||||
|
return {"status": "ok", "tool": "explain_code", "result": f"Explaining code snippet"}
|
||||||
|
|
||||||
|
|
||||||
|
tools = [lint_file, format_file, explain_code]
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"id": "code_helper",
|
||||||
|
"name": "Code Helper",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"description": "Code linting, formatting, and explanation tools",
|
||||||
|
"author": "",
|
||||||
|
"homepage": "",
|
||||||
|
"license": "MIT",
|
||||||
|
"plugin_type": "tool",
|
||||||
|
"main": "__init__.py",
|
||||||
|
"hooks": [],
|
||||||
|
"tools": ["lint_file", "format_file", "explain_code"],
|
||||||
|
"skills": [],
|
||||||
|
"dependencies": {},
|
||||||
|
"peer_dependencies": {},
|
||||||
|
"permissions": [],
|
||||||
|
"allowed_paths": [],
|
||||||
|
"denied_paths": [],
|
||||||
|
"network_allowed": false,
|
||||||
|
"allowed_hosts": [],
|
||||||
|
"config_schema": {}
|
||||||
|
}
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
"""File Organizer Plugin - File organization and duplicate detection tools"""
|
||||||
|
|
||||||
|
|
||||||
|
def organize_by_type(directory: str) -> dict:
|
||||||
|
"""Organize files in a directory by file type."""
|
||||||
|
return {"status": "ok", "tool": "organize_by_type", "result": f"Organizing {directory} by type"}
|
||||||
|
|
||||||
|
|
||||||
|
def find_duplicates(directory: str) -> dict:
|
||||||
|
"""Find duplicate files in a directory."""
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"tool": "find_duplicates",
|
||||||
|
"result": f"Finding duplicates in {directory}",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
tools = [organize_by_type, find_duplicates]
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"id": "file_organizer",
|
||||||
|
"name": "File Organizer",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"description": "File organization and duplicate detection tools",
|
||||||
|
"author": "",
|
||||||
|
"homepage": "",
|
||||||
|
"license": "MIT",
|
||||||
|
"plugin_type": "tool",
|
||||||
|
"main": "__init__.py",
|
||||||
|
"hooks": [],
|
||||||
|
"tools": ["organize_by_type", "find_duplicates"],
|
||||||
|
"skills": [],
|
||||||
|
"dependencies": {},
|
||||||
|
"peer_dependencies": {},
|
||||||
|
"permissions": [],
|
||||||
|
"allowed_paths": [],
|
||||||
|
"denied_paths": [],
|
||||||
|
"network_allowed": false,
|
||||||
|
"allowed_hosts": [],
|
||||||
|
"config_schema": {}
|
||||||
|
}
|
||||||
23
backend/app/agents/plugins/builtins/git_helper/__init__.py
Normal file
23
backend/app/agents/plugins/builtins/git_helper/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
"""Git Helper Plugin - Git status, log, and diff summary tools"""
|
||||||
|
|
||||||
|
|
||||||
|
def git_status_summary() -> dict:
|
||||||
|
"""Get a summary of git status."""
|
||||||
|
return {"status": "ok", "tool": "git_status_summary", "result": "Git status summary"}
|
||||||
|
|
||||||
|
|
||||||
|
def git_log_summary(limit: int = 10) -> dict:
|
||||||
|
"""Get a summary of recent git commits."""
|
||||||
|
return {"status": "ok", "tool": "git_log_summary", "result": f"Git log summary (limit={limit})"}
|
||||||
|
|
||||||
|
|
||||||
|
def git_diff_summary(ref1: str = "HEAD", ref2: str = "HEAD~1") -> dict:
|
||||||
|
"""Get a summary of changes between two refs."""
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"tool": "git_diff_summary",
|
||||||
|
"result": f"Git diff summary ({ref1}..{ref2})",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
tools = [git_status_summary, git_log_summary, git_diff_summary]
|
||||||
22
backend/app/agents/plugins/builtins/git_helper/manifest.json
Normal file
22
backend/app/agents/plugins/builtins/git_helper/manifest.json
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"id": "git_helper",
|
||||||
|
"name": "Git Helper",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"description": "Git status, log, and diff summary tools",
|
||||||
|
"author": "",
|
||||||
|
"homepage": "",
|
||||||
|
"license": "MIT",
|
||||||
|
"plugin_type": "tool",
|
||||||
|
"main": "__init__.py",
|
||||||
|
"hooks": [],
|
||||||
|
"tools": ["git_status_summary", "git_log_summary", "git_diff_summary"],
|
||||||
|
"skills": [],
|
||||||
|
"dependencies": {},
|
||||||
|
"peer_dependencies": {},
|
||||||
|
"permissions": [],
|
||||||
|
"allowed_paths": [],
|
||||||
|
"denied_paths": [],
|
||||||
|
"network_allowed": false,
|
||||||
|
"allowed_hosts": [],
|
||||||
|
"config_schema": {}
|
||||||
|
}
|
||||||
14
backend/app/agents/plugins/builtins/web_helper/__init__.py
Normal file
14
backend/app/agents/plugins/builtins/web_helper/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""Web Helper Plugin - Web fetching and HTML parsing tools"""
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_url_content(url: str) -> dict:
|
||||||
|
"""Fetch content from a URL."""
|
||||||
|
return {"status": "ok", "tool": "fetch_url_content", "result": f"Fetching {url}"}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_html_links(html_content: str) -> dict:
|
||||||
|
"""Parse HTML content and extract links."""
|
||||||
|
return {"status": "ok", "tool": "parse_html_links", "result": "Extracted links from HTML"}
|
||||||
|
|
||||||
|
|
||||||
|
tools = [fetch_url_content, parse_html_links]
|
||||||
22
backend/app/agents/plugins/builtins/web_helper/manifest.json
Normal file
22
backend/app/agents/plugins/builtins/web_helper/manifest.json
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"id": "web_helper",
|
||||||
|
"name": "Web Helper",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"description": "Web fetching and HTML parsing tools",
|
||||||
|
"author": "",
|
||||||
|
"homepage": "",
|
||||||
|
"license": "MIT",
|
||||||
|
"plugin_type": "tool",
|
||||||
|
"main": "__init__.py",
|
||||||
|
"hooks": [],
|
||||||
|
"tools": ["fetch_url_content", "parse_html_links"],
|
||||||
|
"skills": [],
|
||||||
|
"dependencies": {},
|
||||||
|
"peer_dependencies": {},
|
||||||
|
"permissions": [],
|
||||||
|
"allowed_paths": [],
|
||||||
|
"denied_paths": [],
|
||||||
|
"network_allowed": true,
|
||||||
|
"allowed_hosts": [],
|
||||||
|
"config_schema": {}
|
||||||
|
}
|
||||||
207
backend/app/agents/plugins/manager.py
Normal file
207
backend/app/agents/plugins/manager.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
"""插件管理器 - Phase 8.2"""
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.plugins.manifest import PluginManifest
|
||||||
|
from app.agents.plugins.sandbox import PluginSandbox
|
||||||
|
|
||||||
|
|
||||||
|
class PluginManager:
|
||||||
|
"""插件管理器
|
||||||
|
|
||||||
|
负责插件的安装、卸载、启用、禁用和生命周期管理。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, plugins_dir: str | None = None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
plugins_dir: 插件目录,None 则使用默认目录
|
||||||
|
"""
|
||||||
|
if plugins_dir is None:
|
||||||
|
plugins_dir = os.path.join(os.path.dirname(__file__), "..", "..", "..", "plugins")
|
||||||
|
self.plugins_dir = plugins_dir
|
||||||
|
self._plugins: dict[str, PluginManifest] = {}
|
||||||
|
self._enabled: dict[str, bool] = {}
|
||||||
|
self._modules: dict[str, Any] = {}
|
||||||
|
self._sandbox = PluginSandbox()
|
||||||
|
|
||||||
|
def install(self, plugin_path: str) -> bool:
|
||||||
|
"""安装插件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_path: 插件目录路径或 manifest.json 所在目录
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否安装成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
manifest_path = os.path.join(plugin_path, "manifest.json")
|
||||||
|
|
||||||
|
if not os.path.exists(manifest_path):
|
||||||
|
return False
|
||||||
|
|
||||||
|
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||||
|
import json
|
||||||
|
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
manifest = PluginManifest.from_dict(data)
|
||||||
|
|
||||||
|
# 验证 manifest
|
||||||
|
if not self._validate_manifest(manifest, plugin_path):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 复制插件到 plugins_dir
|
||||||
|
target_dir = os.path.join(self.plugins_dir, manifest.id)
|
||||||
|
os.makedirs(os.path.dirname(target_dir), exist_ok=True)
|
||||||
|
|
||||||
|
# 保存 manifest
|
||||||
|
with open(os.path.join(target_dir, "manifest.json"), "w", encoding="utf-8") as f:
|
||||||
|
json.dump(manifest.to_dict(), f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
# 注册插件
|
||||||
|
self._plugins[manifest.id] = manifest
|
||||||
|
self._enabled[manifest.id] = True
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def uninstall(self, plugin_id: str) -> bool:
|
||||||
|
"""卸载插件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_id: 插件 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否卸载成功
|
||||||
|
"""
|
||||||
|
if plugin_id not in self._plugins:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 禁用插件
|
||||||
|
self.disable(plugin_id)
|
||||||
|
|
||||||
|
# 移除模块
|
||||||
|
if plugin_id in self._modules:
|
||||||
|
del self._modules[plugin_id]
|
||||||
|
|
||||||
|
# 移除插件
|
||||||
|
del self._plugins[plugin_id]
|
||||||
|
del self._enabled[plugin_id]
|
||||||
|
|
||||||
|
# 删除目录
|
||||||
|
plugin_dir = os.path.join(self.plugins_dir, plugin_id)
|
||||||
|
if os.path.exists(plugin_dir):
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
shutil.rmtree(plugin_dir)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def enable(self, plugin_id: str) -> bool:
|
||||||
|
"""启用插件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_id: 插件 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否启用成功
|
||||||
|
"""
|
||||||
|
if plugin_id not in self._plugins:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._enabled[plugin_id] = True
|
||||||
|
return True
|
||||||
|
|
||||||
|
def disable(self, plugin_id: str) -> bool:
|
||||||
|
"""禁用插件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_id: 插件 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否禁用成功
|
||||||
|
"""
|
||||||
|
if plugin_id not in self._plugins:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._enabled[plugin_id] = False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def reload(self, plugin_id: str) -> bool:
|
||||||
|
"""重新加载插件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_id: 插件 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否重新加载成功
|
||||||
|
"""
|
||||||
|
if plugin_id not in self._plugins:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 卸载模块
|
||||||
|
if plugin_id in self._modules:
|
||||||
|
del self._modules[plugin_id]
|
||||||
|
|
||||||
|
# 重新加载
|
||||||
|
return self._load_plugin_module(plugin_id)
|
||||||
|
|
||||||
|
def list_plugins(self) -> list[PluginManifest]:
|
||||||
|
"""列出所有插件"""
|
||||||
|
return list(self._plugins.values())
|
||||||
|
|
||||||
|
def get_plugin(self, plugin_id: str) -> PluginManifest | None:
|
||||||
|
"""获取插件清单"""
|
||||||
|
return self._plugins.get(plugin_id)
|
||||||
|
|
||||||
|
def is_enabled(self, plugin_id: str) -> bool:
|
||||||
|
"""检查插件是否启用"""
|
||||||
|
return self._enabled.get(plugin_id, False)
|
||||||
|
|
||||||
|
def _validate_manifest(self, manifest: PluginManifest, plugin_path: str) -> bool:
|
||||||
|
"""验证 manifest"""
|
||||||
|
# 检查主入口文件是否存在
|
||||||
|
main_path = os.path.join(plugin_path, manifest.main)
|
||||||
|
if not os.path.exists(main_path):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _load_plugin_module(self, plugin_id: str) -> bool:
|
||||||
|
"""加载插件模块"""
|
||||||
|
plugin_dir = os.path.join(self.plugins_dir, plugin_id)
|
||||||
|
manifest = self._plugins.get(plugin_id)
|
||||||
|
if not manifest:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
main_path = os.path.join(plugin_dir, manifest.main)
|
||||||
|
spec = importlib.util.spec_from_file_location(plugin_id, main_path)
|
||||||
|
if spec and spec.loader:
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
sys.modules[plugin_id] = module
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
self._modules[plugin_id] = module
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
_manager: PluginManager | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_plugin_manager() -> PluginManager:
|
||||||
|
"""获取全局插件管理器"""
|
||||||
|
global _manager
|
||||||
|
if _manager is None:
|
||||||
|
_manager = PluginManager()
|
||||||
|
return _manager
|
||||||
73
backend/app/agents/plugins/manifest.py
Normal file
73
backend/app/agents/plugins/manifest.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
"""插件清单定义 - Phase 8.1"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PluginManifest:
|
||||||
|
"""插件清单
|
||||||
|
|
||||||
|
定义插件的元数据和接口。
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str # 唯一标识
|
||||||
|
name: str # 显示名称
|
||||||
|
version: str # 版本号
|
||||||
|
description: str # 描述
|
||||||
|
author: str = "" # 作者
|
||||||
|
homepage: str = "" # 主页
|
||||||
|
license: str = "MIT" # 许可证
|
||||||
|
|
||||||
|
# 插件类型
|
||||||
|
plugin_type: str = "tool" # tool, hook, skill, all
|
||||||
|
|
||||||
|
# 入口点
|
||||||
|
main: str = "index.py" # 主入口文件
|
||||||
|
hooks: list[str] = field(default_factory=list) # 提供的 Hook 列表
|
||||||
|
tools: list[str] = field(default_factory=list) # 提供的工具列表
|
||||||
|
skills: list[str] = field(default_factory=list) # 提供的 Skills 列表
|
||||||
|
|
||||||
|
# 依赖
|
||||||
|
dependencies: dict[str, str] = field(default_factory=dict) # pip 依赖
|
||||||
|
peer_dependencies: dict[str, str] = field(default_factory=dict) # 对等依赖
|
||||||
|
|
||||||
|
# 权限要求
|
||||||
|
permissions: list[str] = field(default_factory=list) # 需要的权限
|
||||||
|
allowed_paths: list[str] = field(default_factory=list) # 允许访问的路径
|
||||||
|
denied_paths: list[str] = field(default_factory=list) # 禁止访问的路径
|
||||||
|
|
||||||
|
# 网络权限
|
||||||
|
network_allowed: bool = False # 是否允许网络访问
|
||||||
|
allowed_hosts: list[str] = field(default_factory=list) # 允许访问的 host
|
||||||
|
|
||||||
|
# 配置
|
||||||
|
config_schema: dict[str, Any] = field(default_factory=dict) # 配置 schema
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"name": self.name,
|
||||||
|
"version": self.version,
|
||||||
|
"description": self.description,
|
||||||
|
"author": self.author,
|
||||||
|
"homepage": self.homepage,
|
||||||
|
"license": self.license,
|
||||||
|
"plugin_type": self.plugin_type,
|
||||||
|
"main": self.main,
|
||||||
|
"hooks": self.hooks,
|
||||||
|
"tools": self.tools,
|
||||||
|
"skills": self.skills,
|
||||||
|
"dependencies": self.dependencies,
|
||||||
|
"peer_dependencies": self.peer_dependencies,
|
||||||
|
"permissions": self.permissions,
|
||||||
|
"allowed_paths": self.allowed_paths,
|
||||||
|
"denied_paths": self.denied_paths,
|
||||||
|
"network_allowed": self.network_allowed,
|
||||||
|
"allowed_hosts": self.allowed_hosts,
|
||||||
|
"config_schema": self.config_schema,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "PluginManifest":
|
||||||
|
return cls(**data)
|
||||||
111
backend/app/agents/plugins/sandbox.py
Normal file
111
backend/app/agents/plugins/sandbox.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
"""插件沙箱隔离 - Phase 8.3"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class PluginSandbox:
|
||||||
|
"""插件沙箱
|
||||||
|
|
||||||
|
提供插件执行隔离环境。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._allowed_paths: set[str] = set()
|
||||||
|
self._denied_paths: set[str] = set()
|
||||||
|
self._network_allowed: bool = False
|
||||||
|
self._allowed_hosts: set[str] = set()
|
||||||
|
|
||||||
|
def set_file_permissions(
|
||||||
|
self,
|
||||||
|
allowed_paths: list[str] | None = None,
|
||||||
|
denied_paths: list[str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""设置文件访问权限
|
||||||
|
|
||||||
|
Args:
|
||||||
|
allowed_paths: 允许访问的路径列表
|
||||||
|
denied_paths: 禁止访问的路径列表
|
||||||
|
"""
|
||||||
|
self._allowed_paths = set(allowed_paths or [])
|
||||||
|
self._denied_paths = set(denied_paths or [])
|
||||||
|
|
||||||
|
def set_network_permissions(
|
||||||
|
self, allowed: bool, allowed_hosts: list[str] | None = None
|
||||||
|
) -> None:
|
||||||
|
"""设置网络访问权限
|
||||||
|
|
||||||
|
Args:
|
||||||
|
allowed: 是否允许网络访问
|
||||||
|
allowed_hosts: 允许访问的 host 列表
|
||||||
|
"""
|
||||||
|
self._network_allowed = allowed
|
||||||
|
self._allowed_hosts = set(allowed_hosts or [])
|
||||||
|
|
||||||
|
def check_file_access(self, path: str) -> bool:
|
||||||
|
"""检查文件访问权限
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: 文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否允许访问
|
||||||
|
"""
|
||||||
|
# 如果有允许列表,只允许访问列表中的路径
|
||||||
|
if self._allowed_paths:
|
||||||
|
return path in self._allowed_paths or any(
|
||||||
|
path.startswith(allowed) for allowed in self._allowed_paths
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果有禁止列表,禁止访问列表中的路径
|
||||||
|
if self._denied_paths:
|
||||||
|
return not any(path.startswith(denied) for denied in self._denied_paths)
|
||||||
|
|
||||||
|
# 没有限制
|
||||||
|
return True
|
||||||
|
|
||||||
|
def check_network_access(self, host: str) -> bool:
|
||||||
|
"""检查网络访问权限
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host: 主机地址
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否允许访问
|
||||||
|
"""
|
||||||
|
if not self._network_allowed:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self._allowed_hosts:
|
||||||
|
return host in self._allowed_hosts or any(
|
||||||
|
host.endswith(allowed) for allowed in self._allowed_hosts
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def execute_in_sandbox(self, func: Any, *args, **kwargs) -> Any:
|
||||||
|
"""在沙箱中执行函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: 要执行的函数
|
||||||
|
*args: 位置参数
|
||||||
|
**kwargs: 关键字参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
函数返回值
|
||||||
|
"""
|
||||||
|
# 保存当前状态
|
||||||
|
old_allowed_paths = self._allowed_paths.copy()
|
||||||
|
old_denied_paths = self._denied_paths.copy()
|
||||||
|
old_network_allowed = self._network_allowed
|
||||||
|
old_allowed_hosts = self._allowed_hosts.copy()
|
||||||
|
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
# 恢复状态
|
||||||
|
self._allowed_paths = old_allowed_paths
|
||||||
|
self._denied_paths = old_denied_paths
|
||||||
|
self._network_allowed = old_network_allowed
|
||||||
|
self._allowed_hosts = old_allowed_hosts
|
||||||
@@ -324,6 +324,38 @@ ANALYST_INSIGHTS_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
COORDINATOR_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||||
|
|
||||||
|
你是 Jarvis 的协作协调官,负责把复杂请求收束成最小受控协作,而不是放任系统进入自由 swarm。
|
||||||
|
|
||||||
|
## 你的职责:
|
||||||
|
- 先判断当前请求是否真的需要拆解;不需要时应明确建议继续走 direct
|
||||||
|
- 只有在明显多步骤、跨领域、需要多角色配合时,才拆成 2~4 个子任务
|
||||||
|
- 每个子任务必须清晰写出 `title`、`role`、`goal`、`expected_evidence`
|
||||||
|
- 角色建议只能来自现有 top-level agent:`schedule_planner`、`librarian`、`analyst`、`executor`
|
||||||
|
- 汇总时基于子任务结果回收,不依赖单点硬编码拼接
|
||||||
|
|
||||||
|
## 边界:
|
||||||
|
- 禁止无限递归拆分
|
||||||
|
- 禁止创建新的 runtime agent / worker
|
||||||
|
- 禁止把一个简单请求硬拆成多个空泛步骤
|
||||||
|
- 如果证据不足、子任务未闭环,必须把风险明确暴露出来
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
VERIFIER_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||||
|
|
||||||
|
你是 Jarvis 的验证官,负责对执行结果做最小但明确的核验。
|
||||||
|
|
||||||
|
## 你的职责:
|
||||||
|
- 只输出 passed、failed、skipped 三种验证结论之一
|
||||||
|
- 用一句话总结验证判断
|
||||||
|
- 如有证据,保留关键证据点
|
||||||
|
- 当信息不足以证明成功或失败时,优先判定为 skipped
|
||||||
|
- 不重写执行方案,不扩展无关建议
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
JSON_ACTION_FALLBACK_PROMPT = """你当前运行在 JSON action fallback 模式。
|
JSON_ACTION_FALLBACK_PROMPT = """你当前运行在 JSON action fallback 模式。
|
||||||
|
|
||||||
你的输出必须满足以下规则:
|
你的输出必须满足以下规则:
|
||||||
|
|||||||
@@ -1,11 +1,19 @@
|
|||||||
"""Registry manifest models and validation helpers."""
|
"""Registry manifest models and validation helpers."""
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
from app.agents.registry.indexes import RegistryIndexes, build_registry_indexes
|
from app.agents.registry.indexes import RegistryIndexes, build_registry_indexes
|
||||||
from app.agents.registry.loader import RegistryBundle, load_builtin_registry_bundle
|
from app.agents.registry.loader import RegistryBundle, load_builtin_registry_bundle
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def load_builtin_registry_indexes() -> RegistryIndexes:
|
||||||
|
return build_registry_indexes(load_builtin_registry_bundle())
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"RegistryBundle",
|
"RegistryBundle",
|
||||||
"RegistryIndexes",
|
"RegistryIndexes",
|
||||||
"build_registry_indexes",
|
"build_registry_indexes",
|
||||||
"load_builtin_registry_bundle",
|
"load_builtin_registry_bundle",
|
||||||
|
"load_builtin_registry_indexes",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ from app.agents.prompts import SUB_COMMANDER_PROMPTS_BY_KEY
|
|||||||
from app.agents.registry.models import (
|
from app.agents.registry.models import (
|
||||||
AgentManifest,
|
AgentManifest,
|
||||||
CapabilityManifest,
|
CapabilityManifest,
|
||||||
|
PermissionClass,
|
||||||
|
SideEffectScope,
|
||||||
SpecialistTemplateManifest,
|
SpecialistTemplateManifest,
|
||||||
SubCommanderManifest,
|
SubCommanderManifest,
|
||||||
)
|
)
|
||||||
@@ -55,6 +57,19 @@ TOP_LEVEL_AGENT_ROUTING_HINTS: dict[str, tuple[str, ...]] = {
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TOP_LEVEL_AGENT_ALLOWED_SPAWN_ROLES: dict[str, tuple[str, ...]] = {
|
||||||
|
AgentRole.MASTER.value: (
|
||||||
|
AgentRole.SCHEDULE_PLANNER.value,
|
||||||
|
AgentRole.EXECUTOR.value,
|
||||||
|
AgentRole.LIBRARIAN.value,
|
||||||
|
AgentRole.ANALYST.value,
|
||||||
|
),
|
||||||
|
AgentRole.SCHEDULE_PLANNER.value: (AgentRole.SCHEDULE_PLANNER.value,),
|
||||||
|
AgentRole.EXECUTOR.value: (AgentRole.EXECUTOR.value,),
|
||||||
|
AgentRole.LIBRARIAN.value: (AgentRole.LIBRARIAN.value,),
|
||||||
|
AgentRole.ANALYST.value: (AgentRole.ANALYST.value,),
|
||||||
|
}
|
||||||
|
|
||||||
SUB_COMMANDER_PARENT_AGENT_IDS: dict[str, str] = {
|
SUB_COMMANDER_PARENT_AGENT_IDS: dict[str, str] = {
|
||||||
"schedule_analysis": AgentRole.SCHEDULE_PLANNER.value,
|
"schedule_analysis": AgentRole.SCHEDULE_PLANNER.value,
|
||||||
"schedule_planning": AgentRole.SCHEDULE_PLANNER.value,
|
"schedule_planning": AgentRole.SCHEDULE_PLANNER.value,
|
||||||
@@ -75,6 +90,8 @@ BUILTIN_AGENT_MANIFESTS: tuple[AgentManifest, ...] = tuple(
|
|||||||
system_prompt_key=role.value,
|
system_prompt_key=role.value,
|
||||||
routing_hints=list(TOP_LEVEL_AGENT_ROUTING_HINTS[role.value]),
|
routing_hints=list(TOP_LEVEL_AGENT_ROUTING_HINTS[role.value]),
|
||||||
default_sub_commanders=list(TOP_LEVEL_AGENT_DEFAULT_SUB_COMMANDERS[role.value]),
|
default_sub_commanders=list(TOP_LEVEL_AGENT_DEFAULT_SUB_COMMANDERS[role.value]),
|
||||||
|
can_spawn_children=bool(TOP_LEVEL_AGENT_ALLOWED_SPAWN_ROLES[role.value]),
|
||||||
|
allowed_spawn_role_values=list(TOP_LEVEL_AGENT_ALLOWED_SPAWN_ROLES[role.value]),
|
||||||
skill_context_key=role.value.replace("agent_", ""),
|
skill_context_key=role.value.replace("agent_", ""),
|
||||||
)
|
)
|
||||||
for role in AgentRole
|
for role in AgentRole
|
||||||
@@ -89,10 +106,150 @@ _capability_tool_names = tuple(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_CAPABILITY_METADATA_BY_TOOL_NAME: dict[str, dict[str, object]] = {
|
||||||
|
"get_tasks": {
|
||||||
|
"permission_class": PermissionClass.READ,
|
||||||
|
"side_effect_scope": SideEffectScope.NONE,
|
||||||
|
"supports_retry": True,
|
||||||
|
"idempotent": True,
|
||||||
|
"safe_for_parallel_use": True,
|
||||||
|
"requires_confirmation": False,
|
||||||
|
},
|
||||||
|
"get_schedule_day": {
|
||||||
|
"permission_class": PermissionClass.READ,
|
||||||
|
"side_effect_scope": SideEffectScope.NONE,
|
||||||
|
"supports_retry": True,
|
||||||
|
"idempotent": True,
|
||||||
|
"safe_for_parallel_use": True,
|
||||||
|
"requires_confirmation": False,
|
||||||
|
},
|
||||||
|
"resolve_time_expression": {
|
||||||
|
"permission_class": PermissionClass.READ,
|
||||||
|
"side_effect_scope": SideEffectScope.NONE,
|
||||||
|
"supports_retry": True,
|
||||||
|
"idempotent": True,
|
||||||
|
"safe_for_parallel_use": True,
|
||||||
|
"requires_confirmation": False,
|
||||||
|
},
|
||||||
|
"search_knowledge": {
|
||||||
|
"permission_class": PermissionClass.READ,
|
||||||
|
"side_effect_scope": SideEffectScope.NONE,
|
||||||
|
"supports_retry": True,
|
||||||
|
"idempotent": True,
|
||||||
|
"safe_for_parallel_use": True,
|
||||||
|
"requires_confirmation": False,
|
||||||
|
},
|
||||||
|
"hybrid_search": {
|
||||||
|
"permission_class": PermissionClass.READ,
|
||||||
|
"side_effect_scope": SideEffectScope.NONE,
|
||||||
|
"supports_retry": True,
|
||||||
|
"idempotent": True,
|
||||||
|
"safe_for_parallel_use": True,
|
||||||
|
"requires_confirmation": False,
|
||||||
|
},
|
||||||
|
"get_knowledge_graph_context": {
|
||||||
|
"permission_class": PermissionClass.READ,
|
||||||
|
"side_effect_scope": SideEffectScope.NONE,
|
||||||
|
"supports_retry": True,
|
||||||
|
"idempotent": True,
|
||||||
|
"safe_for_parallel_use": True,
|
||||||
|
"requires_confirmation": False,
|
||||||
|
},
|
||||||
|
"get_forum_posts": {
|
||||||
|
"permission_class": PermissionClass.READ,
|
||||||
|
"side_effect_scope": SideEffectScope.NONE,
|
||||||
|
"supports_retry": True,
|
||||||
|
"idempotent": True,
|
||||||
|
"safe_for_parallel_use": True,
|
||||||
|
"requires_confirmation": False,
|
||||||
|
},
|
||||||
|
"scan_forum_for_instructions": {
|
||||||
|
"permission_class": PermissionClass.READ,
|
||||||
|
"side_effect_scope": SideEffectScope.NONE,
|
||||||
|
"supports_retry": True,
|
||||||
|
"idempotent": True,
|
||||||
|
"safe_for_parallel_use": True,
|
||||||
|
"requires_confirmation": False,
|
||||||
|
},
|
||||||
|
"web_search": {
|
||||||
|
"permission_class": PermissionClass.EXTERNAL,
|
||||||
|
"side_effect_scope": SideEffectScope.NETWORK,
|
||||||
|
"supports_retry": True,
|
||||||
|
"idempotent": True,
|
||||||
|
"safe_for_parallel_use": True,
|
||||||
|
"requires_confirmation": False,
|
||||||
|
},
|
||||||
|
"create_task": {
|
||||||
|
"permission_class": PermissionClass.WRITE,
|
||||||
|
"side_effect_scope": SideEffectScope.LOCAL_STATE,
|
||||||
|
"supports_retry": False,
|
||||||
|
"idempotent": False,
|
||||||
|
"safe_for_parallel_use": False,
|
||||||
|
"requires_confirmation": True,
|
||||||
|
},
|
||||||
|
"update_task_status": {
|
||||||
|
"permission_class": PermissionClass.WRITE,
|
||||||
|
"side_effect_scope": SideEffectScope.LOCAL_STATE,
|
||||||
|
"supports_retry": False,
|
||||||
|
"idempotent": False,
|
||||||
|
"safe_for_parallel_use": False,
|
||||||
|
"requires_confirmation": True,
|
||||||
|
},
|
||||||
|
"create_todo": {
|
||||||
|
"permission_class": PermissionClass.WRITE,
|
||||||
|
"side_effect_scope": SideEffectScope.LOCAL_STATE,
|
||||||
|
"supports_retry": False,
|
||||||
|
"idempotent": False,
|
||||||
|
"safe_for_parallel_use": False,
|
||||||
|
"requires_confirmation": True,
|
||||||
|
},
|
||||||
|
"create_schedule_task": {
|
||||||
|
"permission_class": PermissionClass.WRITE,
|
||||||
|
"side_effect_scope": SideEffectScope.LOCAL_STATE,
|
||||||
|
"supports_retry": False,
|
||||||
|
"idempotent": False,
|
||||||
|
"safe_for_parallel_use": False,
|
||||||
|
"requires_confirmation": True,
|
||||||
|
},
|
||||||
|
"create_reminder": {
|
||||||
|
"permission_class": PermissionClass.WRITE,
|
||||||
|
"side_effect_scope": SideEffectScope.LOCAL_STATE,
|
||||||
|
"supports_retry": False,
|
||||||
|
"idempotent": False,
|
||||||
|
"safe_for_parallel_use": False,
|
||||||
|
"requires_confirmation": True,
|
||||||
|
},
|
||||||
|
"create_goal": {
|
||||||
|
"permission_class": PermissionClass.WRITE,
|
||||||
|
"side_effect_scope": SideEffectScope.LOCAL_STATE,
|
||||||
|
"supports_retry": False,
|
||||||
|
"idempotent": False,
|
||||||
|
"safe_for_parallel_use": False,
|
||||||
|
"requires_confirmation": True,
|
||||||
|
},
|
||||||
|
"create_forum_post": {
|
||||||
|
"permission_class": PermissionClass.WRITE,
|
||||||
|
"side_effect_scope": SideEffectScope.LOCAL_STATE,
|
||||||
|
"supports_retry": False,
|
||||||
|
"idempotent": False,
|
||||||
|
"safe_for_parallel_use": False,
|
||||||
|
"requires_confirmation": True,
|
||||||
|
},
|
||||||
|
"build_knowledge_graph": {
|
||||||
|
"permission_class": PermissionClass.WRITE,
|
||||||
|
"side_effect_scope": SideEffectScope.LOCAL_STATE,
|
||||||
|
"supports_retry": False,
|
||||||
|
"idempotent": False,
|
||||||
|
"safe_for_parallel_use": False,
|
||||||
|
"requires_confirmation": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
BUILTIN_CAPABILITY_MANIFESTS: tuple[CapabilityManifest, ...] = tuple(
|
BUILTIN_CAPABILITY_MANIFESTS: tuple[CapabilityManifest, ...] = tuple(
|
||||||
CapabilityManifest(
|
CapabilityManifest(
|
||||||
capability_id=tool_name,
|
capability_id=tool_name,
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
|
**dict(_CAPABILITY_METADATA_BY_TOOL_NAME.get(tool_name, {})),
|
||||||
)
|
)
|
||||||
for tool_name in _capability_tool_names
|
for tool_name in _capability_tool_names
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from app.agents.registry.models import (
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class RegistryIndexes:
|
class RegistryIndexes:
|
||||||
agent_by_id: Mapping[str, AgentManifest]
|
agent_by_id: Mapping[str, AgentManifest]
|
||||||
|
agent_by_role_value: Mapping[str, AgentManifest]
|
||||||
sub_commander_by_id: Mapping[str, SubCommanderManifest]
|
sub_commander_by_id: Mapping[str, SubCommanderManifest]
|
||||||
capability_by_id: Mapping[str, CapabilityManifest]
|
capability_by_id: Mapping[str, CapabilityManifest]
|
||||||
specialist_template_by_id: Mapping[str, SpecialistTemplateManifest]
|
specialist_template_by_id: Mapping[str, SpecialistTemplateManifest]
|
||||||
@@ -24,6 +25,7 @@ class RegistryIndexes:
|
|||||||
skill_context_key_by_agent_id: Mapping[str, str]
|
skill_context_key_by_agent_id: Mapping[str, str]
|
||||||
capability_id_by_tool_name: Mapping[str, str]
|
capability_id_by_tool_name: Mapping[str, str]
|
||||||
capability_ids_by_sub_commander_id: Mapping[str, tuple[str, ...]]
|
capability_ids_by_sub_commander_id: Mapping[str, tuple[str, ...]]
|
||||||
|
spawnable_role_values_by_agent_id: Mapping[str, tuple[str, ...]]
|
||||||
|
|
||||||
|
|
||||||
def summarize_registry_indexes(indexes: RegistryIndexes) -> dict[str, int]:
|
def summarize_registry_indexes(indexes: RegistryIndexes) -> dict[str, int]:
|
||||||
@@ -50,6 +52,9 @@ def build_registry_indexes(bundle: RegistryBundle) -> RegistryIndexes:
|
|||||||
|
|
||||||
return RegistryIndexes(
|
return RegistryIndexes(
|
||||||
agent_by_id=MappingProxyType(agent_by_id),
|
agent_by_id=MappingProxyType(agent_by_id),
|
||||||
|
agent_by_role_value=MappingProxyType({
|
||||||
|
agent.role_value: agent for agent in bundle.agents
|
||||||
|
}),
|
||||||
sub_commander_by_id=MappingProxyType(sub_commander_by_id),
|
sub_commander_by_id=MappingProxyType(sub_commander_by_id),
|
||||||
capability_by_id=MappingProxyType(capability_by_id),
|
capability_by_id=MappingProxyType(capability_by_id),
|
||||||
specialist_template_by_id=MappingProxyType(specialist_template_by_id),
|
specialist_template_by_id=MappingProxyType(specialist_template_by_id),
|
||||||
@@ -73,4 +78,9 @@ def build_registry_indexes(bundle: RegistryBundle) -> RegistryIndexes:
|
|||||||
sub_commander.sub_commander_id: tuple(sub_commander.capability_ids)
|
sub_commander.sub_commander_id: tuple(sub_commander.capability_ids)
|
||||||
for sub_commander in bundle.sub_commanders
|
for sub_commander in bundle.sub_commanders
|
||||||
}),
|
}),
|
||||||
|
spawnable_role_values_by_agent_id=MappingProxyType({
|
||||||
|
agent.agent_id: tuple(agent.allowed_spawn_role_values)
|
||||||
|
for agent in bundle.agents
|
||||||
|
if agent.can_spawn_children and agent.allowed_spawn_role_values
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,4 +1,19 @@
|
|||||||
from pydantic import BaseModel
|
from enum import Enum
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionClass(str, Enum):
|
||||||
|
READ = "read"
|
||||||
|
WRITE = "write"
|
||||||
|
EXTERNAL = "external"
|
||||||
|
|
||||||
|
|
||||||
|
class SideEffectScope(str, Enum):
|
||||||
|
NONE = "none"
|
||||||
|
LOCAL_STATE = "local_state"
|
||||||
|
DB_WRITE = "db_write"
|
||||||
|
NETWORK = "network"
|
||||||
|
|
||||||
|
|
||||||
class AgentManifest(BaseModel):
|
class AgentManifest(BaseModel):
|
||||||
@@ -8,6 +23,8 @@ class AgentManifest(BaseModel):
|
|||||||
system_prompt_key: str
|
system_prompt_key: str
|
||||||
routing_hints: list[str]
|
routing_hints: list[str]
|
||||||
default_sub_commanders: list[str]
|
default_sub_commanders: list[str]
|
||||||
|
can_spawn_children: bool = False
|
||||||
|
allowed_spawn_role_values: list[str] = Field(default_factory=list)
|
||||||
skill_context_key: str | None = None
|
skill_context_key: str | None = None
|
||||||
continuity_policy: str | None = None
|
continuity_policy: str | None = None
|
||||||
clarification_policy: str | None = None
|
clarification_policy: str | None = None
|
||||||
@@ -23,6 +40,12 @@ class SubCommanderManifest(BaseModel):
|
|||||||
class CapabilityManifest(BaseModel):
|
class CapabilityManifest(BaseModel):
|
||||||
capability_id: str
|
capability_id: str
|
||||||
tool_name: str
|
tool_name: str
|
||||||
|
permission_class: PermissionClass = PermissionClass.READ
|
||||||
|
side_effect_scope: SideEffectScope = SideEffectScope.NONE
|
||||||
|
supports_retry: bool = False
|
||||||
|
idempotent: bool = False
|
||||||
|
safe_for_parallel_use: bool = False
|
||||||
|
requires_confirmation: bool = False
|
||||||
|
|
||||||
|
|
||||||
class SpecialistTemplateManifest(BaseModel):
|
class SpecialistTemplateManifest(BaseModel):
|
||||||
|
|||||||
86
backend/app/agents/runtime_metrics.py
Normal file
86
backend/app/agents/runtime_metrics.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
INPUT_TOKEN_USD_RATE = 0.000003
|
||||||
|
OUTPUT_TOKEN_USD_RATE = 0.000015
|
||||||
|
DEFAULT_COST_THRESHOLDS = {
|
||||||
|
"total_tokens": 4000,
|
||||||
|
"estimated_cost": 0.02,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_token_cost(input_tokens: int, output_tokens: int) -> float | None:
|
||||||
|
total_tokens = max(input_tokens, 0) + max(output_tokens, 0)
|
||||||
|
if total_tokens <= 0:
|
||||||
|
return None
|
||||||
|
return round(
|
||||||
|
(max(input_tokens, 0) * INPUT_TOKEN_USD_RATE)
|
||||||
|
+ (max(output_tokens, 0) * OUTPUT_TOKEN_USD_RATE),
|
||||||
|
6,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_token_usage(response: Any) -> tuple[int, int]:
|
||||||
|
usage_metadata = getattr(response, "usage_metadata", None) or {}
|
||||||
|
if isinstance(usage_metadata, dict):
|
||||||
|
input_tokens = int(
|
||||||
|
usage_metadata.get("input_tokens")
|
||||||
|
or usage_metadata.get("prompt_tokens")
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
output_tokens = int(
|
||||||
|
usage_metadata.get("output_tokens")
|
||||||
|
or usage_metadata.get("completion_tokens")
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
if input_tokens or output_tokens:
|
||||||
|
return input_tokens, output_tokens
|
||||||
|
|
||||||
|
response_metadata = getattr(response, "response_metadata", None) or {}
|
||||||
|
token_usage = {}
|
||||||
|
if isinstance(response_metadata, dict):
|
||||||
|
token_usage = response_metadata.get("token_usage") or response_metadata.get("usage") or {}
|
||||||
|
if isinstance(token_usage, dict):
|
||||||
|
input_tokens = int(
|
||||||
|
token_usage.get("prompt_tokens")
|
||||||
|
or token_usage.get("input_tokens")
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
output_tokens = int(
|
||||||
|
token_usage.get("completion_tokens")
|
||||||
|
or token_usage.get("output_tokens")
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
if input_tokens or output_tokens:
|
||||||
|
return input_tokens, output_tokens
|
||||||
|
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
|
||||||
|
def coerce_cost_thresholds(raw_thresholds: Any) -> dict[str, float]:
|
||||||
|
thresholds: dict[str, float] = dict(DEFAULT_COST_THRESHOLDS)
|
||||||
|
if not isinstance(raw_thresholds, dict):
|
||||||
|
return thresholds
|
||||||
|
for key in DEFAULT_COST_THRESHOLDS:
|
||||||
|
value = raw_thresholds.get(key)
|
||||||
|
if isinstance(value, (int, float)) and value > 0:
|
||||||
|
thresholds[key] = float(value)
|
||||||
|
return thresholds
|
||||||
|
|
||||||
|
|
||||||
|
def is_cost_budget_warning(
|
||||||
|
input_tokens: int,
|
||||||
|
output_tokens: int,
|
||||||
|
estimated_cost: float | None,
|
||||||
|
thresholds: dict[str, float] | None = None,
|
||||||
|
) -> bool:
|
||||||
|
effective_thresholds = thresholds or DEFAULT_COST_THRESHOLDS
|
||||||
|
total_tokens = max(input_tokens, 0) + max(output_tokens, 0)
|
||||||
|
token_threshold = float(effective_thresholds.get("total_tokens") or 0)
|
||||||
|
cost_threshold = float(effective_thresholds.get("estimated_cost") or 0)
|
||||||
|
return (
|
||||||
|
(token_threshold > 0 and total_tokens >= token_threshold)
|
||||||
|
or (cost_threshold > 0 and estimated_cost is not None and estimated_cost >= cost_threshold)
|
||||||
|
)
|
||||||
25
backend/app/agents/schemas/__init__.py
Normal file
25
backend/app/agents/schemas/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from app.agents.schemas.event import AgentEvent
|
||||||
|
from app.agents.schemas.message import AgentMessage
|
||||||
|
from app.agents.schemas.task import (
|
||||||
|
AgentTask,
|
||||||
|
CollaborationBudget,
|
||||||
|
InterruptRecord,
|
||||||
|
RecoveryRecord,
|
||||||
|
TaskLifecycleStatus,
|
||||||
|
TaskResult,
|
||||||
|
TaskResultStatus,
|
||||||
|
VerificationStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AgentEvent",
|
||||||
|
"AgentMessage",
|
||||||
|
"AgentTask",
|
||||||
|
"CollaborationBudget",
|
||||||
|
"InterruptRecord",
|
||||||
|
"RecoveryRecord",
|
||||||
|
"TaskLifecycleStatus",
|
||||||
|
"TaskResult",
|
||||||
|
"TaskResultStatus",
|
||||||
|
"VerificationStatus",
|
||||||
|
]
|
||||||
52
backend/app/agents/schemas/event.py
Normal file
52
backend/app/agents/schemas/event.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
AgentEventType = Literal[
|
||||||
|
"agent.tool.start",
|
||||||
|
"agent.tool.result",
|
||||||
|
"agent.verify.started",
|
||||||
|
"agent.verify.completed",
|
||||||
|
"agent.created",
|
||||||
|
"agent.spawn.blocked",
|
||||||
|
"agent.message.sent",
|
||||||
|
"agent.message.received",
|
||||||
|
"agent.interrupt.requested",
|
||||||
|
"agent.interrupt.completed",
|
||||||
|
"agent.recovery.started",
|
||||||
|
"agent.recovery.completed",
|
||||||
|
"agent.task.interrupted",
|
||||||
|
"agent.task.recovered",
|
||||||
|
"agent.task.reassigned",
|
||||||
|
"agent.collaboration.budget.updated",
|
||||||
|
"agent.isolation.selected",
|
||||||
|
"agent.isolation.fallback",
|
||||||
|
"agent.cost.updated",
|
||||||
|
"agent.cost.warning",
|
||||||
|
"agent.phase.changed",
|
||||||
|
"agent.checkpoint.recorded",
|
||||||
|
"agent.error",
|
||||||
|
]
|
||||||
|
AgentEventSeverity = Literal["info", "warning", "error"]
|
||||||
|
|
||||||
|
|
||||||
|
class AgentEvent(BaseModel):
|
||||||
|
event_id: str
|
||||||
|
event_type: AgentEventType
|
||||||
|
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
conversation_id: str | None = None
|
||||||
|
agent_id: str | None = None
|
||||||
|
sub_commander_id: str | None = None
|
||||||
|
task_id: str | None = None
|
||||||
|
parent_task_id: str | None = None
|
||||||
|
child_task_id: str | None = None
|
||||||
|
thread_id: str | None = None
|
||||||
|
message_id: str | None = None
|
||||||
|
interrupt_id: str | None = None
|
||||||
|
recovery_id: str | None = None
|
||||||
|
payload: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
severity: AgentEventSeverity = "info"
|
||||||
29
backend/app/agents/schemas/message.py
Normal file
29
backend/app/agents/schemas/message.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
AgentMessageType = Literal[
|
||||||
|
"task_request",
|
||||||
|
"task_update",
|
||||||
|
"handoff",
|
||||||
|
"verification_request",
|
||||||
|
"verification_feedback",
|
||||||
|
"interrupt_notice",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class AgentMessage(BaseModel):
|
||||||
|
message_id: str
|
||||||
|
thread_id: str
|
||||||
|
from_agent_id: str
|
||||||
|
to_agent_id: str
|
||||||
|
task_id: str | None = None
|
||||||
|
reply_to_message_id: str | None = None
|
||||||
|
message_type: AgentMessageType = "task_update"
|
||||||
|
content_summary: str
|
||||||
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
payload: dict[str, Any] = Field(default_factory=dict)
|
||||||
85
backend/app/agents/schemas/task.py
Normal file
85
backend/app/agents/schemas/task.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
TaskLifecycleStatus = Literal["pending", "in_progress", "completed", "failed", "blocked"]
|
||||||
|
VerificationStatus = Literal["passed", "failed", "skipped"]
|
||||||
|
TaskResultStatus = Literal["completed", "failed", "blocked", "passed", "skipped"]
|
||||||
|
InterruptStatus = Literal["requested", "acknowledged", "resolved"]
|
||||||
|
BudgetMode = Literal["direct", "collaboration"]
|
||||||
|
|
||||||
|
|
||||||
|
class InterruptRecord(BaseModel):
|
||||||
|
interrupt_id: str
|
||||||
|
reason: str
|
||||||
|
status: InterruptStatus = "requested"
|
||||||
|
requested_by: str | None = None
|
||||||
|
source_event_id: str | None = None
|
||||||
|
requested_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
payload: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class RecoveryRecord(BaseModel):
|
||||||
|
recovery_id: str
|
||||||
|
source_interrupt_id: str | None = None
|
||||||
|
strategy: str | None = None
|
||||||
|
resumed_from_task_id: str | None = None
|
||||||
|
resumed_from_thread_id: str | None = None
|
||||||
|
recovered_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
payload: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class CollaborationBudget(BaseModel):
|
||||||
|
mode: BudgetMode = "direct"
|
||||||
|
max_parallel_tasks: int | None = None
|
||||||
|
remaining_parallel_tasks: int | None = None
|
||||||
|
max_tool_calls: int | None = None
|
||||||
|
remaining_tool_calls: int | None = None
|
||||||
|
max_iterations: int | None = None
|
||||||
|
remaining_iterations: int | None = None
|
||||||
|
escalation_threshold: int | None = None
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTask(BaseModel):
|
||||||
|
task_id: str
|
||||||
|
title: str
|
||||||
|
status: TaskLifecycleStatus = "pending"
|
||||||
|
owner_agent_id: str | None = None
|
||||||
|
role: str | None = None
|
||||||
|
goal: str | None = None
|
||||||
|
parent_task_id: str | None = None
|
||||||
|
child_task_ids: list[str] = Field(default_factory=list)
|
||||||
|
thread_id: str | None = None
|
||||||
|
message_id: str | None = None
|
||||||
|
message_index: int | None = None
|
||||||
|
expected_evidence: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
evidence: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
interrupt_records: list[InterruptRecord | dict[str, Any]] = Field(default_factory=list)
|
||||||
|
recovery_records: list[RecoveryRecord | dict[str, Any]] = Field(default_factory=list)
|
||||||
|
collaboration_budget: CollaborationBudget | dict[str, Any] | None = None
|
||||||
|
result_summary: str | None = None
|
||||||
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
|
|
||||||
|
class TaskResult(BaseModel):
|
||||||
|
task_id: str
|
||||||
|
status: TaskResultStatus
|
||||||
|
summary: str | None = None
|
||||||
|
evidence: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
owner_agent_id: str | None = None
|
||||||
|
parent_task_id: str | None = None
|
||||||
|
child_task_ids: list[str] = Field(default_factory=list)
|
||||||
|
thread_id: str | None = None
|
||||||
|
message_id: str | None = None
|
||||||
|
message_index: int | None = None
|
||||||
|
interrupt_records: list[InterruptRecord | dict[str, Any]] = Field(default_factory=list)
|
||||||
|
recovery_records: list[RecoveryRecord | dict[str, Any]] = Field(default_factory=list)
|
||||||
|
budget_snapshot: CollaborationBudget | dict[str, Any] | None = None
|
||||||
|
next_action: str | None = None
|
||||||
|
output_data: dict[str, Any] | None = None
|
||||||
17
backend/app/agents/session/__init__.py
Normal file
17
backend/app/agents/session/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""Agent Session Management - Phase 10.3"""
|
||||||
|
|
||||||
|
from app.agents.session.manager import (
|
||||||
|
AgentSession,
|
||||||
|
SessionContext,
|
||||||
|
SessionPersistence,
|
||||||
|
create_agent_session,
|
||||||
|
get_agent_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AgentSession",
|
||||||
|
"SessionContext",
|
||||||
|
"SessionPersistence",
|
||||||
|
"create_agent_session",
|
||||||
|
"get_agent_session",
|
||||||
|
]
|
||||||
238
backend/app/agents/session/manager.py
Normal file
238
backend/app/agents/session/manager.py
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
"""Agent Session 管理 - Phase 10.3
|
||||||
|
|
||||||
|
支持会话层级管理和子会话创建。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SessionContext:
|
||||||
|
"""会话上下文"""
|
||||||
|
|
||||||
|
session_id: str
|
||||||
|
parent_session_id: str | None = None
|
||||||
|
root_session_id: str | None = None
|
||||||
|
depth: int = 0
|
||||||
|
user_id: str | None = None
|
||||||
|
created_at: str | None = None
|
||||||
|
last_active: str | None = None
|
||||||
|
message_count: int = 0
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.created_at is None:
|
||||||
|
self.created_at = datetime.now().isoformat()
|
||||||
|
if self.last_active is None:
|
||||||
|
self.last_active = self.created_at
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SessionPersistence:
|
||||||
|
"""会话持久化"""
|
||||||
|
|
||||||
|
def __init__(self, persistence_dir: str | None = None):
|
||||||
|
if persistence_dir is None:
|
||||||
|
persistence_dir = os.path.join(
|
||||||
|
os.path.dirname(__file__), "..", "..", "..", "data", "sessions"
|
||||||
|
)
|
||||||
|
self.persistence_dir = persistence_dir
|
||||||
|
|
||||||
|
def _get_session_path(self, session_id: str) -> str:
|
||||||
|
return os.path.join(self.persistence_dir, f"{session_id}.json")
|
||||||
|
|
||||||
|
def save(self, session: "AgentSession") -> bool:
|
||||||
|
"""保存会话"""
|
||||||
|
try:
|
||||||
|
os.makedirs(self.persistence_dir, exist_ok=True)
|
||||||
|
path = self._get_session_path(session.session_id)
|
||||||
|
data = {
|
||||||
|
"session_id": session.session_id,
|
||||||
|
"parent_session_id": session.context.parent_session_id,
|
||||||
|
"root_session_id": session.context.root_session_id,
|
||||||
|
"depth": session.context.depth,
|
||||||
|
"user_id": session.context.user_id,
|
||||||
|
"created_at": session.context.created_at,
|
||||||
|
"last_active": session.context.last_active,
|
||||||
|
"message_count": session.context.message_count,
|
||||||
|
"metadata": session.context.metadata,
|
||||||
|
"history": session._history,
|
||||||
|
}
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def load(self, session_id: str) -> dict[str, Any] | None:
|
||||||
|
"""加载会话"""
|
||||||
|
try:
|
||||||
|
path = self._get_session_path(session_id)
|
||||||
|
if not os.path.exists(path):
|
||||||
|
return None
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def delete(self, session_id: str) -> bool:
|
||||||
|
"""删除会话"""
|
||||||
|
try:
|
||||||
|
path = self._get_session_path(session_id)
|
||||||
|
if os.path.exists(path):
|
||||||
|
os.remove(path)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]:
|
||||||
|
"""列出所有会话"""
|
||||||
|
sessions = []
|
||||||
|
try:
|
||||||
|
os.makedirs(self.persistence_dir, exist_ok=True)
|
||||||
|
for filename in os.listdir(self.persistence_dir):
|
||||||
|
if filename.endswith(".json"):
|
||||||
|
path = os.path.join(self.persistence_dir, filename)
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
if user_id is None or data.get("user_id") == user_id:
|
||||||
|
sessions.append(data)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return sessions
|
||||||
|
|
||||||
|
|
||||||
|
class AgentSession:
|
||||||
|
"""Agent 会话管理器
|
||||||
|
|
||||||
|
支持:
|
||||||
|
- 会话层级(parent/root/depth)
|
||||||
|
- 子会话创建
|
||||||
|
- 会话摘要
|
||||||
|
- 持久化
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
parent_session_id: str | None = None,
|
||||||
|
):
|
||||||
|
self.session_id = session_id or str(uuid.uuid4())[:8]
|
||||||
|
self.context = SessionContext(
|
||||||
|
session_id=self.session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
parent_session_id=parent_session_id,
|
||||||
|
depth=0 if parent_session_id is None else 1,
|
||||||
|
)
|
||||||
|
self._history: list[dict[str, Any]] = []
|
||||||
|
self._persistence = SessionPersistence()
|
||||||
|
|
||||||
|
# 如果有父会话,设置 root_session_id
|
||||||
|
if parent_session_id:
|
||||||
|
parent_data = self._persistence.load(parent_session_id)
|
||||||
|
if parent_data:
|
||||||
|
self.context.root_session_id = (
|
||||||
|
parent_data.get("root_session_id") or parent_session_id
|
||||||
|
)
|
||||||
|
self.context.depth = parent_data.get("depth", 0) + 1
|
||||||
|
|
||||||
|
async def initialize(self) -> dict[str, Any]:
|
||||||
|
"""初始化会话"""
|
||||||
|
self.context.last_active = datetime.now().isoformat()
|
||||||
|
self._persistence.save(self)
|
||||||
|
return {
|
||||||
|
"session_id": self.session_id,
|
||||||
|
"depth": self.context.depth,
|
||||||
|
"parent_session_id": self.context.parent_session_id,
|
||||||
|
"root_session_id": self.context.root_session_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def process_message(self, message: str, response: str) -> None:
|
||||||
|
"""处理消息并记录到历史"""
|
||||||
|
self.context.message_count += 1
|
||||||
|
self.context.last_active = datetime.now().isoformat()
|
||||||
|
self._history.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": message,
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self._history.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": response,
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self._persistence.save(self)
|
||||||
|
|
||||||
|
async def spawn_child_session(self, user_id: str | None = None) -> "AgentSession":
|
||||||
|
"""创建子会话"""
|
||||||
|
child = AgentSession(
|
||||||
|
user_id=user_id or self.context.user_id,
|
||||||
|
parent_session_id=self.session_id,
|
||||||
|
)
|
||||||
|
child.context.root_session_id = self.context.root_session_id or self.session_id
|
||||||
|
await child.initialize()
|
||||||
|
return child
|
||||||
|
|
||||||
|
async def get_session_summary(self) -> dict[str, Any]:
|
||||||
|
"""获取会话摘要"""
|
||||||
|
return {
|
||||||
|
"session_id": self.session_id,
|
||||||
|
"parent_session_id": self.context.parent_session_id,
|
||||||
|
"root_session_id": self.context.root_session_id,
|
||||||
|
"depth": self.context.depth,
|
||||||
|
"user_id": self.context.user_id,
|
||||||
|
"created_at": self.context.created_at,
|
||||||
|
"last_active": self.context.last_active,
|
||||||
|
"message_count": self.context.message_count,
|
||||||
|
"history_length": len(self._history),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def persist(self) -> bool:
|
||||||
|
"""持久化会话"""
|
||||||
|
return self._persistence.save(self)
|
||||||
|
|
||||||
|
def get_history(self) -> list[dict[str, Any]]:
|
||||||
|
"""获取会话历史"""
|
||||||
|
return self._history.copy()
|
||||||
|
|
||||||
|
def add_metadata(self, key: str, value: Any) -> None:
|
||||||
|
"""添加会话元数据"""
|
||||||
|
self.context.metadata[key] = value
|
||||||
|
|
||||||
|
def get_metadata(self, key: str) -> Any:
|
||||||
|
"""获取会话元数据"""
|
||||||
|
return self.context.metadata.get(key)
|
||||||
|
|
||||||
|
|
||||||
|
# 全局会话存储(内存中)
|
||||||
|
_sessions: dict[str, AgentSession] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_agent_session(session_id: str) -> AgentSession | None:
|
||||||
|
"""获取会话"""
|
||||||
|
return _sessions.get(session_id)
|
||||||
|
|
||||||
|
|
||||||
|
def create_agent_session(
|
||||||
|
session_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
parent_session_id: str | None = None,
|
||||||
|
) -> AgentSession:
|
||||||
|
"""创建新会话"""
|
||||||
|
session = AgentSession(
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
parent_session_id=parent_session_id,
|
||||||
|
)
|
||||||
|
_sessions[session.session_id] = session
|
||||||
|
return session
|
||||||
16
backend/app/agents/skills/__init__.py
Normal file
16
backend/app/agents/skills/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""Skills 注册表 - Phase 9"""
|
||||||
|
|
||||||
|
from app.agents.skills.registry import SkillRegistry, get_skill_registry
|
||||||
|
from app.agents.skills.metadata import SkillMetadata
|
||||||
|
from app.agents.skills.loaders.local_loader import LocalSkillLoader
|
||||||
|
from app.agents.skills.loaders.plugin_loader import PluginSkillLoader
|
||||||
|
from app.agents.skills.mcp_builder import MCPSkillBuilder
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SkillRegistry",
|
||||||
|
"SkillMetadata",
|
||||||
|
"LocalSkillLoader",
|
||||||
|
"PluginSkillLoader",
|
||||||
|
"MCPSkillBuilder",
|
||||||
|
"get_skill_registry",
|
||||||
|
]
|
||||||
72
backend/app/agents/skills/bundled.py
Normal file
72
backend/app/agents/skills/bundled.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
"""Built-in Skills - Phase 9.4
|
||||||
|
|
||||||
|
This module contains bundled skills that are always available
|
||||||
|
without requiring external skill loaders.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
# SkillMetadata-compatible structure for bundled skills
|
||||||
|
BUNDLED_SKILLS: list[dict[str, Any]] = [
|
||||||
|
{
|
||||||
|
"id": "code-analysis",
|
||||||
|
"name": "Code Analysis",
|
||||||
|
"description": "Analyze code structure, patterns, and quality. Helps understand codebase architecture, find issues, and suggest improvements.",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"prompts": [
|
||||||
|
"Analyze the code structure and identify key components, their relationships, and responsibilities.",
|
||||||
|
"Review the code for potential issues like bugs, security vulnerabilities, or performance problems.",
|
||||||
|
"Explain how the code works and what it does in simple terms.",
|
||||||
|
],
|
||||||
|
"tools": ["grep", "read", "glob", "lsp_symbols", "lsp_find_references"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "git-helper",
|
||||||
|
"name": "Git Helper",
|
||||||
|
"description": "Assists with Git operations including commit, branch management, merge conflicts, and repository exploration.",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"prompts": [
|
||||||
|
"Show me the current git status and any uncommitted changes.",
|
||||||
|
"Help me create a meaningful commit message for these changes.",
|
||||||
|
"Explain the git history and branch structure of this repository.",
|
||||||
|
],
|
||||||
|
"tools": ["bash"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "web-research",
|
||||||
|
"name": "Web Research",
|
||||||
|
"description": "Search the web for information, documentation, and resources. Helps find answers and learn about technologies.",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"prompts": [
|
||||||
|
"Search the web for information about {topic} and summarize the key findings.",
|
||||||
|
"Find official documentation or reliable resources about {topic}.",
|
||||||
|
"Look up the latest news or developments in {topic}.",
|
||||||
|
],
|
||||||
|
"tools": ["search_brave_web_search", "websearch_web_search_exa", "webfetch"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "file-management",
|
||||||
|
"name": "File Management",
|
||||||
|
"description": "Helps with file operations like creating, editing, organizing, and managing project files and directories.",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"prompts": [
|
||||||
|
"Create a new file at {path} with the following content: {content}",
|
||||||
|
"Organize the files in the project structure and suggest improvements.",
|
||||||
|
"Find all files related to {topic} or matching {pattern}.",
|
||||||
|
],
|
||||||
|
"tools": ["read", "write", "glob", "bash"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "task-planning",
|
||||||
|
"name": "Task Planning",
|
||||||
|
"description": "Helps break down complex tasks into smaller steps, create implementation plans, and track progress.",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"prompts": [
|
||||||
|
"Break down this task into smaller, manageable steps: {task}",
|
||||||
|
"Create an implementation plan for building {feature} with clear phases.",
|
||||||
|
"Review the current progress and suggest next steps for completing {goal}.",
|
||||||
|
],
|
||||||
|
"tools": ["todowrite", "read", "write"],
|
||||||
|
},
|
||||||
|
]
|
||||||
12
backend/app/agents/skills/loaders/__init__.py
Normal file
12
backend/app/agents/skills/loaders/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""Skills 加载器包"""
|
||||||
|
|
||||||
|
from app.agents.skills.loaders.local_loader import LocalSkillLoader
|
||||||
|
from app.agents.skills.loaders.plugin_loader import PluginSkillLoader
|
||||||
|
from app.agents.skills.loaders.mcp_loader import MCPSkillLoader, get_mcp_skill_loader
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LocalSkillLoader",
|
||||||
|
"PluginSkillLoader",
|
||||||
|
"MCPSkillLoader",
|
||||||
|
"get_mcp_skill_loader",
|
||||||
|
]
|
||||||
100
backend/app/agents/skills/loaders/local_loader.py
Normal file
100
backend/app/agents/skills/loaders/local_loader.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
"""本地 Skills 加载器 - Phase 9.2"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.skills.metadata import SkillMetadata
|
||||||
|
|
||||||
|
|
||||||
|
class LocalSkillLoader:
|
||||||
|
"""本地 Skills 加载器
|
||||||
|
|
||||||
|
从 skills_dir 目录加载 SKILL.md 文件。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, skills_dir: str):
|
||||||
|
self.skills_dir = skills_dir
|
||||||
|
|
||||||
|
def load_all(self) -> list[SkillMetadata]:
|
||||||
|
"""加载所有本地 Skills
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Skill 元数据列表
|
||||||
|
"""
|
||||||
|
skills = []
|
||||||
|
|
||||||
|
if not os.path.exists(self.skills_dir):
|
||||||
|
return skills
|
||||||
|
|
||||||
|
for root, dirs, files in os.walk(self.skills_dir):
|
||||||
|
# 跳过隐藏目录
|
||||||
|
dirs[:] = [d for d in dirs if not d.startswith(".")]
|
||||||
|
|
||||||
|
if "SKILL.md" in files:
|
||||||
|
skill = self._load_skill_from_dir(root)
|
||||||
|
if skill:
|
||||||
|
skills.append(skill)
|
||||||
|
|
||||||
|
return skills
|
||||||
|
|
||||||
|
def _load_skill_from_dir(self, skill_dir: str) -> SkillMetadata | None:
|
||||||
|
"""从目录加载 Skill
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_dir: Skill 目录
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Skill 元数据
|
||||||
|
"""
|
||||||
|
skill_path = os.path.join(skill_dir, "SKILL.md")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(skill_path, "r", encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# 解析 frontmatter
|
||||||
|
metadata = self._parse_frontmatter(content)
|
||||||
|
|
||||||
|
# 获取 Skill 名称(目录名)
|
||||||
|
name = os.path.basename(skill_dir)
|
||||||
|
|
||||||
|
return SkillMetadata(
|
||||||
|
name=metadata.get("name", name),
|
||||||
|
description=metadata.get("description", ""),
|
||||||
|
version=metadata.get("version", "1.0.0"),
|
||||||
|
author=metadata.get("author", ""),
|
||||||
|
tags=metadata.get("tags", []),
|
||||||
|
triggers=metadata.get("triggers", []),
|
||||||
|
content=content,
|
||||||
|
source="local",
|
||||||
|
source_id=skill_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _parse_frontmatter(self, content: str) -> dict[str, Any]:
|
||||||
|
"""解析 frontmatter"""
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
# 匹配 --- 包裹的 frontmatter
|
||||||
|
match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
frontmatter = match.group(1)
|
||||||
|
|
||||||
|
for line in frontmatter.split("\n"):
|
||||||
|
if ":" in line:
|
||||||
|
key, value = line.split(":", 1)
|
||||||
|
key = key.strip()
|
||||||
|
value = value.strip()
|
||||||
|
|
||||||
|
# 处理列表
|
||||||
|
if value.startswith("[") and value.endswith("]"):
|
||||||
|
value = [v.strip().strip('"').strip("'") for v in value[1:-1].split(",")]
|
||||||
|
elif value.lower() in ("true", "false"):
|
||||||
|
value = value.lower() == "true"
|
||||||
|
|
||||||
|
metadata[key] = value
|
||||||
|
|
||||||
|
return metadata
|
||||||
169
backend/app/agents/skills/loaders/mcp_loader.py
Normal file
169
backend/app/agents/skills/loaders/mcp_loader.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
"""MCP Skill 加载器 - Phase 9.2
|
||||||
|
|
||||||
|
从 MCP (Model Context Protocol) 服务器发现和加载 Skills。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.skills.metadata import SkillMetadata
|
||||||
|
|
||||||
|
|
||||||
|
class MCPSkillLoader:
|
||||||
|
"""MCP Skill 加载器
|
||||||
|
|
||||||
|
从 MCP 服务器发现可用的 Skills。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mcp_servers: list[dict[str, Any]] | None = None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
mcp_servers: MCP 服务器列表,每项包含 name, command, env 等
|
||||||
|
"""
|
||||||
|
self.mcp_servers = mcp_servers or []
|
||||||
|
self._discovered_skills: dict[str, SkillMetadata] = {}
|
||||||
|
|
||||||
|
def discover_skills(self) -> list[SkillMetadata]:
|
||||||
|
"""从所有配置的 MCP 服务器发现 Skills
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
发现的 Skill 列表
|
||||||
|
"""
|
||||||
|
skills = []
|
||||||
|
|
||||||
|
for server in self.mcp_servers:
|
||||||
|
server_skills = self._discover_from_server(server)
|
||||||
|
skills.extend(server_skills)
|
||||||
|
|
||||||
|
return skills
|
||||||
|
|
||||||
|
def _discover_from_server(self, server: dict[str, Any]) -> list[SkillMetadata]:
|
||||||
|
"""从单个 MCP 服务器发现 Skills
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server: 服务器配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Skill 列表
|
||||||
|
"""
|
||||||
|
skills = []
|
||||||
|
server_name = server.get("name", "unknown")
|
||||||
|
|
||||||
|
# 模拟从 MCP 服务器获取工具列表
|
||||||
|
# 实际实现时,这里会调用 MCP 服务器的 list_tools 接口
|
||||||
|
try:
|
||||||
|
tools = self._call_mcp_list_tools(server)
|
||||||
|
for tool in tools:
|
||||||
|
skill = self._tool_to_skill(tool, server_name)
|
||||||
|
if skill:
|
||||||
|
skills.append(skill)
|
||||||
|
self._discovered_skills[skill.name] = skill
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return skills
|
||||||
|
|
||||||
|
def _call_mcp_list_tools(self, server: dict[str, Any]) -> list[dict[str, Any]]:
|
||||||
|
"""调用 MCP 服务器的 list_tools 接口
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server: 服务器配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具列表
|
||||||
|
"""
|
||||||
|
# TODO: 实现实际的 MCP 协议调用
|
||||||
|
# 目前返回空列表,实际使用时需要实现 MCP 客户端
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _tool_to_skill(self, tool: dict[str, Any], server: str) -> SkillMetadata | None:
|
||||||
|
"""将 MCP 工具转换为 Skill
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool: MCP 工具定义
|
||||||
|
server: 服务器名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Skill 元数据或 None
|
||||||
|
"""
|
||||||
|
tool_name = tool.get("name")
|
||||||
|
if not tool_name:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return SkillMetadata(
|
||||||
|
id=f"mcp_{server}_{tool_name}",
|
||||||
|
name=f"{server}:{tool_name}",
|
||||||
|
description=tool.get("description", f"MCP tool: {tool_name}"),
|
||||||
|
version="1.0.0",
|
||||||
|
content=self._generate_skill_content(tool),
|
||||||
|
triggers=[f"@{server}", f"/{tool_name}"],
|
||||||
|
tools=[tool_name],
|
||||||
|
tags=["mcp", server],
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_skill_content(self, tool: dict[str, Any]) -> str:
|
||||||
|
"""生成 Skill 内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool: MCP 工具定义
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Skill 内容字符串
|
||||||
|
"""
|
||||||
|
name = tool.get("name", "unknown")
|
||||||
|
description = tool.get("description", "No description")
|
||||||
|
input_schema = tool.get("inputSchema", {})
|
||||||
|
|
||||||
|
content = f"""# MCP Tool: {name}
|
||||||
|
|
||||||
|
**Description**: {description}
|
||||||
|
|
||||||
|
**Server**: {tool.get("server", "unknown")}
|
||||||
|
|
||||||
|
**Input Schema**:
|
||||||
|
```json
|
||||||
|
{input_schema}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Usage**:
|
||||||
|
Use the `/{name}` command or `@{tool.get("server", "server")}` to invoke this tool.
|
||||||
|
|
||||||
|
**Examples**:
|
||||||
|
```
|
||||||
|
/{name} arg1=value1 arg2=value2
|
||||||
|
@{tool.get("server", "server")} {name} --arg1 value1
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
return content
|
||||||
|
|
||||||
|
def get_skill(self, name: str) -> SkillMetadata | None:
|
||||||
|
"""获取已发现的 Skill
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Skill 名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Skill 元数据或 None
|
||||||
|
"""
|
||||||
|
return self._discovered_skills.get(name)
|
||||||
|
|
||||||
|
def list_skills(self) -> list[SkillMetadata]:
|
||||||
|
"""列出所有已发现的 Skills
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Skill 列表
|
||||||
|
"""
|
||||||
|
return list(self._discovered_skills.values())
|
||||||
|
|
||||||
|
|
||||||
|
# 全局加载器
|
||||||
|
_loader: MCPSkillLoader | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_mcp_skill_loader() -> MCPSkillLoader:
|
||||||
|
"""获取全局 MCP Skill 加载器"""
|
||||||
|
global _loader
|
||||||
|
if _loader is None:
|
||||||
|
_loader = MCPSkillLoader()
|
||||||
|
return _loader
|
||||||
53
backend/app/agents/skills/loaders/plugin_loader.py
Normal file
53
backend/app/agents/skills/loaders/plugin_loader.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
"""插件 Skills 加载器 - Phase 9.2"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.skills.metadata import SkillMetadata
|
||||||
|
from app.agents.plugins.manager import get_plugin_manager
|
||||||
|
|
||||||
|
|
||||||
|
class PluginSkillLoader:
|
||||||
|
"""插件 Skills 加载器
|
||||||
|
|
||||||
|
从已安装的插件中加载 Skills。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.plugin_manager = get_plugin_manager()
|
||||||
|
|
||||||
|
def load_all(self) -> list[SkillMetadata]:
|
||||||
|
"""从所有已启用的插件加载 Skills
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Skill 元数据列表
|
||||||
|
"""
|
||||||
|
skills = []
|
||||||
|
|
||||||
|
for plugin in self.plugin_manager.list_plugins():
|
||||||
|
if not self.plugin_manager.is_enabled(plugin.id):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 从插件加载 Skills
|
||||||
|
plugin_skills = self._load_from_plugin(plugin)
|
||||||
|
skills.extend(plugin_skills)
|
||||||
|
|
||||||
|
return skills
|
||||||
|
|
||||||
|
def _load_from_plugin(self, plugin: Any) -> list[SkillMetadata]:
|
||||||
|
"""从单个插件加载 Skills"""
|
||||||
|
skills = []
|
||||||
|
|
||||||
|
for skill_name in plugin.skills:
|
||||||
|
skill = SkillMetadata(
|
||||||
|
name=f"{plugin.id}/{skill_name}",
|
||||||
|
description=f"Skill from plugin: {plugin.name}",
|
||||||
|
version=plugin.version,
|
||||||
|
author=plugin.author,
|
||||||
|
tags=["plugin", plugin.id],
|
||||||
|
content=f"# {skill_name}\n\nFrom plugin: {plugin.name}",
|
||||||
|
source="plugin",
|
||||||
|
source_id=plugin.id,
|
||||||
|
)
|
||||||
|
skills.append(skill)
|
||||||
|
|
||||||
|
return skills
|
||||||
100
backend/app/agents/skills/mcp_builder.py
Normal file
100
backend/app/agents/skills/mcp_builder.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
"""MCP Skill Builder - Phase 9.3"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.skills.metadata import SkillMetadata
|
||||||
|
|
||||||
|
|
||||||
|
class MCPSkillBuilder:
|
||||||
|
"""MCP Skill Builder
|
||||||
|
|
||||||
|
从 MCP 服务器发现和构建 Skills。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._skills: dict[str, SkillMetadata] = {}
|
||||||
|
|
||||||
|
def discover_skills_from_mcp(self, mcp_servers: list[dict[str, Any]]) -> list[SkillMetadata]:
|
||||||
|
"""从 MCP 服务器发现 Skills
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_servers: MCP 服务器配置列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
发现的 Skill 元数据列表
|
||||||
|
"""
|
||||||
|
skills = []
|
||||||
|
|
||||||
|
for server in mcp_servers:
|
||||||
|
server_skills = self._discover_from_server(server)
|
||||||
|
skills.extend(server_skills)
|
||||||
|
|
||||||
|
return skills
|
||||||
|
|
||||||
|
def _discover_from_server(self, server: dict[str, Any]) -> list[SkillMetadata]:
|
||||||
|
"""从单个 MCP 服务器发现 Skills"""
|
||||||
|
skills = []
|
||||||
|
server_name = server.get("name", "unknown")
|
||||||
|
tools = server.get("tools", [])
|
||||||
|
|
||||||
|
# 按工具分组
|
||||||
|
tool_groups: dict[str, list[str]] = {}
|
||||||
|
for tool in tools:
|
||||||
|
group = tool.get("group", "default")
|
||||||
|
if group not in tool_groups:
|
||||||
|
tool_groups[group] = []
|
||||||
|
tool_groups[group].append(tool)
|
||||||
|
|
||||||
|
# 为每个组创建一个 Skill
|
||||||
|
for group_name, group_tools in tool_groups.items():
|
||||||
|
skill = self._tool_to_skill(group_name, group_tools, server_name)
|
||||||
|
skills.append(skill)
|
||||||
|
|
||||||
|
return skills
|
||||||
|
|
||||||
|
def _tool_to_skill(self, group: str, tools: list[dict[str, Any]], server: str) -> SkillMetadata:
|
||||||
|
"""将 MCP 工具转换为 Skill"""
|
||||||
|
tool_summaries = []
|
||||||
|
for tool in tools:
|
||||||
|
name = tool.get("name", "unknown")
|
||||||
|
description = tool.get("description", "")
|
||||||
|
input_schema = tool.get("inputSchema", {})
|
||||||
|
|
||||||
|
tool_summaries.append(f"### {name}\n{description}\n\nInput: {input_schema}")
|
||||||
|
|
||||||
|
content = f"""# MCP Skill: {group}
|
||||||
|
|
||||||
|
来自 MCP 服务器: {server}
|
||||||
|
|
||||||
|
## 工具列表
|
||||||
|
|
||||||
|
{chr(10).join(tool_summaries)}
|
||||||
|
|
||||||
|
## 使用说明
|
||||||
|
|
||||||
|
使用这些工具前请确保理解每个工具的输入输出格式。
|
||||||
|
"""
|
||||||
|
|
||||||
|
return SkillMetadata(
|
||||||
|
name=f"mcp-{server}-{group}",
|
||||||
|
description=f"MCP skill from {server}: {group}",
|
||||||
|
version="1.0.0",
|
||||||
|
tags=["mcp", server, group],
|
||||||
|
triggers=[group, server],
|
||||||
|
content=content,
|
||||||
|
source="mcp",
|
||||||
|
source_id=f"{server}:{group}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _group_to_skill(self, group: str, tools: list[str], server: str) -> SkillMetadata:
|
||||||
|
"""将 MCP 工具组转换为 Skill"""
|
||||||
|
return SkillMetadata(
|
||||||
|
name=f"mcp-{server}-{group}",
|
||||||
|
description=f"MCP skill from {server}: {group}",
|
||||||
|
version="1.0.0",
|
||||||
|
tags=["mcp", server, group],
|
||||||
|
triggers=[group, server],
|
||||||
|
content=f"# {group}\n\nTools: {', '.join(tools)}",
|
||||||
|
source="mcp",
|
||||||
|
source_id=f"{server}:{group}",
|
||||||
|
)
|
||||||
42
backend/app/agents/skills/metadata.py
Normal file
42
backend/app/agents/skills/metadata.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""Skill 元数据定义 - Phase 9.1"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SkillMetadata:
|
||||||
|
"""Skill 元数据"""
|
||||||
|
|
||||||
|
id: str = "" # Skill ID
|
||||||
|
name: str = "" # Skill 名称
|
||||||
|
description: str = "" # 描述
|
||||||
|
version: str = "1.0.0" # 版本
|
||||||
|
author: str = "" # 作者
|
||||||
|
tags: list[str] = field(default_factory=list) # 标签
|
||||||
|
triggers: list[str] = field(default_factory=list) # 触发关键词
|
||||||
|
content: str = "" # Skill 内容(markdown)
|
||||||
|
source: str = "local" # 来源:local, plugin, mcp, bundled
|
||||||
|
source_id: str = "" # 来源 ID
|
||||||
|
enabled: bool = True # 是否启用
|
||||||
|
tools: list[str] = field(default_factory=list) # 关联的工具
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"version": self.version,
|
||||||
|
"author": self.author,
|
||||||
|
"tags": self.tags,
|
||||||
|
"triggers": self.triggers,
|
||||||
|
"content": self.content,
|
||||||
|
"source": self.source,
|
||||||
|
"source_id": self.source_id,
|
||||||
|
"enabled": self.enabled,
|
||||||
|
"tools": self.tools,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "SkillMetadata":
|
||||||
|
return cls(**data)
|
||||||
133
backend/app/agents/skills/registry.py
Normal file
133
backend/app/agents/skills/registry.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
"""Skills 注册表 - Phase 9.1"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.skills.metadata import SkillMetadata
|
||||||
|
from app.agents.skills.loaders.local_loader import LocalSkillLoader
|
||||||
|
|
||||||
|
|
||||||
|
class SkillRegistry:
|
||||||
|
"""Skills 注册表
|
||||||
|
|
||||||
|
管理所有 Skills 的注册、发现和加载。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._skills: dict[str, SkillMetadata] = {}
|
||||||
|
self._loaders: list[Any] = []
|
||||||
|
|
||||||
|
def load_all(self, skills_dir: str | None = None) -> int:
|
||||||
|
"""加载所有 Skills
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skills_dir: Skills 目录,None 则使用默认目录
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
加载的 Skill 数量
|
||||||
|
"""
|
||||||
|
if skills_dir is None:
|
||||||
|
skills_dir = os.path.join(
|
||||||
|
os.path.dirname(__file__), "..", "..", "..", ".claude", "skills"
|
||||||
|
)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
# 本地加载器
|
||||||
|
local_loader = LocalSkillLoader(skills_dir)
|
||||||
|
local_skills = local_loader.load_all()
|
||||||
|
for skill in local_skills:
|
||||||
|
self.register(skill)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
# 插件加载器
|
||||||
|
for loader in self._loaders:
|
||||||
|
try:
|
||||||
|
external_skills = loader.load_all()
|
||||||
|
for skill in external_skills:
|
||||||
|
self.register(skill)
|
||||||
|
count += 1
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
def register(self, skill: SkillMetadata) -> None:
|
||||||
|
"""注册 Skill"""
|
||||||
|
self._skills[skill.name] = skill
|
||||||
|
|
||||||
|
def unregister(self, name: str) -> bool:
|
||||||
|
"""注销 Skill"""
|
||||||
|
if name in self._skills:
|
||||||
|
del self._skills[name]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_skill(self, name: str) -> SkillMetadata | None:
|
||||||
|
"""获取 Skill"""
|
||||||
|
return self._skills.get(name)
|
||||||
|
|
||||||
|
def search(self, query: str) -> list[SkillMetadata]:
|
||||||
|
"""搜索 Skills
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 搜索关键词
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
匹配的 Skills 列表
|
||||||
|
"""
|
||||||
|
query_lower = query.lower()
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for skill in self._skills.values():
|
||||||
|
if not skill.enabled:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 匹配名称、描述、标签
|
||||||
|
if (
|
||||||
|
query_lower in skill.name.lower()
|
||||||
|
or query_lower in skill.description.lower()
|
||||||
|
or any(query_lower in tag.lower() for tag in skill.tags)
|
||||||
|
or any(query_lower in trigger.lower() for trigger in skill.triggers)
|
||||||
|
):
|
||||||
|
results.append(skill)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_skill_context(self, names: list[str]) -> str:
|
||||||
|
"""获取 Skill 上下文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
names: Skill 名称列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
拼接的 Skill 内容
|
||||||
|
"""
|
||||||
|
contexts = []
|
||||||
|
|
||||||
|
for name in names:
|
||||||
|
skill = self._skills.get(name)
|
||||||
|
if skill and skill.enabled:
|
||||||
|
contexts.append(f"# {skill.name}\n\n{skill.content}")
|
||||||
|
|
||||||
|
return "\n\n---\n\n".join(contexts)
|
||||||
|
|
||||||
|
def add_loader(self, loader: Any) -> None:
|
||||||
|
"""添加加载器"""
|
||||||
|
self._loaders.append(loader)
|
||||||
|
|
||||||
|
def list_all(self) -> list[SkillMetadata]:
|
||||||
|
"""列出所有 Skills"""
|
||||||
|
return list(self._skills.values())
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
_registry: SkillRegistry | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_skill_registry() -> SkillRegistry:
|
||||||
|
"""获取全局 Skills 注册表"""
|
||||||
|
global _registry
|
||||||
|
if _registry is None:
|
||||||
|
_registry = SkillRegistry()
|
||||||
|
return _registry
|
||||||
140
backend/app/agents/skills/trigger.py
Normal file
140
backend/app/agents/skills/trigger.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
"""Skill 触发检测器 - Phase 9.5
|
||||||
|
|
||||||
|
检测消息中的 Skill 触发条件。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.skills.metadata import SkillMetadata
|
||||||
|
|
||||||
|
|
||||||
|
class SkillTriggerDetector:
|
||||||
|
"""Skill 触发检测器
|
||||||
|
|
||||||
|
检测用户消息中是否触发了某个 Skill。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._skills: dict[str, SkillMetadata] = {}
|
||||||
|
|
||||||
|
def register_skill(self, skill: SkillMetadata) -> None:
|
||||||
|
"""注册 Skill
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill: Skill 元数据
|
||||||
|
"""
|
||||||
|
self._skills[skill.name] = skill
|
||||||
|
|
||||||
|
def unregister_skill(self, name: str) -> bool:
|
||||||
|
"""注销 Skill
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Skill 名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功
|
||||||
|
"""
|
||||||
|
if name in self._skills:
|
||||||
|
del self._skills[name]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def detect_triggered_skills(self, message: str) -> list[str]:
|
||||||
|
"""检测触发的 Skills
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: 用户消息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
触发的 Skill 名称列表
|
||||||
|
"""
|
||||||
|
triggered = []
|
||||||
|
message_lower = message.lower()
|
||||||
|
|
||||||
|
for skill in self._skills.values():
|
||||||
|
if not skill.enabled:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self._matches_triggers(message, message_lower, skill):
|
||||||
|
triggered.append(skill.name)
|
||||||
|
|
||||||
|
return triggered
|
||||||
|
|
||||||
|
def _matches_triggers(self, message: str, message_lower: str, skill: SkillMetadata) -> bool:
|
||||||
|
"""检查消息是否匹配 Skill 触发条件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: 原始消息
|
||||||
|
message_lower: 小写消息
|
||||||
|
skill: Skill 元数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否匹配
|
||||||
|
"""
|
||||||
|
for trigger in skill.triggers:
|
||||||
|
trigger_lower = trigger.lower()
|
||||||
|
|
||||||
|
# 前缀匹配,如 "/code" 或 "@git"
|
||||||
|
if trigger_lower.startswith("/") or trigger_lower.startswith("@"):
|
||||||
|
if message_lower.startswith(trigger_lower):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 命令格式,如 "//analyze"
|
||||||
|
if trigger_lower.startswith("//"):
|
||||||
|
pattern = trigger_lower[2:]
|
||||||
|
if re.search(rf"\b{re.escape(pattern)}\b", message_lower):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 关键词匹配
|
||||||
|
if trigger_lower in message_lower:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_skill_prompt(self, skill_name: str) -> str | None:
|
||||||
|
"""获取 Skill 的提示词
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_name: Skill 名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Skill 内容或 None
|
||||||
|
"""
|
||||||
|
skill = self._skills.get(skill_name)
|
||||||
|
if skill:
|
||||||
|
return skill.content
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_triggered_skill_context(self, message: str) -> str:
|
||||||
|
"""获取触发的 Skills 上下文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: 用户消息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
拼接的 Skill 上下文
|
||||||
|
"""
|
||||||
|
triggered = self.detect_triggered_skills(message)
|
||||||
|
if not triggered:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
contexts = []
|
||||||
|
for skill_name in triggered:
|
||||||
|
skill = self._skills.get(skill_name)
|
||||||
|
if skill:
|
||||||
|
contexts.append(f"# {skill.name}\n\n{skill.content}")
|
||||||
|
|
||||||
|
return "\n\n---\n\n".join(contexts)
|
||||||
|
|
||||||
|
|
||||||
|
# 全局检测器
|
||||||
|
_detector: SkillTriggerDetector | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_skill_trigger_detector() -> SkillTriggerDetector:
|
||||||
|
"""获取全局 Skill 触发检测器"""
|
||||||
|
global _detector
|
||||||
|
if _detector is None:
|
||||||
|
_detector = SkillTriggerDetector()
|
||||||
|
return _detector
|
||||||
@@ -1,10 +1,21 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from typing import TypedDict, Annotated, Sequence
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Annotated, Any, Literal, TypedDict
|
||||||
|
|
||||||
from langchain_core.messages import BaseMessage
|
from app.agents.schemas.event import AgentEvent
|
||||||
|
from app.agents.schemas.message import AgentMessage
|
||||||
|
from app.agents.schemas.task import AgentTask, CollaborationBudget, InterruptRecord, RecoveryRecord, TaskResult, VerificationStatus
|
||||||
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
from langgraph.graph.message import add_messages
|
from langgraph.graph.message import add_messages
|
||||||
|
|
||||||
|
AgentPhase = Literal[
|
||||||
|
"phase_0_bootstrap",
|
||||||
|
"phase_1_routing",
|
||||||
|
"phase_2_controlled_collaboration",
|
||||||
|
"phase_3_dynamic_collaboration",
|
||||||
|
"phase_4_visibility_and_verification",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class AgentRole(str, Enum):
|
class AgentRole(str, Enum):
|
||||||
MASTER = "master"
|
MASTER = "master"
|
||||||
@@ -22,41 +33,113 @@ class ConversationTurn:
|
|||||||
model: str | None = None
|
model: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def turn_to_message(turn: ConversationTurn) -> BaseMessage:
|
||||||
|
if turn.role == "user":
|
||||||
|
return HumanMessage(content=turn.content)
|
||||||
|
return AIMessage(content=turn.content)
|
||||||
|
|
||||||
|
|
||||||
class AgentState(TypedDict):
|
class AgentState(TypedDict):
|
||||||
# Core message history with add_messages reducer
|
|
||||||
messages: Annotated[list[BaseMessage], add_messages]
|
messages: Annotated[list[BaseMessage], add_messages]
|
||||||
|
|
||||||
# Session identifiers
|
|
||||||
user_id: str
|
user_id: str
|
||||||
conversation_id: str
|
conversation_id: str
|
||||||
|
parent_conversation_id: str | None
|
||||||
|
thread_id: str | None
|
||||||
|
last_message_id: str | None
|
||||||
|
message_sequence: int
|
||||||
|
agent_id: str | None
|
||||||
|
parent_agent_id: str | None
|
||||||
|
root_agent_id: str | None
|
||||||
|
collaboration_depth: int
|
||||||
|
spawned_agent_ids: list[str]
|
||||||
|
|
||||||
# Agent routing state
|
execution_mode: Literal["direct", "collaboration", "delegated", "verified"]
|
||||||
current_agent: str | None
|
current_agent: str | None
|
||||||
next_step: str | None # For explicit graph routing
|
next_step: str | None
|
||||||
|
active_agents: list[AgentRole]
|
||||||
# Traceability
|
current_sub_commander: str | None
|
||||||
|
active_sub_commanders: list[str]
|
||||||
|
sub_commander_trace: list[dict[str, Any]]
|
||||||
agent_trace: list[str]
|
agent_trace: list[str]
|
||||||
|
event_trace: list[AgentEvent | dict[str, Any]]
|
||||||
# Task & Entity Tracking (Business Logic)
|
message_trace: list[AgentMessage | dict[str, Any]]
|
||||||
pending_tasks: list[dict]
|
|
||||||
completed_tasks: list[dict]
|
pending_tasks: list[dict[str, Any]]
|
||||||
created_entities: list[dict]
|
completed_tasks: list[dict[str, Any]]
|
||||||
|
active_tasks: list[AgentTask | dict[str, Any]]
|
||||||
|
task_results: list[TaskResult | dict[str, Any]]
|
||||||
|
task_hierarchy: dict[str, list[str]]
|
||||||
|
interrupted_tasks: list[InterruptRecord | dict[str, Any]]
|
||||||
|
recovery_trace: list[RecoveryRecord | dict[str, Any]]
|
||||||
|
recovery_points: list[dict[str, Any]]
|
||||||
|
tool_calls: list[dict[str, Any]]
|
||||||
|
last_tool_result: str | None
|
||||||
|
action_results: list[dict[str, Any]]
|
||||||
|
created_entities: list[dict[str, Any]]
|
||||||
|
tool_outcomes: list[dict[str, Any]]
|
||||||
|
task_result_summary: dict[str, Any] | None
|
||||||
|
verifier_hints: dict[str, Any] | None
|
||||||
|
|
||||||
|
verification_status: VerificationStatus | None
|
||||||
|
verification_summary: str | None
|
||||||
|
verification_evidence: list[dict[str, Any]]
|
||||||
|
isolation_mode: str
|
||||||
|
isolation_id: str | None
|
||||||
|
isolation_workspace_path: str | None
|
||||||
|
isolation_parent_conversation_id: str | None
|
||||||
|
isolation_metadata: dict[str, Any]
|
||||||
|
input_tokens: int
|
||||||
|
output_tokens: int
|
||||||
|
estimated_cost: float | None
|
||||||
|
budget_warning: bool
|
||||||
|
cost_by_agent: dict[str, dict[str, Any]]
|
||||||
|
cost_thresholds: dict[str, Any]
|
||||||
|
budget_state: CollaborationBudget | dict[str, Any] | None
|
||||||
|
collaboration_budget_history: list[CollaborationBudget | dict[str, Any]]
|
||||||
|
current_phase: AgentPhase
|
||||||
|
phase_history: list[dict[str, Any]]
|
||||||
|
current_checkpoint: str | None
|
||||||
|
checkpoint_history: list[dict[str, Any]]
|
||||||
|
|
||||||
|
tool_strategy_used: str | None
|
||||||
|
tool_round_count: int
|
||||||
|
max_tool_rounds: int
|
||||||
|
retry_count: int
|
||||||
|
max_retries: int
|
||||||
|
iteration_count: int
|
||||||
|
max_iterations: int
|
||||||
|
routing_hops: int
|
||||||
|
max_routing_hops: int
|
||||||
|
terminated_due_to_loop_guard: bool
|
||||||
|
retrieval_trace: list[dict[str, Any]]
|
||||||
|
stop_reason: str | None
|
||||||
|
|
||||||
|
clarification_needed: bool
|
||||||
|
clarification_question: str | None
|
||||||
|
fallback_parse_error: str | None
|
||||||
|
should_respond: bool
|
||||||
|
|
||||||
# Context summaries (for long-term or cross-agent context)
|
|
||||||
knowledge_context: str | None
|
knowledge_context: str | None
|
||||||
|
graph_context: str | None
|
||||||
schedule_context_summary: str | None
|
schedule_context_summary: str | None
|
||||||
|
plan: str | None
|
||||||
|
plan_steps: list[dict[str, Any]]
|
||||||
analysis_report: str | None
|
analysis_report: str | None
|
||||||
|
|
||||||
# Output control
|
|
||||||
final_response: str | None
|
final_response: str | None
|
||||||
|
|
||||||
# Memory & Environment
|
|
||||||
memory_context: str | None
|
memory_context: str | None
|
||||||
current_datetime_context: str | None
|
current_datetime_context: str | None
|
||||||
|
current_datetime_reference: dict[str, str] | None
|
||||||
# Configuration
|
|
||||||
user_llm_config: dict | None
|
turn_context: dict[str, Any] | None
|
||||||
provider_capabilities: dict | None
|
routing_decision: dict[str, Any] | None
|
||||||
|
continuity_state: dict[str, Any] | None
|
||||||
|
pending_action: dict[str, Any] | None
|
||||||
|
last_completed_action: dict[str, Any] | None
|
||||||
|
clarification_context: dict[str, Any] | None
|
||||||
|
|
||||||
|
user_llm_config: dict[str, Any] | None
|
||||||
|
provider_capabilities: dict[str, Any] | None
|
||||||
|
|
||||||
|
|
||||||
def initial_state(user_id: str, conversation_id: str) -> AgentState:
|
def initial_state(user_id: str, conversation_id: str) -> AgentState:
|
||||||
@@ -64,18 +147,103 @@ def initial_state(user_id: str, conversation_id: str) -> AgentState:
|
|||||||
messages=[],
|
messages=[],
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
|
parent_conversation_id=None,
|
||||||
|
thread_id=None,
|
||||||
|
last_message_id=None,
|
||||||
|
message_sequence=0,
|
||||||
|
agent_id=AgentRole.MASTER.value,
|
||||||
|
parent_agent_id=None,
|
||||||
|
root_agent_id=AgentRole.MASTER.value,
|
||||||
|
collaboration_depth=0,
|
||||||
|
spawned_agent_ids=[],
|
||||||
|
execution_mode="direct",
|
||||||
current_agent=AgentRole.MASTER.value,
|
current_agent=AgentRole.MASTER.value,
|
||||||
next_step=None,
|
next_step=None,
|
||||||
|
active_agents=[AgentRole.MASTER],
|
||||||
|
current_sub_commander=None,
|
||||||
|
active_sub_commanders=[],
|
||||||
|
sub_commander_trace=[],
|
||||||
agent_trace=[AgentRole.MASTER.value],
|
agent_trace=[AgentRole.MASTER.value],
|
||||||
|
event_trace=[],
|
||||||
|
message_trace=[],
|
||||||
pending_tasks=[],
|
pending_tasks=[],
|
||||||
completed_tasks=[],
|
completed_tasks=[],
|
||||||
|
active_tasks=[],
|
||||||
|
task_results=[],
|
||||||
|
task_hierarchy={},
|
||||||
|
interrupted_tasks=[],
|
||||||
|
recovery_trace=[],
|
||||||
|
recovery_points=[],
|
||||||
|
tool_calls=[],
|
||||||
|
last_tool_result=None,
|
||||||
|
action_results=[],
|
||||||
created_entities=[],
|
created_entities=[],
|
||||||
|
tool_outcomes=[],
|
||||||
|
task_result_summary=None,
|
||||||
|
verifier_hints=None,
|
||||||
|
verification_status=None,
|
||||||
|
verification_summary=None,
|
||||||
|
verification_evidence=[],
|
||||||
|
isolation_mode="none",
|
||||||
|
isolation_id=None,
|
||||||
|
isolation_workspace_path=None,
|
||||||
|
isolation_parent_conversation_id=None,
|
||||||
|
isolation_metadata={},
|
||||||
|
input_tokens=0,
|
||||||
|
output_tokens=0,
|
||||||
|
estimated_cost=None,
|
||||||
|
budget_warning=False,
|
||||||
|
cost_by_agent={},
|
||||||
|
cost_thresholds={},
|
||||||
|
budget_state=None,
|
||||||
|
collaboration_budget_history=[],
|
||||||
|
current_phase="phase_0_bootstrap",
|
||||||
|
phase_history=[
|
||||||
|
{
|
||||||
|
"phase": "phase_0_bootstrap",
|
||||||
|
"reason": "initial_state_created",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
current_checkpoint="bootstrap.initialized",
|
||||||
|
checkpoint_history=[
|
||||||
|
{
|
||||||
|
"checkpoint": "bootstrap.initialized",
|
||||||
|
"phase": "phase_0_bootstrap",
|
||||||
|
"reason": "initial_state_created",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_strategy_used=None,
|
||||||
|
tool_round_count=0,
|
||||||
|
max_tool_rounds=2,
|
||||||
|
retry_count=0,
|
||||||
|
max_retries=1,
|
||||||
|
iteration_count=0,
|
||||||
|
max_iterations=3,
|
||||||
|
routing_hops=0,
|
||||||
|
max_routing_hops=2,
|
||||||
|
terminated_due_to_loop_guard=False,
|
||||||
|
retrieval_trace=[],
|
||||||
|
stop_reason=None,
|
||||||
|
clarification_needed=False,
|
||||||
|
clarification_question=None,
|
||||||
|
fallback_parse_error=None,
|
||||||
|
should_respond=True,
|
||||||
knowledge_context=None,
|
knowledge_context=None,
|
||||||
|
graph_context=None,
|
||||||
schedule_context_summary=None,
|
schedule_context_summary=None,
|
||||||
|
plan=None,
|
||||||
|
plan_steps=[],
|
||||||
analysis_report=None,
|
analysis_report=None,
|
||||||
final_response=None,
|
final_response=None,
|
||||||
memory_context=None,
|
memory_context=None,
|
||||||
current_datetime_context=None,
|
current_datetime_context=None,
|
||||||
|
current_datetime_reference=None,
|
||||||
|
turn_context=None,
|
||||||
|
routing_decision=None,
|
||||||
|
continuity_state=None,
|
||||||
|
pending_action=None,
|
||||||
|
last_completed_action=None,
|
||||||
|
clarification_context=None,
|
||||||
user_llm_config=None,
|
user_llm_config=None,
|
||||||
provider_capabilities=None,
|
provider_capabilities=None,
|
||||||
)
|
)
|
||||||
|
|||||||
13
backend/app/agents/team/__init__.py
Normal file
13
backend/app/agents/team/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
"""Team 多 Agent 协作"""
|
||||||
|
|
||||||
|
from app.agents.team.leader import TeamLeader, TeamTask, TaskStatus
|
||||||
|
from app.agents.team.member import TeamMember, MemberStatus, MemberTask
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TeamLeader",
|
||||||
|
"TeamTask",
|
||||||
|
"TaskStatus",
|
||||||
|
"TeamMember",
|
||||||
|
"MemberStatus",
|
||||||
|
"MemberTask",
|
||||||
|
]
|
||||||
121
backend/app/agents/team/leader.py
Normal file
121
backend/app/agents/team/leader.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
"""Team 多 Agent 协作 - Phase 10.1"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatus(Enum):
|
||||||
|
PENDING = "pending"
|
||||||
|
IN_PROGRESS = "in_progress"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TeamTask:
|
||||||
|
"""团队任务"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
description: str
|
||||||
|
assignee: str | None = None
|
||||||
|
status: TaskStatus = TaskStatus.PENDING
|
||||||
|
result: Any = None
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TeamLeader:
|
||||||
|
"""团队领导者
|
||||||
|
|
||||||
|
协调多个 Agent 成员执行任务。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, team_id: str, members: list[str]):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
team_id: 团队 ID
|
||||||
|
members: 成员 ID 列表
|
||||||
|
"""
|
||||||
|
self.team_id = team_id
|
||||||
|
self.members = members
|
||||||
|
self._tasks: dict[str, TeamTask] = {}
|
||||||
|
|
||||||
|
def create_task(self, description: str) -> str:
|
||||||
|
"""创建任务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
description: 任务描述
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
任务 ID
|
||||||
|
"""
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
task_id = str(uuid.uuid4())[:8]
|
||||||
|
self._tasks[task_id] = TeamTask(
|
||||||
|
id=task_id,
|
||||||
|
description=description,
|
||||||
|
)
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
def assign_task(self, task_id: str, member: str) -> bool:
|
||||||
|
"""分配任务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: 任务 ID
|
||||||
|
member: 成员 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功
|
||||||
|
"""
|
||||||
|
if task_id not in self._tasks:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if member not in self.members:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._tasks[task_id].assignee = member
|
||||||
|
self._tasks[task_id].status = TaskStatus.IN_PROGRESS
|
||||||
|
return True
|
||||||
|
|
||||||
|
def broadcast_task(self, description: str) -> list[str]:
|
||||||
|
"""广播任务给所有成员
|
||||||
|
|
||||||
|
Args:
|
||||||
|
description: 任务描述
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
创建的任务 ID 列表
|
||||||
|
"""
|
||||||
|
task_ids = []
|
||||||
|
for member in self.members:
|
||||||
|
task_id = self.create_task(description)
|
||||||
|
self.assign_task(task_id, member)
|
||||||
|
task_ids.append(task_id)
|
||||||
|
return task_ids
|
||||||
|
|
||||||
|
def collect_results(self) -> dict[str, Any]:
|
||||||
|
"""收集所有任务结果
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
任务 ID -> 结果的映射
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
task_id: task.result
|
||||||
|
for task_id, task in self._tasks.items()
|
||||||
|
if task.status == TaskStatus.COMPLETED
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_team_status(self) -> dict[str, Any]:
|
||||||
|
"""获取团队状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
团队状态摘要
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"team_id": self.team_id,
|
||||||
|
"members": self.members,
|
||||||
|
"task_count": len(self._tasks),
|
||||||
|
"completed": sum(1 for t in self._tasks.values() if t.status == TaskStatus.COMPLETED),
|
||||||
|
"failed": sum(1 for t in self._tasks.values() if t.status == TaskStatus.FAILED),
|
||||||
|
}
|
||||||
166
backend/app/agents/team/member.py
Normal file
166
backend/app/agents/team/member.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
"""TeamMember 实现 - Phase 10.1
|
||||||
|
|
||||||
|
团队成员实现,负责执行分配的任务。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class MemberStatus(Enum):
|
||||||
|
"""成员状态"""
|
||||||
|
|
||||||
|
IDLE = "idle"
|
||||||
|
BUSY = "busy"
|
||||||
|
OFFLINE = "offline"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemberTask:
|
||||||
|
"""成员任务"""
|
||||||
|
|
||||||
|
task_id: str
|
||||||
|
description: str
|
||||||
|
status: str = "pending" # pending, in_progress, completed, failed
|
||||||
|
result: Any = None
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TeamMember:
|
||||||
|
"""团队成员
|
||||||
|
|
||||||
|
代表团队中的一个 Agent 成员,负责执行分配的任务。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, member_id: str, name: str, capabilities: list[str] | None = None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
member_id: 成员 ID
|
||||||
|
name: 成员名称
|
||||||
|
capabilities: 成员能力列表
|
||||||
|
"""
|
||||||
|
self.member_id = member_id
|
||||||
|
self.name = name
|
||||||
|
self.capabilities = capabilities or []
|
||||||
|
self.status = MemberStatus.IDLE
|
||||||
|
self._tasks: dict[str, MemberTask] = {}
|
||||||
|
self._metadata: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def assign_task(self, task_id: str, description: str) -> MemberTask:
|
||||||
|
"""接收任务分配
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: 任务 ID
|
||||||
|
description: 任务描述
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
创建的任务对象
|
||||||
|
"""
|
||||||
|
task = MemberTask(task_id=task_id, description=description)
|
||||||
|
self._tasks[task_id] = task
|
||||||
|
self.status = MemberStatus.BUSY
|
||||||
|
return task
|
||||||
|
|
||||||
|
def update_task_status(
|
||||||
|
self, task_id: str, status: str, result: Any = None, error: str | None = None
|
||||||
|
) -> bool:
|
||||||
|
"""更新任务状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: 任务 ID
|
||||||
|
status: 新状态
|
||||||
|
result: 任务结果
|
||||||
|
error: 错误信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否更新成功
|
||||||
|
"""
|
||||||
|
if task_id not in self._tasks:
|
||||||
|
return False
|
||||||
|
|
||||||
|
task = self._tasks[task_id]
|
||||||
|
task.status = status
|
||||||
|
if result is not None:
|
||||||
|
task.result = result
|
||||||
|
if error is not None:
|
||||||
|
task.error = error
|
||||||
|
|
||||||
|
if status in ("completed", "failed"):
|
||||||
|
self.status = MemberStatus.IDLE
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_task(self, task_id: str) -> MemberTask | None:
|
||||||
|
"""获取任务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: 任务 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
任务对象或 None
|
||||||
|
"""
|
||||||
|
return self._tasks.get(task_id)
|
||||||
|
|
||||||
|
def get_pending_tasks(self) -> list[MemberTask]:
|
||||||
|
"""获取待处理任务
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
待处理任务列表
|
||||||
|
"""
|
||||||
|
return [t for t in self._tasks.values() if t.status == "pending"]
|
||||||
|
|
||||||
|
def get_active_task(self) -> MemberTask | None:
|
||||||
|
"""获取当前执行中的任务
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
当前任务或 None
|
||||||
|
"""
|
||||||
|
for task in self._tasks.values():
|
||||||
|
if task.status == "in_progress":
|
||||||
|
return task
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_completed_tasks(self) -> list[MemberTask]:
|
||||||
|
"""获取已完成任务
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
已完成任务列表
|
||||||
|
"""
|
||||||
|
return [t for t in self._tasks.values() if t.status == "completed"]
|
||||||
|
|
||||||
|
def set_metadata(self, key: str, value: Any) -> None:
|
||||||
|
"""设置元数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: 元数据键
|
||||||
|
value: 元数据值
|
||||||
|
"""
|
||||||
|
self._metadata[key] = value
|
||||||
|
|
||||||
|
def get_metadata(self, key: str) -> Any:
|
||||||
|
"""获取元数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: 元数据键
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
元数据值或 None
|
||||||
|
"""
|
||||||
|
return self._metadata.get(key)
|
||||||
|
|
||||||
|
def get_status(self) -> dict[str, Any]:
|
||||||
|
"""获取成员状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
状态字典
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"member_id": self.member_id,
|
||||||
|
"name": self.name,
|
||||||
|
"status": self.status.value,
|
||||||
|
"capabilities": self.capabilities,
|
||||||
|
"task_count": len(self._tasks),
|
||||||
|
"pending_count": len(self.get_pending_tasks()),
|
||||||
|
"active_task": self.get_active_task().__dict__ if self.get_active_task() else None,
|
||||||
|
}
|
||||||
@@ -1,6 +1,9 @@
|
|||||||
from app.agents.tools.search import (
|
from app.agents.tools.search import (
|
||||||
search_knowledge, get_knowledge_graph_context,
|
search_knowledge,
|
||||||
build_knowledge_graph, hybrid_search, web_search,
|
get_knowledge_graph_context,
|
||||||
|
build_knowledge_graph,
|
||||||
|
hybrid_search,
|
||||||
|
web_search,
|
||||||
)
|
)
|
||||||
from app.agents.tools.task import get_tasks, create_task, update_task_status
|
from app.agents.tools.task import get_tasks, create_task, update_task_status
|
||||||
from app.agents.tools.forum import get_forum_posts, create_forum_post, scan_forum_for_instructions
|
from app.agents.tools.forum import get_forum_posts, create_forum_post, scan_forum_for_instructions
|
||||||
@@ -13,6 +16,58 @@ from app.agents.tools.schedule import (
|
|||||||
)
|
)
|
||||||
from app.agents.tools.time_reasoning import resolve_time_expression
|
from app.agents.tools.time_reasoning import resolve_time_expression
|
||||||
|
|
||||||
|
# Phase 6.1: Tool Registry exports
|
||||||
|
from app.agents.tools.registry import (
|
||||||
|
ToolRegistry,
|
||||||
|
get_tool_registry,
|
||||||
|
reset_tool_registry,
|
||||||
|
)
|
||||||
|
from app.agents.tools.manifest import (
|
||||||
|
HookConfig,
|
||||||
|
PermissionClass,
|
||||||
|
SideEffectScope,
|
||||||
|
ToolCategory,
|
||||||
|
ToolManifest,
|
||||||
|
)
|
||||||
|
from app.agents.tools.migration import (
|
||||||
|
migrate_tool,
|
||||||
|
migrate_all_tools,
|
||||||
|
get_tool_executor,
|
||||||
|
BackwardCompatTool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Phase 6.2: Hook System exports
|
||||||
|
from app.agents.tools.hooks import (
|
||||||
|
HookManager,
|
||||||
|
HookExecutor,
|
||||||
|
HookType,
|
||||||
|
HookDefinition,
|
||||||
|
HookResult,
|
||||||
|
ExecutionContext,
|
||||||
|
get_hook_manager,
|
||||||
|
get_hook_executor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Phase 6.3: Streaming Executor exports
|
||||||
|
from app.agents.tools.streaming import (
|
||||||
|
StreamingToolExecutor,
|
||||||
|
get_streaming_executor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Phase 6.4: Builtin Tools exports
|
||||||
|
from app.agents.tools.builtins import (
|
||||||
|
GlobTool,
|
||||||
|
GrepTool,
|
||||||
|
ReadFileTool,
|
||||||
|
WriteFileTool,
|
||||||
|
BashTool,
|
||||||
|
PowerShellTool,
|
||||||
|
LSPTools,
|
||||||
|
GitTool,
|
||||||
|
TeamAgentTool,
|
||||||
|
TaskBroadcastTool,
|
||||||
|
)
|
||||||
|
|
||||||
TASK_TOOLS = [
|
TASK_TOOLS = [
|
||||||
get_tasks,
|
get_tasks,
|
||||||
create_task,
|
create_task,
|
||||||
|
|||||||
18
backend/app/agents/tools/async_bridge.py
Normal file
18
backend/app/agents/tools/async_bridge.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
_executor = ThreadPoolExecutor(max_workers=4)
|
||||||
|
|
||||||
|
|
||||||
|
def run_async(coro: Any, timeout: int = 30):
|
||||||
|
try:
|
||||||
|
asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
return asyncio.run(coro)
|
||||||
|
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["run_async"]
|
||||||
161
backend/app/agents/tools/base.py
Normal file
161
backend/app/agents/tools/base.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
"""工具基类 - 工具系统重构 Phase 6.1"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Generic, TypeVar
|
||||||
|
|
||||||
|
from app.agents.tools.manifest import (
|
||||||
|
PermissionClass,
|
||||||
|
SideEffectScope,
|
||||||
|
ToolCategory,
|
||||||
|
ToolManifest,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTool(ABC, Generic[T]):
|
||||||
|
"""工具基类
|
||||||
|
|
||||||
|
提供工具的标准接口和默认实现。
|
||||||
|
所有自定义工具应继承此类。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
category: ToolCategory,
|
||||||
|
permission_class: PermissionClass,
|
||||||
|
side_effect_scope: SideEffectScope = SideEffectScope.NONE,
|
||||||
|
requires_confirmation: bool = False,
|
||||||
|
is_streaming: bool = False,
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.category = category
|
||||||
|
self.permission_class = permission_class
|
||||||
|
self.side_effect_scope = side_effect_scope
|
||||||
|
self.requires_confirmation = requires_confirmation
|
||||||
|
self.is_streaming = is_streaming
|
||||||
|
self.tags = tags or []
|
||||||
|
|
||||||
|
def get_manifest(self) -> ToolManifest:
|
||||||
|
"""获取工具元数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具元数据
|
||||||
|
"""
|
||||||
|
return ToolManifest(
|
||||||
|
name=self.name,
|
||||||
|
description=self.description,
|
||||||
|
category=self.category,
|
||||||
|
parameters=self.get_parameters(),
|
||||||
|
return_schema=self.get_return_schema(),
|
||||||
|
permission_class=self.permission_class,
|
||||||
|
side_effect_scope=self.side_effect_scope,
|
||||||
|
requires_confirmation=self.requires_confirmation,
|
||||||
|
is_streaming=self.is_streaming,
|
||||||
|
tags=self.tags,
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_parameters(self) -> dict[str, Any]:
|
||||||
|
"""获取参数 Schema(JSON Schema 格式)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
参数 schema
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_return_schema(self) -> dict[str, Any]:
|
||||||
|
"""获取返回值 Schema
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
返回值 schema
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def execute(self, **kwargs) -> T:
|
||||||
|
"""执行工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: 工具参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
执行结果
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def execute_safe(self, **kwargs) -> dict[str, Any]:
|
||||||
|
"""安全执行工具,捕获异常
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: 工具参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含 success 和 result/error 的字典
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await self.execute(**kwargs)
|
||||||
|
return {"success": True, "result": result}
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<{self.__class__.__name__}(name={self.name!r})>"
|
||||||
|
|
||||||
|
|
||||||
|
class ReadTool(BaseTool):
|
||||||
|
"""只读工具基类"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
kwargs.setdefault("category", ToolCategory.READ)
|
||||||
|
kwargs.setdefault("permission_class", PermissionClass.READ)
|
||||||
|
kwargs.setdefault("side_effect_scope", SideEffectScope.NONE)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class WriteTool(BaseTool):
|
||||||
|
"""写入工具基类"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
kwargs.setdefault("category", ToolCategory.WRITE)
|
||||||
|
kwargs.setdefault("permission_class", PermissionClass.WRITE)
|
||||||
|
kwargs.setdefault("side_effect_scope", SideEffectScope.LOCAL_STATE)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class DBWriteTool(BaseTool):
|
||||||
|
"""数据库写入工具基类"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
kwargs.setdefault("category", ToolCategory.DB_WRITE)
|
||||||
|
kwargs.setdefault("permission_class", PermissionClass.WRITE)
|
||||||
|
kwargs.setdefault("side_effect_scope", SideEffectScope.DB_WRITE)
|
||||||
|
kwargs.setdefault("requires_confirmation", True)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalTool(BaseTool):
|
||||||
|
"""外部工具基类(执行外部命令等)"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
kwargs.setdefault("category", ToolCategory.EXTERNAL)
|
||||||
|
kwargs.setdefault("permission_class", PermissionClass.EXTERNAL)
|
||||||
|
kwargs.setdefault("side_effect_scope", SideEffectScope.NETWORK)
|
||||||
|
kwargs.setdefault("requires_confirmation", True)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkTool(BaseTool):
|
||||||
|
"""网络工具基类"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
kwargs.setdefault("category", ToolCategory.NETWORK)
|
||||||
|
kwargs.setdefault("permission_class", PermissionClass.EXTERNAL)
|
||||||
|
kwargs.setdefault("side_effect_scope", SideEffectScope.NETWORK)
|
||||||
|
super().__init__(**kwargs)
|
||||||
43
backend/app/agents/tools/builtins/__init__.py
Normal file
43
backend/app/agents/tools/builtins/__init__.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
"""内置工具集 - Phase 6.4
|
||||||
|
|
||||||
|
新的内置工具,使用 BaseTool 基类。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.agents.tools.builtins.file_tools import (
|
||||||
|
GlobTool,
|
||||||
|
GrepTool,
|
||||||
|
ReadFileTool,
|
||||||
|
WriteFileTool,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.agents.tools.builtins.system_tools import (
|
||||||
|
BashTool,
|
||||||
|
PowerShellTool,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.agents.tools.builtins.dev_tools import (
|
||||||
|
LSPTools,
|
||||||
|
GitTool,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.agents.tools.builtins.collaboration_tools import (
|
||||||
|
TeamAgentTool,
|
||||||
|
TaskBroadcastTool,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# File tools
|
||||||
|
"GlobTool",
|
||||||
|
"GrepTool",
|
||||||
|
"ReadFileTool",
|
||||||
|
"WriteFileTool",
|
||||||
|
# System tools
|
||||||
|
"BashTool",
|
||||||
|
"PowerShellTool",
|
||||||
|
# Dev tools
|
||||||
|
"LSPTools",
|
||||||
|
"GitTool",
|
||||||
|
# Collaboration tools
|
||||||
|
"TeamAgentTool",
|
||||||
|
"TaskBroadcastTool",
|
||||||
|
]
|
||||||
129
backend/app/agents/tools/builtins/collaboration_tools.py
Normal file
129
backend/app/agents/tools/builtins/collaboration_tools.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
"""协作工具 - Phase 6.4"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.tools.base import WriteTool
|
||||||
|
from app.agents.tools.manifest import (
|
||||||
|
PermissionClass,
|
||||||
|
SideEffectScope,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TeamAgentTool(WriteTool):
|
||||||
|
"""团队 Agent 通信工具
|
||||||
|
|
||||||
|
用于与其他 Agent 进行消息传递和协作。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="team_agent",
|
||||||
|
description="向团队 Agent 发送消息或请求协作",
|
||||||
|
permission_class=PermissionClass.WRITE,
|
||||||
|
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||||
|
tags=["collaboration", "team", "agent"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"agent_name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "目标 Agent 名称",
|
||||||
|
},
|
||||||
|
"message": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "要发送的消息",
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["send", "request", "delegate"],
|
||||||
|
"description": "操作类型",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["agent_name", "message"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_return_schema(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"success": {"type": "boolean"},
|
||||||
|
"response": {"type": "string"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, agent_name: str, message: str, action: str = "send") -> dict[str, Any]:
|
||||||
|
# 注意:实际实现需要通过 Agent 通信协议
|
||||||
|
# 这里只是一个框架实现
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"response": f"Message '{action}' to agent '{agent_name}': {message}",
|
||||||
|
"agent_name": agent_name,
|
||||||
|
"action": action,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TaskBroadcastTool(WriteTool):
|
||||||
|
"""任务广播工具
|
||||||
|
|
||||||
|
向多个 Agent 广播任务。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="task_broadcast",
|
||||||
|
description="向多个 Agent 广播任务",
|
||||||
|
permission_class=PermissionClass.WRITE,
|
||||||
|
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||||
|
tags=["collaboration", "broadcast", "task"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"agent_names": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "目标 Agent 列表",
|
||||||
|
},
|
||||||
|
"task": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "要广播的任务描述",
|
||||||
|
},
|
||||||
|
"priority": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["low", "normal", "high", "urgent"],
|
||||||
|
"description": "任务优先级",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["agent_names", "task"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_return_schema(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"success": {"type": "boolean"},
|
||||||
|
"broadcast_to": {"type": "array", "items": {"type": "string"}},
|
||||||
|
"responses": {"type": "array"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
agent_names: list[str],
|
||||||
|
task: str,
|
||||||
|
priority: str = "normal",
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
# 注意:实际实现需要通过 Agent 通信协议
|
||||||
|
# 这里只是一个框架实现
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"broadcast_to": agent_names,
|
||||||
|
"task": task,
|
||||||
|
"priority": priority,
|
||||||
|
"responses": [f"Acknowledged by {agent}" for agent in agent_names],
|
||||||
|
}
|
||||||
155
backend/app/agents/tools/builtins/dev_tools.py
Normal file
155
backend/app/agents/tools/builtins/dev_tools.py
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
"""开发工具 - Phase 6.4"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.tools.base import ReadTool, WriteTool
|
||||||
|
from app.agents.tools.manifest import (
|
||||||
|
PermissionClass,
|
||||||
|
SideEffectScope,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LSPTools(ReadTool):
|
||||||
|
"""语言服务器协议工具集
|
||||||
|
|
||||||
|
提供代码导航、查找引用等 LSP 功能。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="lsp_tools",
|
||||||
|
description="LSP 代码导航和查找引用",
|
||||||
|
permission_class=PermissionClass.READ,
|
||||||
|
side_effect_scope=SideEffectScope.NONE,
|
||||||
|
tags=["development", "lsp", "code"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"action": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["goto_definition", "find_references", "document_symbols"],
|
||||||
|
"description": "LSP 操作类型",
|
||||||
|
},
|
||||||
|
"file": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "文件路径",
|
||||||
|
},
|
||||||
|
"line": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "行号(1-based)",
|
||||||
|
},
|
||||||
|
"character": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "列号(0-based)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["action", "file"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_return_schema(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"success": {"type": "boolean"},
|
||||||
|
"results": {"type": "array"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
action: str,
|
||||||
|
file: str,
|
||||||
|
line: int = 1,
|
||||||
|
character: int = 0,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
# 注意:实际 LSP 调用需要通过 lsp-utils 或类似库
|
||||||
|
# 这里只是一个框架实现
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"LSP action '{action}' not fully implemented - requires LSP server integration",
|
||||||
|
"action": action,
|
||||||
|
"file": file,
|
||||||
|
"position": {"line": line, "character": character},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class GitTool(ReadTool):
|
||||||
|
"""Git 操作工具
|
||||||
|
|
||||||
|
提供常用的 Git 操作。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, repo_path: str = "."):
|
||||||
|
super().__init__(
|
||||||
|
name="git",
|
||||||
|
description="执行 Git 命令",
|
||||||
|
permission_class=PermissionClass.EXTERNAL,
|
||||||
|
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||||
|
requires_confirmation=True,
|
||||||
|
tags=["development", "git", "version-control"],
|
||||||
|
)
|
||||||
|
self.repo_path = repo_path
|
||||||
|
|
||||||
|
def get_parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Git 子命令和参数,如 'status' 或 'log --oneline -10'",
|
||||||
|
},
|
||||||
|
"repo_path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "仓库路径(可选)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["command"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_return_schema(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"stdout": {"type": "string"},
|
||||||
|
"stderr": {"type": "string"},
|
||||||
|
"returncode": {"type": "integer"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, command: str, repo_path: str | None = None) -> dict[str, Any]:
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
|
||||||
|
repo = repo_path or self.repo_path
|
||||||
|
|
||||||
|
# 构建完整的 git 命令
|
||||||
|
if platform.system() == "Windows":
|
||||||
|
full_command = f'git -C "{repo}" {command}'
|
||||||
|
else:
|
||||||
|
full_command = f"git -C '{repo}' {command}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
process = await asyncio.create_subprocess_shell(
|
||||||
|
full_command,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
|
||||||
|
stdout, stderr = await process.communicate()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"stdout": stdout.decode("utf-8", errors="replace"),
|
||||||
|
"stderr": stderr.decode("utf-8", errors="replace"),
|
||||||
|
"returncode": process.returncode,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": str(e),
|
||||||
|
"returncode": -1,
|
||||||
|
}
|
||||||
255
backend/app/agents/tools/builtins/file_tools.py
Normal file
255
backend/app/agents/tools/builtins/file_tools.py
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
"""文件操作工具 - Phase 6.4"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.tools.base import ExternalTool, ReadTool, WriteTool
|
||||||
|
from app.agents.tools.manifest import (
|
||||||
|
PermissionClass,
|
||||||
|
SideEffectScope,
|
||||||
|
ToolCategory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GlobTool(ReadTool):
|
||||||
|
"""文件路径匹配工具
|
||||||
|
|
||||||
|
使用 glob 模式查找文件。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, root_dir: str = "."):
|
||||||
|
super().__init__(
|
||||||
|
name="glob",
|
||||||
|
description="使用 glob 模式查找文件路径",
|
||||||
|
permission_class=PermissionClass.READ,
|
||||||
|
side_effect_scope=SideEffectScope.NONE,
|
||||||
|
tags=["file", "search", "glob"],
|
||||||
|
)
|
||||||
|
self.root_dir = root_dir
|
||||||
|
|
||||||
|
def get_parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"pattern": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Glob 模式,如 **/*.py",
|
||||||
|
},
|
||||||
|
"root_dir": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "搜索根目录(可选)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["pattern"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_return_schema(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, pattern: str, root_dir: str | None = None) -> list[str]:
|
||||||
|
import glob as glob_module
|
||||||
|
|
||||||
|
root = root_dir or self.root_dir
|
||||||
|
return glob_module.glob(pattern, root_dir=root, recursive=True)
|
||||||
|
|
||||||
|
|
||||||
|
class GrepTool(ReadTool):
|
||||||
|
"""文件内容搜索工具
|
||||||
|
|
||||||
|
在文件中搜索匹配的行。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="grep",
|
||||||
|
description="在文件中搜索匹配的文本行",
|
||||||
|
permission_class=PermissionClass.READ,
|
||||||
|
side_effect_scope=SideEffectScope.NONE,
|
||||||
|
tags=["file", "search", "text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"pattern": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "正则表达式模式",
|
||||||
|
},
|
||||||
|
"paths": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "要搜索的文件路径列表",
|
||||||
|
},
|
||||||
|
"case_sensitive": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "是否区分大小写",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["pattern", "paths"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_return_schema(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"file": {"type": "string"},
|
||||||
|
"line": {"type": "integer"},
|
||||||
|
"content": {"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self, pattern: str, paths: list[str], case_sensitive: bool = True
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
import re
|
||||||
|
|
||||||
|
flags = 0 if case_sensitive else re.IGNORECASE
|
||||||
|
regex = re.compile(pattern, flags)
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for path in paths:
|
||||||
|
if not os.path.isfile(path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
for line_num, line in enumerate(f, 1):
|
||||||
|
if regex.search(line):
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"file": path,
|
||||||
|
"line": line_num,
|
||||||
|
"content": line.rstrip(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except (UnicodeDecodeError, PermissionError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class ReadFileTool(ReadTool):
|
||||||
|
"""文件读取工具"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="read_file",
|
||||||
|
description="读取文件内容",
|
||||||
|
permission_class=PermissionClass.READ,
|
||||||
|
side_effect_scope=SideEffectScope.NONE,
|
||||||
|
tags=["file", "read"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "文件路径",
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "最大行数",
|
||||||
|
},
|
||||||
|
"offset": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "起始行号",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["path"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_return_schema(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"content": {"type": "string"},
|
||||||
|
"lines": {"type": "integer"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, path: str, limit: int | None = None, offset: int = 0) -> dict[str, Any]:
|
||||||
|
if not os.path.isfile(path):
|
||||||
|
raise FileNotFoundError(f"File not found: {path}")
|
||||||
|
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
total_lines = len(lines)
|
||||||
|
start = max(0, offset)
|
||||||
|
end = len(lines) if limit is None else min(start + limit, len(lines))
|
||||||
|
|
||||||
|
content = "".join(lines[start:end])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"content": content,
|
||||||
|
"lines": total_lines,
|
||||||
|
"truncated": limit is not None and end < len(lines),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class WriteFileTool(WriteTool):
|
||||||
|
"""文件写入工具"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="write_file",
|
||||||
|
description="写入文件内容",
|
||||||
|
permission_class=PermissionClass.WRITE,
|
||||||
|
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||||
|
requires_confirmation=True,
|
||||||
|
tags=["file", "write"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "文件路径",
|
||||||
|
},
|
||||||
|
"content": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "文件内容",
|
||||||
|
},
|
||||||
|
"append": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "是否追加模式",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["path", "content"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_return_schema(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"success": {"type": "boolean"},
|
||||||
|
"bytes_written": {"type": "integer"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, path: str, content: str, append: bool = False) -> dict[str, Any]:
|
||||||
|
mode = "a" if append else "w"
|
||||||
|
|
||||||
|
# 确保目录存在
|
||||||
|
directory = os.path.dirname(path)
|
||||||
|
if directory and not os.path.exists(directory):
|
||||||
|
os.makedirs(directory, exist_ok=True)
|
||||||
|
|
||||||
|
with open(path, mode, encoding="utf-8") as f:
|
||||||
|
bytes_written = f.write(content)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"bytes_written": bytes_written,
|
||||||
|
}
|
||||||
193
backend/app/agents/tools/builtins/system_tools.py
Normal file
193
backend/app/agents/tools/builtins/system_tools.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
"""系统工具 - Phase 6.4"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import shlex
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.tools.base import ExternalTool
|
||||||
|
from app.agents.tools.manifest import (
|
||||||
|
PermissionClass,
|
||||||
|
SideEffectScope,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BashTool(ExternalTool):
|
||||||
|
"""Bash 命令执行工具"""
|
||||||
|
|
||||||
|
def __init__(self, working_dir: str = "."):
|
||||||
|
super().__init__(
|
||||||
|
name="bash",
|
||||||
|
description="执行 Bash 命令",
|
||||||
|
permission_class=PermissionClass.EXTERNAL,
|
||||||
|
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||||
|
requires_confirmation=True,
|
||||||
|
tags=["system", "bash", "shell"],
|
||||||
|
)
|
||||||
|
self.working_dir = working_dir
|
||||||
|
|
||||||
|
def get_parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "要执行的 Bash 命令",
|
||||||
|
},
|
||||||
|
"timeout": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "超时时间(秒)",
|
||||||
|
},
|
||||||
|
"working_dir": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "工作目录(可选)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["command"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_return_schema(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"stdout": {"type": "string"},
|
||||||
|
"stderr": {"type": "string"},
|
||||||
|
"returncode": {"type": "integer"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self, command: str, timeout: int = 30, working_dir: str | None = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
import os
|
||||||
|
|
||||||
|
cwd = working_dir or self.working_dir
|
||||||
|
|
||||||
|
try:
|
||||||
|
process = await asyncio.create_subprocess_shell(
|
||||||
|
command,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
cwd=cwd,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
process.kill()
|
||||||
|
await process.wait()
|
||||||
|
return {
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": f"Command timed out after {timeout} seconds",
|
||||||
|
"returncode": -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"stdout": stdout.decode("utf-8", errors="replace"),
|
||||||
|
"stderr": stderr.decode("utf-8", errors="replace"),
|
||||||
|
"returncode": process.returncode,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": str(e),
|
||||||
|
"returncode": -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PowerShellTool(ExternalTool):
|
||||||
|
"""PowerShell 命令执行工具"""
|
||||||
|
|
||||||
|
def __init__(self, working_dir: str = "."):
|
||||||
|
super().__init__(
|
||||||
|
name="powershell",
|
||||||
|
description="执行 PowerShell 命令",
|
||||||
|
permission_class=PermissionClass.EXTERNAL,
|
||||||
|
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||||
|
requires_confirmation=True,
|
||||||
|
tags=["system", "powershell", "shell"],
|
||||||
|
)
|
||||||
|
self.working_dir = working_dir
|
||||||
|
|
||||||
|
def get_parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "要执行的 PowerShell 命令",
|
||||||
|
},
|
||||||
|
"timeout": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "超时时间(秒)",
|
||||||
|
},
|
||||||
|
"working_dir": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "工作目录(可选)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["command"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_return_schema(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"stdout": {"type": "string"},
|
||||||
|
"stderr": {"type": "string"},
|
||||||
|
"returncode": {"type": "integer"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self, command: str, timeout: int = 30, working_dir: str | None = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
import platform
|
||||||
|
|
||||||
|
# 检测是否是 Windows 平台
|
||||||
|
is_windows = platform.system() == "Windows"
|
||||||
|
|
||||||
|
if not is_windows:
|
||||||
|
# 非 Windows 平台,可能没有 PowerShell
|
||||||
|
return {
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": "PowerShell is not available on this platform",
|
||||||
|
"returncode": -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
cwd = working_dir or self.working_dir
|
||||||
|
|
||||||
|
try:
|
||||||
|
process = await asyncio.create_subprocess_exec(
|
||||||
|
"powershell.exe",
|
||||||
|
"-NoProfile",
|
||||||
|
"-Command",
|
||||||
|
command,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
cwd=cwd,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
process.kill()
|
||||||
|
await process.wait()
|
||||||
|
return {
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": f"Command timed out after {timeout} seconds",
|
||||||
|
"returncode": -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"stdout": stdout.decode("utf-8", errors="replace"),
|
||||||
|
"stderr": stderr.decode("utf-8", errors="replace"),
|
||||||
|
"returncode": process.returncode,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": str(e),
|
||||||
|
"returncode": -1,
|
||||||
|
}
|
||||||
@@ -4,19 +4,12 @@ from langchain_core.tools import tool
|
|||||||
from app.database import async_session
|
from app.database import async_session
|
||||||
from app.models.forum import ForumPost, ForumReply
|
from app.models.forum import ForumPost, ForumReply
|
||||||
from app.agents.context import get_current_user
|
from app.agents.context import get_current_user
|
||||||
|
from app.agents.tools.async_bridge import run_async
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
import asyncio
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
|
|
||||||
_executor = ThreadPoolExecutor(max_workers=4)
|
|
||||||
|
|
||||||
|
|
||||||
def _run_async(coro, timeout: int = 30):
|
def _run_async(coro, timeout: int = 30):
|
||||||
try:
|
return run_async(coro, timeout=timeout)
|
||||||
asyncio.get_running_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
return asyncio.run(coro)
|
|
||||||
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
|
|||||||
46
backend/app/agents/tools/hooks/__init__.py
Normal file
46
backend/app/agents/tools/hooks/__init__.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""Hook 系统 - Phase 6.2"""
|
||||||
|
|
||||||
|
from app.agents.tools.hooks.types import (
|
||||||
|
HookDefinition,
|
||||||
|
HookResult,
|
||||||
|
HookStage,
|
||||||
|
HookTrigger,
|
||||||
|
HookType,
|
||||||
|
ExecutionContext,
|
||||||
|
HookHandler,
|
||||||
|
PreToolHook,
|
||||||
|
PostToolHook,
|
||||||
|
ErrorToolHook,
|
||||||
|
SkipToolHook,
|
||||||
|
)
|
||||||
|
from app.agents.tools.hooks.manager import (
|
||||||
|
HookManager,
|
||||||
|
get_hook_manager,
|
||||||
|
reset_hook_manager,
|
||||||
|
)
|
||||||
|
from app.agents.tools.hooks.executor import (
|
||||||
|
HookExecutor,
|
||||||
|
get_hook_executor,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Types
|
||||||
|
"HookType",
|
||||||
|
"HookStage",
|
||||||
|
"HookTrigger",
|
||||||
|
"HookDefinition",
|
||||||
|
"HookResult",
|
||||||
|
"ExecutionContext",
|
||||||
|
"HookHandler",
|
||||||
|
"PreToolHook",
|
||||||
|
"PostToolHook",
|
||||||
|
"ErrorToolHook",
|
||||||
|
"SkipToolHook",
|
||||||
|
# Manager
|
||||||
|
"HookManager",
|
||||||
|
"get_hook_manager",
|
||||||
|
"reset_hook_manager",
|
||||||
|
# Executor
|
||||||
|
"HookExecutor",
|
||||||
|
"get_hook_executor",
|
||||||
|
]
|
||||||
11
backend/app/agents/tools/hooks/builtins/__init__.py
Normal file
11
backend/app/agents/tools/hooks/builtins/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""内置 Hook 集合 - Phase 7"""
|
||||||
|
|
||||||
|
from app.agents.tools.hooks.builtins.audit_log import AuditLogHook
|
||||||
|
from app.agents.tools.hooks.builtins.dangerous_confirmation import DangerousConfirmationHook
|
||||||
|
from app.agents.tools.hooks.builtins.security_scan import SecurityScanHook
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AuditLogHook",
|
||||||
|
"DangerousConfirmationHook",
|
||||||
|
"SecurityScanHook",
|
||||||
|
]
|
||||||
115
backend/app/agents/tools/hooks/builtins/audit_log.py
Normal file
115
backend/app/agents/tools/hooks/builtins/audit_log.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
"""审计日志 Hook - Phase 7.2
|
||||||
|
|
||||||
|
记录所有工具调用到审计日志。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.tools.hooks.types import (
|
||||||
|
ExecutionContext,
|
||||||
|
HookResult,
|
||||||
|
HookType,
|
||||||
|
)
|
||||||
|
from app.agents.tools.manifest import ToolCategory
|
||||||
|
|
||||||
|
|
||||||
|
class AuditLogHook:
|
||||||
|
"""审计日志 Hook
|
||||||
|
|
||||||
|
记录所有工具调用的详细信息,包括:
|
||||||
|
- 调用时间
|
||||||
|
- 工具名称
|
||||||
|
- 输入参数
|
||||||
|
- 执行结果
|
||||||
|
- 执行时长
|
||||||
|
- 用户 ID
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, log_path: str | None = None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
log_path: 日志文件路径,None 则输出到 stdout
|
||||||
|
"""
|
||||||
|
self.log_path = log_path
|
||||||
|
self._logs: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
async def pre_tool_use(self, context: ExecutionContext) -> HookResult:
|
||||||
|
"""工具执行前记录"""
|
||||||
|
log_entry = {
|
||||||
|
"event": "pre_tool",
|
||||||
|
"tool_name": context.tool_name,
|
||||||
|
"input": context.tool_input,
|
||||||
|
"user_id": context.user_id,
|
||||||
|
"session_id": context.session_id,
|
||||||
|
}
|
||||||
|
self._logs.append(log_entry)
|
||||||
|
self._write_log(log_entry)
|
||||||
|
return HookResult(
|
||||||
|
hook_name="audit_log",
|
||||||
|
success=True,
|
||||||
|
continue_execution=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def post_tool_use(self, context: ExecutionContext, result: Any) -> HookResult:
|
||||||
|
"""工具执行后记录"""
|
||||||
|
log_entry = {
|
||||||
|
"event": "post_tool",
|
||||||
|
"tool_name": context.tool_name,
|
||||||
|
"result": str(result)[:500] if result else None,
|
||||||
|
"duration_ms": (
|
||||||
|
(context.end_time - context.start_time) * 1000
|
||||||
|
if context.start_time and context.end_time
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
self._logs.append(log_entry)
|
||||||
|
self._write_log(log_entry)
|
||||||
|
return HookResult(
|
||||||
|
hook_name="audit_log",
|
||||||
|
success=True,
|
||||||
|
continue_execution=True,
|
||||||
|
modified_output=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def tool_error(self, context: ExecutionContext, error: Exception) -> HookResult:
|
||||||
|
"""工具出错时记录"""
|
||||||
|
log_entry = {
|
||||||
|
"event": "tool_error",
|
||||||
|
"tool_name": context.tool_name,
|
||||||
|
"error": str(error),
|
||||||
|
"error_type": type(error).__name__,
|
||||||
|
}
|
||||||
|
self._logs.append(log_entry)
|
||||||
|
self._write_log(log_entry)
|
||||||
|
return HookResult(
|
||||||
|
hook_name="audit_log",
|
||||||
|
success=False,
|
||||||
|
continue_execution=True,
|
||||||
|
error=str(error),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _write_log(self, entry: dict[str, Any]) -> None:
|
||||||
|
"""写入日志"""
|
||||||
|
import json
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
entry["timestamp"] = datetime.datetime.now().isoformat()
|
||||||
|
|
||||||
|
if self.log_path:
|
||||||
|
try:
|
||||||
|
with open(self.log_path, "a", encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||||
|
except Exception:
|
||||||
|
# 日志写入失败不影响主流程
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# 输出到 stdout
|
||||||
|
print(f"[AUDIT] {json.dumps(entry, ensure_ascii=False)}")
|
||||||
|
|
||||||
|
def get_logs(self) -> list[dict[str, Any]]:
|
||||||
|
"""获取所有日志"""
|
||||||
|
return self._logs.copy()
|
||||||
|
|
||||||
|
def clear_logs(self) -> None:
|
||||||
|
"""清空日志"""
|
||||||
|
self._logs.clear()
|
||||||
@@ -0,0 +1,142 @@
|
|||||||
|
"""危险操作确认 Hook - Phase 7.2
|
||||||
|
|
||||||
|
对危险操作要求用户确认。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.tools.hooks.types import (
|
||||||
|
ExecutionContext,
|
||||||
|
HookResult,
|
||||||
|
)
|
||||||
|
from app.agents.tools.manifest import SideEffectScope
|
||||||
|
|
||||||
|
|
||||||
|
# 危险操作关键词
|
||||||
|
DANGEROUS_PATTERNS = [
|
||||||
|
# 文件操作
|
||||||
|
"delete",
|
||||||
|
"remove",
|
||||||
|
"rm ",
|
||||||
|
"rmdir",
|
||||||
|
"unlink",
|
||||||
|
"format",
|
||||||
|
"truncate",
|
||||||
|
# 系统操作
|
||||||
|
"shutdown",
|
||||||
|
"reboot",
|
||||||
|
"kill",
|
||||||
|
"pkill",
|
||||||
|
"sudo",
|
||||||
|
"chmod",
|
||||||
|
"chown",
|
||||||
|
# 数据操作
|
||||||
|
"drop",
|
||||||
|
"truncate",
|
||||||
|
"delete from",
|
||||||
|
"delete.*where",
|
||||||
|
"insert into.*select",
|
||||||
|
"update.*set",
|
||||||
|
# 网络操作
|
||||||
|
"curl",
|
||||||
|
"wget",
|
||||||
|
"nc ",
|
||||||
|
"netcat",
|
||||||
|
"ssh ",
|
||||||
|
"scp ",
|
||||||
|
"sftp ",
|
||||||
|
# 环境变量
|
||||||
|
"export.*secret",
|
||||||
|
"export.*key",
|
||||||
|
"export.*token",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class DangerousConfirmationHook:
|
||||||
|
"""危险操作确认 Hook
|
||||||
|
|
||||||
|
检查工具调用是否包含危险操作,如是则要求确认。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, auto_block: bool = False):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
auto_block: True 表示自动拦截危险操作,False 表示仅警告
|
||||||
|
"""
|
||||||
|
self.auto_block = auto_block
|
||||||
|
self._pending_confirmations: dict[str, bool] = {}
|
||||||
|
|
||||||
|
async def pre_tool_use(self, context: ExecutionContext) -> HookResult:
|
||||||
|
"""检查是否为危险操作"""
|
||||||
|
is_dangerous = self._check_dangerous(context.tool_name, context.tool_input)
|
||||||
|
|
||||||
|
if is_dangerous:
|
||||||
|
if self.auto_block:
|
||||||
|
return HookResult(
|
||||||
|
hook_name="dangerous_confirmation",
|
||||||
|
success=False,
|
||||||
|
continue_execution=False,
|
||||||
|
error=f"危险操作被自动拦截: {context.tool_name}",
|
||||||
|
metadata={"dangerous": True, "auto_blocked": True},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 标记需要确认
|
||||||
|
context.metadata["requires_confirmation"] = True
|
||||||
|
context.metadata["dangerous_operation"] = True
|
||||||
|
return HookResult(
|
||||||
|
hook_name="dangerous_confirmation",
|
||||||
|
success=True,
|
||||||
|
continue_execution=True,
|
||||||
|
metadata={"dangerous": True, "requires_confirmation": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
return HookResult(
|
||||||
|
hook_name="dangerous_confirmation",
|
||||||
|
success=True,
|
||||||
|
continue_execution=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_dangerous(self, tool_name: str, tool_input: dict[str, Any]) -> bool:
|
||||||
|
"""检查是否为危险操作"""
|
||||||
|
# 检查工具名称
|
||||||
|
dangerous_tools = [
|
||||||
|
"delete",
|
||||||
|
"remove",
|
||||||
|
"drop",
|
||||||
|
"truncate",
|
||||||
|
"kill",
|
||||||
|
"shutdown",
|
||||||
|
"reboot",
|
||||||
|
"bash",
|
||||||
|
"powershell",
|
||||||
|
"shell",
|
||||||
|
]
|
||||||
|
|
||||||
|
if tool_name.lower() in dangerous_tools:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 检查输入参数
|
||||||
|
input_str = str(tool_input).lower()
|
||||||
|
|
||||||
|
for pattern in DANGEROUS_PATTERNS:
|
||||||
|
if pattern.lower() in input_str:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def confirm(self, session_id: str, confirmed: bool) -> None:
|
||||||
|
"""确认危险操作
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话 ID
|
||||||
|
confirmed: True 表示用户确认,False 表示取消
|
||||||
|
"""
|
||||||
|
self._pending_confirmations[session_id] = confirmed
|
||||||
|
|
||||||
|
def is_confirmed(self, session_id: str) -> bool:
|
||||||
|
"""检查是否已确认"""
|
||||||
|
return self._pending_confirmations.get(session_id, False)
|
||||||
|
|
||||||
|
def clear_confirmation(self, session_id: str) -> None:
|
||||||
|
"""清除确认状态"""
|
||||||
|
self._pending_confirmations.pop(session_id, None)
|
||||||
183
backend/app/agents/tools/hooks/builtins/security_scan.py
Normal file
183
backend/app/agents/tools/hooks/builtins/security_scan.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
"""安全扫描 Hook - Phase 7.2
|
||||||
|
|
||||||
|
扫描工具调用和结果中的敏感信息。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.tools.hooks.types import (
|
||||||
|
ExecutionContext,
|
||||||
|
HookResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 敏感信息模式
|
||||||
|
SENSITIVE_PATTERNS = {
|
||||||
|
"api_key": [
|
||||||
|
r"api[_-]?key['\"]?\s*[:=]\s*['\"]?[a-zA-Z0-9_\-]{20,}",
|
||||||
|
r"apikey['\"]?\s*[:=]\s*['\"]?[a-zA-Z0-9_\-]{20,}",
|
||||||
|
],
|
||||||
|
"password": [
|
||||||
|
r"password['\"]?\s*[:=]\s*['\"]?[^\s'\"]{8,}",
|
||||||
|
r"passwd['\"]?\s*[:=]\s*['\"]?[^\s'\"]{8,}",
|
||||||
|
r"secret['\"]?\s*[:=]\s*['\"]?[a-zA-Z0-9_\-]{20,}",
|
||||||
|
],
|
||||||
|
"token": [
|
||||||
|
r"token['\"]?\s*[:=]\s*['\"]?[a-zA-Z0-9_\-\.]{20,}",
|
||||||
|
r"bearer\s+[a-zA-Z0-9_\-\.]+",
|
||||||
|
r"ghp_[a-zA-Z0-9]{36}",
|
||||||
|
r"sk-[a-zA-Z0-9]{48}",
|
||||||
|
],
|
||||||
|
"private_key": [
|
||||||
|
r"-----BEGIN (RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----",
|
||||||
|
r"-----END (RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----",
|
||||||
|
],
|
||||||
|
"ip_address": [
|
||||||
|
r"\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b",
|
||||||
|
],
|
||||||
|
"email": [
|
||||||
|
r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SecurityScanHook:
|
||||||
|
"""安全扫描 Hook
|
||||||
|
|
||||||
|
扫描工具输入和输出中的敏感信息,进行脱敏处理。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
redact: bool = True,
|
||||||
|
block_on_detect: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
redact: 是否对敏感信息进行脱敏
|
||||||
|
block_on_detect: 检测到敏感信息时是否阻止执行
|
||||||
|
"""
|
||||||
|
self.redact = redact
|
||||||
|
self.block_on_detect = block_on_detect
|
||||||
|
self._compiled_patterns = {
|
||||||
|
name: [re.compile(p, re.IGNORECASE) for p in patterns]
|
||||||
|
for name, patterns in SENSITIVE_PATTERNS.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
async def pre_tool_use(self, context: ExecutionContext) -> HookResult:
|
||||||
|
"""扫描输入参数"""
|
||||||
|
detected = self._scan_dict(context.tool_input)
|
||||||
|
|
||||||
|
if detected:
|
||||||
|
context.metadata["security_detected"] = detected
|
||||||
|
|
||||||
|
if self.block_on_detect:
|
||||||
|
return HookResult(
|
||||||
|
hook_name="security_scan",
|
||||||
|
success=False,
|
||||||
|
continue_execution=False,
|
||||||
|
error=f"检测到敏感信息: {', '.join(detected.keys())}",
|
||||||
|
metadata={"detected": detected, "blocked": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.redact:
|
||||||
|
redacted_input = self._redact_dict(context.tool_input.copy())
|
||||||
|
return HookResult(
|
||||||
|
hook_name="security_scan",
|
||||||
|
success=True,
|
||||||
|
continue_execution=True,
|
||||||
|
modified_input=redacted_input,
|
||||||
|
metadata={"detected": detected, "redacted": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
return HookResult(
|
||||||
|
hook_name="security_scan",
|
||||||
|
success=True,
|
||||||
|
continue_execution=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def post_tool_use(self, context: ExecutionContext, result: Any) -> HookResult:
|
||||||
|
"""扫描输出结果"""
|
||||||
|
if isinstance(result, dict):
|
||||||
|
detected = self._scan_dict(result)
|
||||||
|
|
||||||
|
if detected:
|
||||||
|
context.metadata["security_detected_output"] = detected
|
||||||
|
|
||||||
|
if self.redact:
|
||||||
|
redacted_result = self._redact_dict(result.copy())
|
||||||
|
return HookResult(
|
||||||
|
hook_name="security_scan",
|
||||||
|
success=True,
|
||||||
|
continue_execution=True,
|
||||||
|
modified_output=redacted_result,
|
||||||
|
metadata={"detected": detected, "redacted": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(result, str):
|
||||||
|
detected = self._scan_string(result)
|
||||||
|
if detected:
|
||||||
|
context.metadata["security_detected_output"] = detected
|
||||||
|
|
||||||
|
if self.redact:
|
||||||
|
redacted_result = self._redact_string(result)
|
||||||
|
return HookResult(
|
||||||
|
hook_name="security_scan",
|
||||||
|
success=True,
|
||||||
|
continue_execution=True,
|
||||||
|
modified_output=redacted_result,
|
||||||
|
metadata={"detected": detected, "redacted": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
return HookResult(
|
||||||
|
hook_name="security_scan",
|
||||||
|
success=True,
|
||||||
|
continue_execution=True,
|
||||||
|
modified_output=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _scan_dict(self, data: dict[str, Any]) -> dict[str, list[str]]:
|
||||||
|
"""扫描字典中的敏感信息"""
|
||||||
|
result: dict[str, list[str]] = {}
|
||||||
|
|
||||||
|
for key, value in data.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
found = self._scan_string(value)
|
||||||
|
if found:
|
||||||
|
result[key] = found
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _scan_string(self, text: str) -> list[str]:
|
||||||
|
"""扫描字符串中的敏感信息"""
|
||||||
|
found_types = []
|
||||||
|
|
||||||
|
for name, patterns in self._compiled_patterns.items():
|
||||||
|
for pattern in patterns:
|
||||||
|
if pattern.search(text):
|
||||||
|
if name not in found_types:
|
||||||
|
found_types.append(name)
|
||||||
|
break
|
||||||
|
|
||||||
|
return found_types
|
||||||
|
|
||||||
|
def _redact_dict(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""脱敏字典中的敏感信息"""
|
||||||
|
for key, value in data.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
data[key] = self._redact_string(value)
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
data[key] = self._redact_dict(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
data[key] = [self._redact_string(v) if isinstance(v, str) else v for v in value]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _redact_string(self, text: str) -> str:
|
||||||
|
"""脱敏字符串中的敏感信息"""
|
||||||
|
for name, patterns in self._compiled_patterns.items():
|
||||||
|
for pattern in patterns:
|
||||||
|
text = pattern.sub(f"[REDACTED:{name}]", text)
|
||||||
|
|
||||||
|
return text
|
||||||
105
backend/app/agents/tools/hooks/config.py
Normal file
105
backend/app/agents/tools/hooks/config.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
"""Hook 配置持久化 - Phase 7.3"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.tools.hooks.manager import get_hook_manager
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HookConfigEntry:
|
||||||
|
"""Hook 配置条目"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
hook_type: str
|
||||||
|
enabled: bool
|
||||||
|
tool_names: list[str] | None = None
|
||||||
|
categories: list[str] | None = None
|
||||||
|
priority: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class HookConfigPersistence:
|
||||||
|
"""Hook 配置持久化"""
|
||||||
|
|
||||||
|
def __init__(self, config_path: str | None = None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
config_path: 配置文件路径,None 则使用默认路径
|
||||||
|
"""
|
||||||
|
if config_path is None:
|
||||||
|
config_path = os.path.join(
|
||||||
|
os.path.dirname(__file__), "..", "..", "..", "..", "config", "hooks.json"
|
||||||
|
)
|
||||||
|
self.config_path = config_path
|
||||||
|
|
||||||
|
def load_config(self) -> list[HookConfigEntry]:
|
||||||
|
"""从文件加载 Hook 配置"""
|
||||||
|
if not os.path.exists(self.config_path):
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(self.config_path, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
return [HookConfigEntry(**entry) for entry in data]
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def save_config(self, entries: list[HookConfigEntry]) -> bool:
|
||||||
|
"""保存 Hook 配置到文件"""
|
||||||
|
try:
|
||||||
|
os.makedirs(os.path.dirname(self.config_path), exist_ok=True)
|
||||||
|
with open(self.config_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump([asdict(e) for e in entries], f, indent=2, ensure_ascii=False)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def apply_config(self) -> int:
|
||||||
|
"""应用配置到 HookManager
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
应用的 Hook 数量
|
||||||
|
"""
|
||||||
|
from app.agents.tools.hooks.types import HookType
|
||||||
|
|
||||||
|
manager = get_hook_manager()
|
||||||
|
entries = self.load_config()
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
if entry.enabled:
|
||||||
|
from app.agents.tools.hooks.types import HookDefinition, HookTrigger
|
||||||
|
|
||||||
|
trigger = HookTrigger(
|
||||||
|
tool_names=entry.tool_names,
|
||||||
|
categories=entry.categories,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建空的 handler,只是注册配置
|
||||||
|
hook_def = HookDefinition(
|
||||||
|
name=entry.name,
|
||||||
|
hook_type=HookType(entry.hook_type),
|
||||||
|
trigger=trigger,
|
||||||
|
handler=lambda ctx, *args: ctx,
|
||||||
|
priority=entry.priority,
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
manager.register(hook_def)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
_persistence: HookConfigPersistence | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_hook_config_persistence() -> HookConfigPersistence:
|
||||||
|
"""获取全局 Hook 配置持久化实例"""
|
||||||
|
global _persistence
|
||||||
|
if _persistence is None:
|
||||||
|
_persistence = HookConfigPersistence()
|
||||||
|
return _persistence
|
||||||
5
backend/app/agents/tools/hooks/custom/__init__.py
Normal file
5
backend/app/agents/tools/hooks/custom/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""自定义 Hook 加载器包"""
|
||||||
|
|
||||||
|
from app.agents.tools.hooks.custom.loader import CustomHookLoader, get_custom_hook_loader
|
||||||
|
|
||||||
|
__all__ = ["CustomHookLoader", "get_custom_hook_loader"]
|
||||||
153
backend/app/agents/tools/hooks/custom/loader.py
Normal file
153
backend/app/agents/tools/hooks/custom/loader.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
"""自定义 Hook 加载器 - Phase 7.4
|
||||||
|
|
||||||
|
支持动态加载用户自定义的 Hook。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.tools.hooks.types import HookDefinition, HookType, HookTrigger, HookResult
|
||||||
|
|
||||||
|
|
||||||
|
class CustomHookLoader:
|
||||||
|
"""自定义 Hook 加载器
|
||||||
|
|
||||||
|
从指定目录动态加载自定义 Hook 模块。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hooks_dir: str | None = None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hooks_dir: Hook 目录,None 则使用默认目录
|
||||||
|
"""
|
||||||
|
if hooks_dir is None:
|
||||||
|
hooks_dir = os.path.join(
|
||||||
|
os.path.dirname(__file__), "..", "..", "..", "data", "custom_hooks"
|
||||||
|
)
|
||||||
|
self.hooks_dir = hooks_dir
|
||||||
|
self._loaded_hooks: dict[str, HookDefinition] = {}
|
||||||
|
|
||||||
|
def load_all(self) -> list[HookDefinition]:
|
||||||
|
"""加载所有自定义 Hook
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hook 定义列表
|
||||||
|
"""
|
||||||
|
hooks = []
|
||||||
|
|
||||||
|
if not os.path.exists(self.hooks_dir):
|
||||||
|
return hooks
|
||||||
|
|
||||||
|
for filename in os.listdir(self.hooks_dir):
|
||||||
|
if filename.endswith(".py") and not filename.startswith("_"):
|
||||||
|
hook_path = os.path.join(self.hooks_dir, filename)
|
||||||
|
hook_def = self._load_hook_from_file(hook_path, filename[:-3])
|
||||||
|
if hook_def:
|
||||||
|
hooks.append(hook_def)
|
||||||
|
self._loaded_hooks[hook_def.name] = hook_def
|
||||||
|
|
||||||
|
return hooks
|
||||||
|
|
||||||
|
def _load_hook_from_file(self, hook_path: str, module_name: str) -> HookDefinition | None:
|
||||||
|
"""从文件加载 Hook
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hook_path: Hook 文件路径
|
||||||
|
module_name: 模块名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hook 定义或 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
spec = importlib.util.spec_from_file_location(module_name, hook_path)
|
||||||
|
if not spec or not spec.loader:
|
||||||
|
return None
|
||||||
|
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
|
||||||
|
# 查找 HOOK_DEFINITION 或 hook_definition
|
||||||
|
hook_def = getattr(module, "HOOK_DEFINITION", None) or getattr(
|
||||||
|
module, "hook_definition", None
|
||||||
|
)
|
||||||
|
|
||||||
|
if hook_def and isinstance(hook_def, HookDefinition):
|
||||||
|
return hook_def
|
||||||
|
|
||||||
|
# 如果没有定义,尝试从函数自动推断
|
||||||
|
if hasattr(module, "pre_tool_hook") or hasattr(module, "post_tool_hook"):
|
||||||
|
return self._infer_hook_definition(module, module_name)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _infer_hook_definition(self, module: Any, module_name: str) -> HookDefinition | None:
|
||||||
|
"""从模块函数推断 Hook 定义
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: 模块对象
|
||||||
|
module_name: 模块名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hook 定义或 None
|
||||||
|
"""
|
||||||
|
hook_type = None
|
||||||
|
handler = None
|
||||||
|
|
||||||
|
if hasattr(module, "pre_tool_hook"):
|
||||||
|
handler = module.pre_tool_hook
|
||||||
|
hook_type = HookType.PRE_TOOL_USE
|
||||||
|
elif hasattr(module, "post_tool_hook"):
|
||||||
|
handler = module.post_tool_hook
|
||||||
|
hook_type = HookType.POST_TOOL_USE
|
||||||
|
elif hasattr(module, "error_tool_hook"):
|
||||||
|
handler = module.error_tool_hook
|
||||||
|
hook_type = HookType.TOOL_ERROR
|
||||||
|
|
||||||
|
if not handler or not hook_type:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return HookDefinition(
|
||||||
|
name=module_name,
|
||||||
|
hook_type=hook_type,
|
||||||
|
trigger=HookTrigger(),
|
||||||
|
handler=handler,
|
||||||
|
priority=0,
|
||||||
|
enabled=True,
|
||||||
|
description=f"Auto-loaded hook from {module_name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_hook(self, name: str) -> HookDefinition | None:
|
||||||
|
"""获取已加载的 Hook
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Hook 名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hook 定义或 None
|
||||||
|
"""
|
||||||
|
return self._loaded_hooks.get(name)
|
||||||
|
|
||||||
|
def reload(self) -> list[HookDefinition]:
|
||||||
|
"""重新加载所有 Hook
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
重新加载的 Hook 列表
|
||||||
|
"""
|
||||||
|
self._loaded_hooks.clear()
|
||||||
|
return self.load_all()
|
||||||
|
|
||||||
|
|
||||||
|
# 全局加载器
|
||||||
|
_loader: CustomHookLoader | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_custom_hook_loader() -> CustomHookLoader:
|
||||||
|
"""获取全局自定义 Hook 加载器"""
|
||||||
|
global _loader
|
||||||
|
if _loader is None:
|
||||||
|
_loader = CustomHookLoader()
|
||||||
|
return _loader
|
||||||
170
backend/app/agents/tools/hooks/executor.py
Normal file
170
backend/app/agents/tools/hooks/executor.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
"""Hook 执行器 - Phase 6.2
|
||||||
|
|
||||||
|
执行 Hook 拦截逻辑。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.tools.hooks.manager import get_hook_manager
|
||||||
|
from app.agents.tools.hooks.types import (
|
||||||
|
HookDefinition,
|
||||||
|
HookResult,
|
||||||
|
HookType,
|
||||||
|
ExecutionContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HookExecutor:
|
||||||
|
"""Hook 执行器
|
||||||
|
|
||||||
|
负责在工具执行前后执行 Hook 逻辑。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._manager = get_hook_manager()
|
||||||
|
|
||||||
|
async def execute_pre_hooks(
|
||||||
|
self, context: ExecutionContext
|
||||||
|
) -> tuple[bool, dict[str, Any] | None]:
|
||||||
|
"""执行 pre-tool Hook
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: 执行上下文
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(是否继续执行, 修改后的输入)
|
||||||
|
"""
|
||||||
|
hooks = self._manager.get_hooks(HookType.PRE_TOOL_USE, context.tool_name)
|
||||||
|
modified_input = context.tool_input
|
||||||
|
|
||||||
|
for hook in hooks:
|
||||||
|
try:
|
||||||
|
# 调用 hook handler
|
||||||
|
handler = hook.handler
|
||||||
|
if callable(handler):
|
||||||
|
result = await self._call_hook(handler, context)
|
||||||
|
if result and not result.continue_execution:
|
||||||
|
# Hook 决定中断执行
|
||||||
|
return False, modified_input
|
||||||
|
if result.modified_input is not None:
|
||||||
|
modified_input = result.modified_input
|
||||||
|
except Exception as e:
|
||||||
|
# Hook 出错,默认继续执行
|
||||||
|
pass
|
||||||
|
|
||||||
|
return True, modified_input
|
||||||
|
|
||||||
|
async def execute_post_hooks(self, context: ExecutionContext, result: Any) -> Any:
|
||||||
|
"""执行 post-tool Hook
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: 执行上下文
|
||||||
|
result: 工具执行结果
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
修改后的结果
|
||||||
|
"""
|
||||||
|
hooks = self._manager.get_hooks(HookType.POST_TOOL_USE, context.tool_name)
|
||||||
|
modified_result = result
|
||||||
|
|
||||||
|
for hook in hooks:
|
||||||
|
try:
|
||||||
|
handler = hook.handler
|
||||||
|
if callable(handler):
|
||||||
|
hook_result = await self._call_hook(handler, context, modified_result)
|
||||||
|
if hook_result and hook_result.modified_output is not None:
|
||||||
|
modified_result = hook_result.modified_output
|
||||||
|
except Exception:
|
||||||
|
# Hook 出错,默认保留原结果
|
||||||
|
pass
|
||||||
|
|
||||||
|
return modified_result
|
||||||
|
|
||||||
|
async def execute_error_hooks(
|
||||||
|
self, context: ExecutionContext, error: Exception
|
||||||
|
) -> HookResult | None:
|
||||||
|
"""执行 error Hook
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: 执行上下文
|
||||||
|
error: 异常
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hook 结果,如果返回 None 则继续传播错误
|
||||||
|
"""
|
||||||
|
hooks = self._manager.get_hooks(HookType.TOOL_ERROR, context.tool_name)
|
||||||
|
|
||||||
|
for hook in hooks:
|
||||||
|
try:
|
||||||
|
handler = hook.handler
|
||||||
|
if callable(handler):
|
||||||
|
result = await self._call_hook(handler, context, error)
|
||||||
|
if result is not None and result.continue_execution:
|
||||||
|
return result
|
||||||
|
except Exception:
|
||||||
|
# Hook 出错,继续执行其他 error hooks
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def execute_skip_check(self, context: ExecutionContext) -> bool:
|
||||||
|
"""检查是否应跳过工具执行
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: 执行上下文
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True 表示跳过,False 表示执行
|
||||||
|
"""
|
||||||
|
hooks = self._manager.get_hooks(HookType.TOOL_SKIP, context.tool_name)
|
||||||
|
|
||||||
|
for hook in hooks:
|
||||||
|
try:
|
||||||
|
handler = hook.handler
|
||||||
|
if callable(handler):
|
||||||
|
result = await self._call_hook(handler, context)
|
||||||
|
if result is not None and isinstance(result, bool):
|
||||||
|
return result
|
||||||
|
except Exception:
|
||||||
|
# Hook 出错,默认不跳过
|
||||||
|
pass
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _call_hook(
|
||||||
|
self, handler: Any, context: ExecutionContext, *args: Any
|
||||||
|
) -> HookResult | None:
|
||||||
|
"""调用 Hook 处理函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
handler: Hook 处理函数
|
||||||
|
context: 执行上下文
|
||||||
|
*args: 额外参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hook 结果
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# 如果是普通函数,直接调用
|
||||||
|
if asyncio.iscoroutinefunction(handler):
|
||||||
|
return await handler(context, *args)
|
||||||
|
else:
|
||||||
|
return handler(context, *args)
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
_executor: HookExecutor | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_hook_executor() -> HookExecutor:
|
||||||
|
"""获取全局 Hook 执行器
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
全局 HookExecutor 实例
|
||||||
|
"""
|
||||||
|
global _executor
|
||||||
|
if _executor is None:
|
||||||
|
_executor = HookExecutor()
|
||||||
|
return _executor
|
||||||
174
backend/app/agents/tools/hooks/manager.py
Normal file
174
backend/app/agents/tools/hooks/manager.py
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
"""Hook 管理器 - Phase 6.2
|
||||||
|
|
||||||
|
管理 Hook 的注册、查找和配置。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.tools.hooks.types import (
|
||||||
|
HookDefinition,
|
||||||
|
HookResult,
|
||||||
|
HookTrigger,
|
||||||
|
HookType,
|
||||||
|
ExecutionContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HookManager:
|
||||||
|
"""Hook 管理器
|
||||||
|
|
||||||
|
管理全局 Hook 的注册和配置。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._hooks: dict[HookType, list[HookDefinition]] = {
|
||||||
|
HookType.PRE_TOOL_USE: [],
|
||||||
|
HookType.POST_TOOL_USE: [],
|
||||||
|
HookType.TOOL_ERROR: [],
|
||||||
|
HookType.TOOL_SKIP: [],
|
||||||
|
}
|
||||||
|
self._global_hooks: list[HookDefinition] = [] # 全局 Hook(对所有工具生效)
|
||||||
|
|
||||||
|
def register(self, definition: HookDefinition) -> None:
|
||||||
|
"""注册 Hook
|
||||||
|
|
||||||
|
Args:
|
||||||
|
definition: Hook 定义
|
||||||
|
"""
|
||||||
|
if definition.trigger.tool_names is None and definition.trigger.categories is None:
|
||||||
|
# 全局 Hook
|
||||||
|
self._global_hooks.append(definition)
|
||||||
|
else:
|
||||||
|
# 特定工具 Hook
|
||||||
|
self._hooks[definition.hook_type].append(definition)
|
||||||
|
|
||||||
|
# 按优先级排序
|
||||||
|
self._hooks[definition.hook_type].sort(key=lambda h: h.priority, reverse=True)
|
||||||
|
self._global_hooks.sort(key=lambda h: h.priority, reverse=True)
|
||||||
|
|
||||||
|
def unregister(self, name: str) -> bool:
|
||||||
|
"""注销 Hook
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Hook 名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功注销
|
||||||
|
"""
|
||||||
|
# 从特定工具 Hook 中移除
|
||||||
|
for hooks in self._hooks.values():
|
||||||
|
for i, hook in enumerate(hooks):
|
||||||
|
if hook.name == name:
|
||||||
|
hooks.pop(i)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 从全局 Hook 中移除
|
||||||
|
for i, hook in enumerate(self._global_hooks):
|
||||||
|
if hook.name == name:
|
||||||
|
self._global_hooks.pop(i)
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_hooks(self, hook_type: HookType, tool_name: str | None = None) -> list[HookDefinition]:
|
||||||
|
"""获取指定类型和工具的 Hook
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hook_type: Hook 类型
|
||||||
|
tool_name: 工具名称(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
匹配的 Hook 列表
|
||||||
|
"""
|
||||||
|
result: list[HookDefinition] = []
|
||||||
|
|
||||||
|
# 添加全局 Hook
|
||||||
|
for hook in self._global_hooks:
|
||||||
|
if hook.hook_type == hook_type and hook.enabled:
|
||||||
|
result.append(hook)
|
||||||
|
|
||||||
|
# 添加特定工具 Hook
|
||||||
|
for hook in self._hooks[hook_type]:
|
||||||
|
if not hook.enabled:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if hook.trigger.tool_names is None and hook.trigger.categories is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查是否匹配
|
||||||
|
if hook.trigger.tool_names and tool_name not in hook.trigger.tool_names:
|
||||||
|
continue
|
||||||
|
|
||||||
|
result.append(hook)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def list_all(self) -> list[HookDefinition]:
|
||||||
|
"""列出所有已注册的 Hook
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hook 列表
|
||||||
|
"""
|
||||||
|
all_hooks = list(self._global_hooks)
|
||||||
|
for hooks in self._hooks.values():
|
||||||
|
all_hooks.extend(hooks)
|
||||||
|
return all_hooks
|
||||||
|
|
||||||
|
def enable(self, name: str) -> bool:
|
||||||
|
"""启用 Hook
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Hook 名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功启用
|
||||||
|
"""
|
||||||
|
for hook in self.list_all():
|
||||||
|
if hook.name == name:
|
||||||
|
hook.enabled = True
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def disable(self, name: str) -> bool:
|
||||||
|
"""禁用 Hook
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Hook 名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功禁用
|
||||||
|
"""
|
||||||
|
for hook in self.list_all():
|
||||||
|
if hook.name == name:
|
||||||
|
hook.enabled = False
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""清除所有 Hook"""
|
||||||
|
self._hooks = {ht: [] for ht in HookType}
|
||||||
|
self._global_hooks = []
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
_global_hook_manager: HookManager | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_hook_manager() -> HookManager:
|
||||||
|
"""获取全局 Hook 管理器
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
全局 HookManager 实例
|
||||||
|
"""
|
||||||
|
global _global_hook_manager
|
||||||
|
if _global_hook_manager is None:
|
||||||
|
_global_hook_manager = HookManager()
|
||||||
|
return _global_hook_manager
|
||||||
|
|
||||||
|
|
||||||
|
def reset_hook_manager() -> None:
|
||||||
|
"""重置全局 Hook 管理器(用于测试)"""
|
||||||
|
global _global_hook_manager
|
||||||
|
if _global_hook_manager is not None:
|
||||||
|
_global_hook_manager.clear()
|
||||||
|
_global_hook_manager = None
|
||||||
90
backend/app/agents/tools/hooks/types.py
Normal file
90
backend/app/agents/tools/hooks/types.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""Hook 类型定义 - Phase 6.2
|
||||||
|
|
||||||
|
Hook 拦截系统类型定义。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
|
||||||
|
class HookType(Enum):
|
||||||
|
"""Hook 类型"""
|
||||||
|
|
||||||
|
PRE_TOOL_USE = "pre_tool_use" # 工具执行前
|
||||||
|
POST_TOOL_USE = "post_tool_use" # 工具执行后
|
||||||
|
TOOL_ERROR = "tool_error" # 工具执行出错
|
||||||
|
TOOL_SKIP = "tool_skip" # 工具跳过(条件执行)
|
||||||
|
|
||||||
|
|
||||||
|
class HookStage(Enum):
|
||||||
|
"""Hook 执行阶段"""
|
||||||
|
|
||||||
|
BEFORE = "before"
|
||||||
|
AFTER = "after"
|
||||||
|
ON_ERROR = "on_error"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HookTrigger:
|
||||||
|
"""Hook 触发条件"""
|
||||||
|
|
||||||
|
tool_names: list[str] | None = None # 只对特定工具生效,None 表示全部
|
||||||
|
categories: list[str] | None = None # 只对特定类别生效
|
||||||
|
conditions: dict[str, Any] | None = None # 自定义条件
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HookDefinition:
|
||||||
|
"""Hook 定义"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
hook_type: HookType
|
||||||
|
trigger: HookTrigger
|
||||||
|
handler: Callable[..., Any] # Hook 处理函数
|
||||||
|
priority: int = 0 # 优先级,数字越大越先执行
|
||||||
|
enabled: bool = True
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HookResult:
|
||||||
|
"""Hook 执行结果"""
|
||||||
|
|
||||||
|
hook_name: str
|
||||||
|
success: bool
|
||||||
|
continue_execution: bool = True # False 表示中断执行
|
||||||
|
modified_input: Any = None # 修改后的输入
|
||||||
|
modified_output: Any = None # 修改后的输出
|
||||||
|
error: str | None = None
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExecutionContext:
|
||||||
|
"""工具执行上下文"""
|
||||||
|
|
||||||
|
tool_name: str
|
||||||
|
tool_input: dict[str, Any]
|
||||||
|
user_id: str | None = None
|
||||||
|
session_id: str | None = None
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# 执行结果(由 HookExecutor 填充)
|
||||||
|
result: Any = None
|
||||||
|
error: Exception | None = None
|
||||||
|
start_time: float | None = None
|
||||||
|
end_time: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# Hook 处理函数类型
|
||||||
|
HookHandler = Callable[[ExecutionContext, HookDefinition], HookResult]
|
||||||
|
|
||||||
|
# Pre-hook: 在工具执行前调用,可以修改输入或决定是否跳过
|
||||||
|
PreToolHook = Callable[[ExecutionContext], tuple[bool, dict[str, Any] | None]]
|
||||||
|
# post-hook: 在工具执行后调用,可以修改输出
|
||||||
|
PostToolHook = Callable[[ExecutionContext, Any], Any]
|
||||||
|
# Error hook: 在工具出错时调用
|
||||||
|
ErrorToolHook = Callable[[ExecutionContext, Exception], HookResult | None]
|
||||||
|
# Skip hook: 决定是否跳过工具执行
|
||||||
|
SkipToolHook = Callable[[ExecutionContext], bool]
|
||||||
77
backend/app/agents/tools/manifest.py
Normal file
77
backend/app/agents/tools/manifest.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""工具元数据和数据类型定义"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCategory(Enum):
|
||||||
|
"""工具类别"""
|
||||||
|
|
||||||
|
READ = "read"
|
||||||
|
WRITE = "write"
|
||||||
|
EXTERNAL = "external"
|
||||||
|
DB_WRITE = "db_write"
|
||||||
|
NETWORK = "network"
|
||||||
|
|
||||||
|
|
||||||
|
class SideEffectScope(Enum):
|
||||||
|
"""副作用范围"""
|
||||||
|
|
||||||
|
NONE = "none"
|
||||||
|
LOCAL_STATE = "local_state"
|
||||||
|
DB_WRITE = "db_write"
|
||||||
|
NETWORK = "network"
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionClass(Enum):
|
||||||
|
"""权限级别"""
|
||||||
|
|
||||||
|
READ = "read"
|
||||||
|
WRITE = "write"
|
||||||
|
EXTERNAL = "external"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolManifest:
|
||||||
|
"""工具元数据"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
category: ToolCategory
|
||||||
|
parameters: dict[str, Any] # JSON Schema
|
||||||
|
return_schema: dict[str, Any]
|
||||||
|
permission_class: PermissionClass
|
||||||
|
side_effect_scope: SideEffectScope
|
||||||
|
requires_confirmation: bool = False
|
||||||
|
is_streaming: bool = False
|
||||||
|
tags: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"category": self.category.value,
|
||||||
|
"parameters": self.parameters,
|
||||||
|
"return_schema": self.return_schema,
|
||||||
|
"permission_class": self.permission_class.value,
|
||||||
|
"side_effect_scope": self.side_effect_scope.value,
|
||||||
|
"requires_confirmation": self.requires_confirmation,
|
||||||
|
"is_streaming": self.is_streaming,
|
||||||
|
"tags": self.tags,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HookConfig:
|
||||||
|
"""Hook 配置"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
hook_type: str # "pre_tool_use", "post_tool_use", "tool_error", "tool_skip"
|
||||||
|
filter_names: list[str] | None = None # 只对特定工具生效,None 表示全部
|
||||||
|
|
||||||
|
def matches_tool(self, tool_name: str) -> bool:
|
||||||
|
"""检查 Hook 是否对指定工具生效"""
|
||||||
|
if self.filter_names is None:
|
||||||
|
return True
|
||||||
|
return tool_name in self.filter_names
|
||||||
251
backend/app/agents/tools/migration.py
Normal file
251
backend/app/agents/tools/migration.py
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
"""工具迁移和向后兼容层 - Phase 6.1
|
||||||
|
|
||||||
|
将现有 @tool 装饰的工具迁移到 ToolRegistry,同时保持向后兼容。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
from app.agents.tools.manifest import (
|
||||||
|
PermissionClass,
|
||||||
|
SideEffectScope,
|
||||||
|
ToolCategory,
|
||||||
|
ToolManifest,
|
||||||
|
)
|
||||||
|
from app.agents.tools.registry import get_tool_registry
|
||||||
|
|
||||||
|
|
||||||
|
# 现有工具的类别映射
|
||||||
|
_TOOL_CATEGORY_MAP: dict[str, tuple[ToolCategory, PermissionClass, SideEffectScope]] = {
|
||||||
|
# 知识检索 - 只读
|
||||||
|
"search_knowledge": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||||
|
"get_knowledge_graph_context": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||||
|
"hybrid_search": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||||
|
"web_search": (ToolCategory.NETWORK, PermissionClass.EXTERNAL, SideEffectScope.NETWORK),
|
||||||
|
# 知识构建 - 写入
|
||||||
|
"build_knowledge_graph": (
|
||||||
|
ToolCategory.WRITE,
|
||||||
|
PermissionClass.WRITE,
|
||||||
|
SideEffectScope.LOCAL_STATE,
|
||||||
|
),
|
||||||
|
# 任务工具
|
||||||
|
"get_tasks": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||||
|
"create_task": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||||
|
"update_task_status": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||||
|
# 日程工具
|
||||||
|
"get_schedule_day": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||||
|
"create_todo": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||||
|
"create_schedule_task": (
|
||||||
|
ToolCategory.WRITE,
|
||||||
|
PermissionClass.WRITE,
|
||||||
|
SideEffectScope.LOCAL_STATE,
|
||||||
|
),
|
||||||
|
"create_reminder": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||||
|
"create_goal": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||||
|
"resolve_time_expression": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||||
|
# 论坛工具
|
||||||
|
"get_forum_posts": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||||
|
"create_forum_post": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||||
|
"scan_forum_for_instructions": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_category(name: str) -> tuple[ToolCategory, PermissionClass, SideEffectScope]:
|
||||||
|
"""获取工具的类别信息"""
|
||||||
|
return _TOOL_CATEGORY_MAP.get(
|
||||||
|
name,
|
||||||
|
(ToolCategory.EXTERNAL, PermissionClass.EXTERNAL, SideEffectScope.NETWORK),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def infer_tags_from_docstring(docstring: str | None) -> list[str]:
|
||||||
|
"""从 docstring 推断工具标签"""
|
||||||
|
if not docstring:
|
||||||
|
return []
|
||||||
|
tags = []
|
||||||
|
doc_lower = docstring.lower()
|
||||||
|
if "搜索" in docstring or "查询" in docstring or "search" in doc_lower:
|
||||||
|
tags.append("search")
|
||||||
|
if "创建" in docstring or "新建" in docstring or "create" in doc_lower:
|
||||||
|
tags.append("create")
|
||||||
|
if "获取" in docstring or "读取" in docstring or "get" in doc_lower:
|
||||||
|
tags.append("read")
|
||||||
|
if "更新" in docstring or "修改" in docstring or "update" in doc_lower:
|
||||||
|
tags.append("update")
|
||||||
|
return tags
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_tool(tool_func: Callable) -> Callable:
|
||||||
|
"""将现有 @tool 装饰的函数迁移到 ToolRegistry
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_func: LangChain @tool 装饰的函数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
原函数(已注册到 registry)
|
||||||
|
"""
|
||||||
|
registry = get_tool_registry()
|
||||||
|
|
||||||
|
# 如果已经注册,跳过
|
||||||
|
if registry.get(tool_func.name):
|
||||||
|
return tool_func
|
||||||
|
|
||||||
|
# 获取类别信息
|
||||||
|
category, permission, side_effect = get_tool_category(tool_func.name)
|
||||||
|
|
||||||
|
# 从 docstring 提取 description
|
||||||
|
description = tool_func.description if hasattr(tool_func, "description") else ""
|
||||||
|
|
||||||
|
# 推断 tags
|
||||||
|
tags = infer_tags_from_docstring(description)
|
||||||
|
tags.append("migrated")
|
||||||
|
|
||||||
|
# 创建 manifest
|
||||||
|
manifest = ToolManifest(
|
||||||
|
name=tool_func.name,
|
||||||
|
description=description,
|
||||||
|
category=category,
|
||||||
|
parameters={}, # LangChain @tool 动态处理参数
|
||||||
|
return_schema={},
|
||||||
|
permission_class=permission,
|
||||||
|
side_effect_scope=side_effect,
|
||||||
|
requires_confirmation=side_effect != SideEffectScope.NONE,
|
||||||
|
is_streaming=False,
|
||||||
|
tags=tags,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注册到 registry
|
||||||
|
registry.register(manifest, tool_func)
|
||||||
|
|
||||||
|
return tool_func
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_all_tools() -> int:
|
||||||
|
"""迁移所有现有工具到 ToolRegistry
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
迁移的工具数量
|
||||||
|
"""
|
||||||
|
from app.agents.tools import (
|
||||||
|
ALL_TOOLS,
|
||||||
|
KNOWLEDGE_GRAPH_TOOLS,
|
||||||
|
KNOWLEDGE_RETRIEVAL_TOOLS,
|
||||||
|
SCHEDULE_READ_TOOLS,
|
||||||
|
SCHEDULE_WRITE_TOOLS,
|
||||||
|
TASK_TOOLS,
|
||||||
|
FORUM_TOOLS,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_tools = (
|
||||||
|
KNOWLEDGE_RETRIEVAL_TOOLS
|
||||||
|
+ KNOWLEDGE_GRAPH_TOOLS
|
||||||
|
+ TASK_TOOLS
|
||||||
|
+ SCHEDULE_READ_TOOLS
|
||||||
|
+ SCHEDULE_WRITE_TOOLS
|
||||||
|
+ FORUM_TOOLS
|
||||||
|
)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for tool in all_tools:
|
||||||
|
try:
|
||||||
|
migrate_tool(tool)
|
||||||
|
count += 1
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to migrate tool {getattr(tool, 'name', 'unknown')}: {e}")
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
class BackwardCompatTool:
|
||||||
|
"""向后兼容工具包装器
|
||||||
|
|
||||||
|
确保现有代码通过 registry.get_executor() 仍能正常调用工具。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self.name = name
|
||||||
|
self._registry = get_tool_registry()
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs) -> Any:
|
||||||
|
executor = self._registry.get_executor(self.name)
|
||||||
|
if executor is None:
|
||||||
|
raise ValueError(f"Tool not found in registry: {self.name}")
|
||||||
|
return executor(*args, **kwargs)
|
||||||
|
|
||||||
|
def invoke(self, tool_input: dict[str, Any]) -> Any:
|
||||||
|
"""LangChain 风格的 invoke 调用"""
|
||||||
|
executor = self._registry.get_executor(self.name)
|
||||||
|
if executor is None:
|
||||||
|
raise ValueError(f"Tool not found in registry: {self.name}")
|
||||||
|
|
||||||
|
# 处理位置参数
|
||||||
|
if isinstance(tool_input, dict):
|
||||||
|
return executor(**tool_input)
|
||||||
|
return executor(tool_input)
|
||||||
|
|
||||||
|
|
||||||
|
def create_compat_layer() -> dict[str, BackwardCompatTool]:
|
||||||
|
"""创建向后兼容层
|
||||||
|
|
||||||
|
返回一个字典,允许通过名称访问兼容的工具包装器。
|
||||||
|
"""
|
||||||
|
registry = get_tool_registry()
|
||||||
|
tools = registry.list_all()
|
||||||
|
|
||||||
|
return {tool.name: BackwardCompatTool(tool.name) for tool in tools}
|
||||||
|
|
||||||
|
|
||||||
|
# 自动迁移装饰器
|
||||||
|
def auto_migrate(func: Callable) -> Callable:
|
||||||
|
"""自动迁移装饰器
|
||||||
|
|
||||||
|
用于装饰新的 @tool 函数,自动注册到 registry。
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
# 迁移到 registry
|
||||||
|
migrate_tool(wrapper)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
# 便捷函数:获取兼容的工具执行器
|
||||||
|
def get_tool_executor(name: str) -> Callable | None:
|
||||||
|
"""获取工具执行器(兼容层)
|
||||||
|
|
||||||
|
优先从 registry 获取,fallback 到直接导入。
|
||||||
|
"""
|
||||||
|
registry = get_tool_registry()
|
||||||
|
executor = registry.get_executor(name)
|
||||||
|
|
||||||
|
if executor is not None:
|
||||||
|
return executor
|
||||||
|
|
||||||
|
# Fallback: 直接从模块导入(仅用于迁移期间)
|
||||||
|
try:
|
||||||
|
from app.agents.tools import (
|
||||||
|
TASK_TOOLS,
|
||||||
|
SCHEDULE_READ_TOOLS,
|
||||||
|
SCHEDULE_WRITE_TOOLS,
|
||||||
|
FORUM_TOOLS,
|
||||||
|
KNOWLEDGE_RETRIEVAL_TOOLS,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_tools = (
|
||||||
|
KNOWLEDGE_RETRIEVAL_TOOLS
|
||||||
|
+ TASK_TOOLS
|
||||||
|
+ SCHEDULE_READ_TOOLS
|
||||||
|
+ SCHEDULE_WRITE_TOOLS
|
||||||
|
+ FORUM_TOOLS
|
||||||
|
)
|
||||||
|
|
||||||
|
for tool in all_tools:
|
||||||
|
if hasattr(tool, "name") and tool.name == name:
|
||||||
|
return tool
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
206
backend/app/agents/tools/registry.py
Normal file
206
backend/app/agents/tools/registry.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
"""工具注册表 - 工具系统重构 Phase 6.1"""
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
from app.agents.tools.manifest import HookConfig, ToolManifest
|
||||||
|
|
||||||
|
|
||||||
|
class ToolRegistry:
|
||||||
|
"""工具注册表
|
||||||
|
|
||||||
|
统一管理所有工具的注册、发现和调用。
|
||||||
|
支持工具元数据、权限分类、Hook 拦截。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._tools: dict[str, ToolManifest] = {}
|
||||||
|
self._executors: dict[str, Callable] = {}
|
||||||
|
self._hooks: dict[str, list[HookConfig]] = defaultdict(list)
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self, manifest: ToolManifest, executor: Callable, hooks: list[HookConfig] | None = None
|
||||||
|
) -> None:
|
||||||
|
"""注册工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
manifest: 工具元数据
|
||||||
|
executor: 工具执行函数
|
||||||
|
hooks: 可选的 Hook 配置列表
|
||||||
|
"""
|
||||||
|
if manifest.name in self._tools:
|
||||||
|
raise ValueError(f"Tool already registered: {manifest.name}")
|
||||||
|
|
||||||
|
self._tools[manifest.name] = manifest
|
||||||
|
self._executors[manifest.name] = executor
|
||||||
|
|
||||||
|
if hooks:
|
||||||
|
for hook in hooks:
|
||||||
|
self._hooks[manifest.name].append(hook)
|
||||||
|
|
||||||
|
def unregister(self, name: str) -> bool:
|
||||||
|
"""注销工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 工具名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功注销
|
||||||
|
"""
|
||||||
|
if name not in self._tools:
|
||||||
|
return False
|
||||||
|
|
||||||
|
del self._tools[name]
|
||||||
|
del self._executors[name]
|
||||||
|
if name in self._hooks:
|
||||||
|
del self._hooks[name]
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get(self, name: str) -> ToolManifest | None:
|
||||||
|
"""获取工具元数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 工具名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具元数据,不存在返回 None
|
||||||
|
"""
|
||||||
|
return self._tools.get(name)
|
||||||
|
|
||||||
|
def get_executor(self, name: str) -> Callable | None:
|
||||||
|
"""获取工具执行器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 工具名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具执行函数,不存在返回 None
|
||||||
|
"""
|
||||||
|
return self._executors.get(name)
|
||||||
|
|
||||||
|
def get_hooks(self, name: str) -> list[HookConfig]:
|
||||||
|
"""获取工具的 Hook 配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 工具名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hook 配置列表
|
||||||
|
"""
|
||||||
|
return self._hooks.get(name, [])
|
||||||
|
|
||||||
|
def list_all(self) -> list[ToolManifest]:
|
||||||
|
"""列出所有已注册的工具
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具元数据列表
|
||||||
|
"""
|
||||||
|
return list(self._tools.values())
|
||||||
|
|
||||||
|
def list_by_category(self, category: Any) -> list[ToolManifest]:
|
||||||
|
"""按类别列出工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
category: 工具类别
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
该类别下的所有工具
|
||||||
|
"""
|
||||||
|
return [t for t in self._tools.values() if t.category == category]
|
||||||
|
|
||||||
|
def list_by_permission(self, permission: Any) -> list[ToolManifest]:
|
||||||
|
"""按权限级别列出工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
permission: 权限级别
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
该权限级别下的所有工具
|
||||||
|
"""
|
||||||
|
return [t for t in self._tools.values() if t.permission_class == permission]
|
||||||
|
|
||||||
|
def search_by_tag(self, tag: str) -> list[ToolManifest]:
|
||||||
|
"""按标签搜索工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tag: 标签
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含该标签的工具
|
||||||
|
"""
|
||||||
|
return [t for t in self._tools.values() if tag in t.tags]
|
||||||
|
|
||||||
|
def search_by_name(self, keyword: str) -> list[ToolManifest]:
|
||||||
|
"""按名称关键词搜索工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keyword: 关键词
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
名称包含关键词的工具
|
||||||
|
"""
|
||||||
|
keyword = keyword.lower()
|
||||||
|
return [t for t in self._tools.values() if keyword in t.name.lower()]
|
||||||
|
|
||||||
|
def get_requires_confirmation(self, name: str) -> bool:
|
||||||
|
"""检查工具是否需要确认
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 工具名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否需要确认
|
||||||
|
"""
|
||||||
|
manifest = self._tools.get(name)
|
||||||
|
return manifest.requires_confirmation if manifest else False
|
||||||
|
|
||||||
|
def get_is_streaming(self, name: str) -> bool:
|
||||||
|
"""检查工具是否支持流式执行
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 工具名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否支持流式
|
||||||
|
"""
|
||||||
|
manifest = self._tools.get(name)
|
||||||
|
return manifest.is_streaming if manifest else False
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""清空注册表"""
|
||||||
|
self._tools.clear()
|
||||||
|
self._executors.clear()
|
||||||
|
self._hooks.clear()
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._tools)
|
||||||
|
|
||||||
|
def __contains__(self, name: str) -> bool:
|
||||||
|
return name in self._tools
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._tools.values())
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例实例
|
||||||
|
_global_registry: ToolRegistry | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_registry() -> ToolRegistry:
|
||||||
|
"""获取全局工具注册表单例
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
全局 ToolRegistry 实例
|
||||||
|
"""
|
||||||
|
global _global_registry
|
||||||
|
if _global_registry is None:
|
||||||
|
_global_registry = ToolRegistry()
|
||||||
|
return _global_registry
|
||||||
|
|
||||||
|
|
||||||
|
def reset_tool_registry() -> None:
|
||||||
|
"""重置全局工具注册表(用于测试)"""
|
||||||
|
global _global_registry
|
||||||
|
if _global_registry is not None:
|
||||||
|
_global_registry.clear()
|
||||||
|
_global_registry = None
|
||||||
@@ -2,8 +2,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
@@ -11,21 +9,16 @@ from langchain_core.tools import tool
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from app.agents.context import get_current_user
|
from app.agents.context import get_current_user
|
||||||
|
from app.agents.tools.async_bridge import run_async
|
||||||
from app.database import async_session
|
from app.database import async_session
|
||||||
from app.models.goal import Goal, GoalStatus
|
from app.models.goal import Goal, GoalStatus
|
||||||
from app.models.reminder import Reminder
|
from app.models.reminder import Reminder
|
||||||
from app.models.task import Task, TaskPriority, TaskStatus
|
from app.models.task import Task, TaskPriority, TaskStatus
|
||||||
from app.models.todo import DailyTodo, TodoSource
|
from app.models.todo import DailyTodo, TodoSource
|
||||||
|
|
||||||
_executor = ThreadPoolExecutor(max_workers=4)
|
|
||||||
|
|
||||||
|
|
||||||
def _run_async(coro, timeout: int = 30):
|
def _run_async(coro, timeout: int = 30):
|
||||||
try:
|
return run_async(coro, timeout=timeout)
|
||||||
asyncio.get_running_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
return asyncio.run(coro)
|
|
||||||
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_date(value: str | None) -> date:
|
def _parse_date(value: str | None) -> date:
|
||||||
|
|||||||
@@ -5,25 +5,16 @@ Agent 工具集 - 知识库 & 图谱相关
|
|||||||
由于 LangChain 工具系统是同步的,内部用 run_in_executor 处理 async 逻辑。
|
由于 LangChain 工具系统是同步的,内部用 run_in_executor 处理 async 逻辑。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.agents.context import get_current_user
|
from app.agents.context import get_current_user
|
||||||
|
from app.agents.tools.async_bridge import run_async
|
||||||
from app.database import async_session
|
from app.database import async_session
|
||||||
|
|
||||||
_executor = ThreadPoolExecutor(max_workers=4)
|
|
||||||
|
|
||||||
|
|
||||||
def _run_async(coro, timeout: int = 30):
|
def _run_async(coro, timeout: int = 30):
|
||||||
"""在同步上下文中运行 async 代码"""
|
"""在同步上下文中运行 async 代码"""
|
||||||
try:
|
return run_async(coro, timeout=timeout)
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
future = loop.run_in_executor(_executor, lambda: asyncio.run(coro))
|
|
||||||
return future.result(timeout=timeout)
|
|
||||||
except RuntimeError:
|
|
||||||
return asyncio.run(coro)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
|
|||||||
210
backend/app/agents/tools/streaming.py
Normal file
210
backend/app/agents/tools/streaming.py
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
"""流式工具执行器 - Phase 6.3
|
||||||
|
|
||||||
|
支持流式输出的工具执行器。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
|
from app.agents.tools.hooks.executor import get_hook_executor
|
||||||
|
from app.agents.tools.hooks.types import ExecutionContext
|
||||||
|
from app.agents.tools.registry import get_tool_registry
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingToolExecutor:
|
||||||
|
"""流式工具执行器
|
||||||
|
|
||||||
|
支持:
|
||||||
|
- 普通工具的同步/异步执行
|
||||||
|
- 流式工具的流式输出
|
||||||
|
- Hook 拦截(pre/post/error)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._registry = get_tool_registry()
|
||||||
|
self._hook_executor = get_hook_executor()
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
tool_input: dict[str, Any],
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> Any:
|
||||||
|
"""执行工具(非流式)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
tool_input: 工具输入参数
|
||||||
|
user_id: 用户 ID(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具执行结果
|
||||||
|
"""
|
||||||
|
# 创建执行上下文
|
||||||
|
context = ExecutionContext(
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_input=tool_input,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取工具和执行器
|
||||||
|
manifest = self._registry.get(tool_name)
|
||||||
|
if manifest is None:
|
||||||
|
raise ValueError(f"Tool not found: {tool_name}")
|
||||||
|
|
||||||
|
executor = self._registry.get_executor(tool_name)
|
||||||
|
if executor is None:
|
||||||
|
raise ValueError(f"Executor not found for tool: {tool_name}")
|
||||||
|
|
||||||
|
# 检查是否跳过
|
||||||
|
if await self._hook_executor.execute_skip_check(context):
|
||||||
|
return {"skipped": True, "tool": tool_name}
|
||||||
|
|
||||||
|
# 执行 pre-hooks
|
||||||
|
continue_execution, modified_input = await self._hook_executor.execute_pre_hooks(context)
|
||||||
|
if not continue_execution:
|
||||||
|
return {"pre_hook_aborted": True, "tool": tool_name}
|
||||||
|
|
||||||
|
# 执行工具
|
||||||
|
try:
|
||||||
|
context.start_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
|
# 判断是同步还是异步执行
|
||||||
|
if asyncio.iscoroutinefunction(executor):
|
||||||
|
result = await executor(**modified_input)
|
||||||
|
else:
|
||||||
|
# 同步函数在线程池中执行
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
result = await loop.run_in_executor(None, lambda: executor(**modified_input))
|
||||||
|
|
||||||
|
context.result = result
|
||||||
|
context.end_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
|
# 执行 post-hooks
|
||||||
|
result = await self._hook_executor.execute_post_hooks(context, result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
context.error = e
|
||||||
|
context.end_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
|
# 执行 error-hooks
|
||||||
|
error_result = await self._hook_executor.execute_error_hooks(context, e)
|
||||||
|
if error_result is not None:
|
||||||
|
return error_result
|
||||||
|
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def execute_streaming(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
tool_input: dict[str, Any],
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> AsyncGenerator[dict[str, Any], None]:
|
||||||
|
"""执行流式工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
tool_input: 工具输入参数
|
||||||
|
user_id: 用户 ID(可选)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
流式输出片段
|
||||||
|
"""
|
||||||
|
# 创建执行上下文
|
||||||
|
context = ExecutionContext(
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_input=tool_input,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取工具和执行器
|
||||||
|
manifest = self._registry.get(tool_name)
|
||||||
|
if manifest is None:
|
||||||
|
raise ValueError(f"Tool not found: {tool_name}")
|
||||||
|
|
||||||
|
if not manifest.is_streaming:
|
||||||
|
raise ValueError(f"Tool is not streaming: {tool_name}")
|
||||||
|
|
||||||
|
executor = self._registry.get_executor(tool_name)
|
||||||
|
if executor is None:
|
||||||
|
raise ValueError(f"Executor not found for tool: {tool_name}")
|
||||||
|
|
||||||
|
# 检查是否跳过
|
||||||
|
if await self._hook_executor.execute_skip_check(context):
|
||||||
|
yield {"type": "skipped", "tool": tool_name}
|
||||||
|
return
|
||||||
|
|
||||||
|
# 执行 pre-hooks
|
||||||
|
continue_execution, modified_input = await self._hook_executor.execute_pre_hooks(context)
|
||||||
|
if not continue_execution:
|
||||||
|
yield {"type": "pre_hook_aborted", "tool": tool_name}
|
||||||
|
return
|
||||||
|
|
||||||
|
# 执行流式工具
|
||||||
|
try:
|
||||||
|
context.start_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
|
# 调用执行器(应该返回 AsyncGenerator)
|
||||||
|
result = executor(**modified_input)
|
||||||
|
|
||||||
|
# 如果是 async generator
|
||||||
|
if asyncio.isasyncgen(result):
|
||||||
|
async for chunk in result:
|
||||||
|
yield {"type": "chunk", "data": chunk}
|
||||||
|
else:
|
||||||
|
# 普通协程
|
||||||
|
data = await result
|
||||||
|
yield {"type": "chunk", "data": data}
|
||||||
|
|
||||||
|
context.end_time = asyncio.get_event_loop().time()
|
||||||
|
yield {"type": "done", "tool": tool_name}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
context.error = e
|
||||||
|
context.end_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
|
yield {"type": "error", "error": str(e), "tool": tool_name}
|
||||||
|
|
||||||
|
# 执行 error-hooks
|
||||||
|
await self._hook_executor.execute_error_hooks(context, e)
|
||||||
|
|
||||||
|
async def execute_batch(
|
||||||
|
self,
|
||||||
|
tool_calls: list[dict[str, Any]],
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""批量执行工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_calls: 工具调用列表,每个元素包含 tool_name 和 tool_input
|
||||||
|
user_id: 用户 ID(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
执行结果列表
|
||||||
|
"""
|
||||||
|
tasks = []
|
||||||
|
for call in tool_calls:
|
||||||
|
tool_name = call.get("tool_name") or call.get("name")
|
||||||
|
tool_input = call.get("tool_input") or call.get("input") or {}
|
||||||
|
tasks.append(self.execute(tool_name, tool_input, user_id))
|
||||||
|
|
||||||
|
return await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
_executor: StreamingToolExecutor | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_streaming_executor() -> StreamingToolExecutor:
|
||||||
|
"""获取全局流式执行器
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
全局 StreamingToolExecutor 实例
|
||||||
|
"""
|
||||||
|
global _executor
|
||||||
|
if _executor is None:
|
||||||
|
_executor = StreamingToolExecutor()
|
||||||
|
return _executor
|
||||||
@@ -8,21 +8,13 @@ from langchain_core.tools import tool
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from app.agents.context import get_current_user
|
from app.agents.context import get_current_user
|
||||||
|
from app.agents.tools.async_bridge import run_async
|
||||||
from app.database import async_session
|
from app.database import async_session
|
||||||
from app.models.task import Task, TaskPriority, TaskStatus
|
from app.models.task import Task, TaskPriority, TaskStatus
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
|
|
||||||
_executor = ThreadPoolExecutor(max_workers=4)
|
|
||||||
|
|
||||||
|
|
||||||
def _run_async(coro, timeout: int = 30):
|
def _run_async(coro, timeout: int = 30):
|
||||||
try:
|
return run_async(coro, timeout=timeout)
|
||||||
asyncio.get_running_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
return asyncio.run(coro)
|
|
||||||
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_title(title: str | None, content: str | None) -> str:
|
def _normalize_title(title: str | None, content: str | None) -> str:
|
||||||
|
|||||||
@@ -241,6 +241,10 @@ def normalize_tool_time_arguments(tool_name: str, args: dict, current_datetime_c
|
|||||||
if raw_value and not _is_iso_datetime(raw_value):
|
if raw_value and not _is_iso_datetime(raw_value):
|
||||||
payload = resolve_time_expression_data(raw_value, current_datetime_context=current_datetime_context, prefer="datetime")
|
payload = resolve_time_expression_data(raw_value, current_datetime_context=current_datetime_context, prefer="datetime")
|
||||||
normalized["reminder_at"] = payload["resolved_datetime"]
|
normalized["reminder_at"] = payload["resolved_datetime"]
|
||||||
|
raw_date = normalized.get("date")
|
||||||
|
if isinstance(raw_date, str) and raw_date.strip() and not _is_iso_date(raw_date):
|
||||||
|
payload = resolve_time_expression_data(raw_date, current_datetime_context=current_datetime_context, prefer="date")
|
||||||
|
normalized["date"] = payload["resolved_date"]
|
||||||
return normalized
|
return normalized
|
||||||
|
|
||||||
if tool_name in {"create_schedule_task", "create_task"}:
|
if tool_name in {"create_schedule_task", "create_task"}:
|
||||||
|
|||||||
113
backend/app/agents/transport/remote.py
Normal file
113
backend/app/agents/transport/remote.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""远程传输层 - Phase 10.2"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StructuredMessage:
|
||||||
|
"""结构化消息"""
|
||||||
|
|
||||||
|
type: str # response, event, tool_call, error
|
||||||
|
data: dict[str, Any]
|
||||||
|
session_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteTransport:
|
||||||
|
"""远程传输层
|
||||||
|
|
||||||
|
处理与远程 Agent 的通信。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._connections: dict[str, Any] = {}
|
||||||
|
self._handlers: dict[str, Any] = {}
|
||||||
|
|
||||||
|
async def send_response(self, session_id: str, response: dict[str, Any]) -> bool:
|
||||||
|
"""发送响应
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话 ID
|
||||||
|
response: 响应数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否发送成功
|
||||||
|
"""
|
||||||
|
message = StructuredMessage(
|
||||||
|
type="response",
|
||||||
|
data=response,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
return await self._send(session_id, message)
|
||||||
|
|
||||||
|
async def send_event(self, session_id: str, event: dict[str, Any]) -> bool:
|
||||||
|
"""发送事件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话 ID
|
||||||
|
event: 事件数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否发送成功
|
||||||
|
"""
|
||||||
|
message = StructuredMessage(
|
||||||
|
type="event",
|
||||||
|
data=event,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
return await self._send(session_id, message)
|
||||||
|
|
||||||
|
async def send_tool_call(self, session_id: str, tool_call: dict[str, Any]) -> bool:
|
||||||
|
"""发送工具调用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话 ID
|
||||||
|
tool_call: 工具调用数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否发送成功
|
||||||
|
"""
|
||||||
|
message = StructuredMessage(
|
||||||
|
type="tool_call",
|
||||||
|
data=tool_call,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
return await self._send(session_id, message)
|
||||||
|
|
||||||
|
async def _send(self, session_id: str, message: StructuredMessage) -> bool:
|
||||||
|
"""内部发送方法"""
|
||||||
|
if session_id not in self._connections:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
connection = self._connections[session_id]
|
||||||
|
if hasattr(connection, "send"):
|
||||||
|
await connection.send(json.dumps(message.__dict__))
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def register_handler(self, event_type: str, handler: Any) -> None:
|
||||||
|
"""注册消息处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_type: 事件类型
|
||||||
|
handler: 处理函数
|
||||||
|
"""
|
||||||
|
self._handlers[event_type] = handler
|
||||||
|
|
||||||
|
async def handle_message(self, session_id: str, message: dict[str, Any]) -> None:
|
||||||
|
"""处理收到的消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话 ID
|
||||||
|
message: 消息数据
|
||||||
|
"""
|
||||||
|
msg_type = message.get("type")
|
||||||
|
handler = self._handlers.get(msg_type)
|
||||||
|
if handler:
|
||||||
|
await handler(session_id, message.get("data"))
|
||||||
86
backend/app/agents/transport/structured_io.py
Normal file
86
backend/app/agents/transport/structured_io.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
"""Structured IO for typed input/output - Phase 10.2"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Generic, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StructuredInput:
|
||||||
|
"""Structured input wrapper"""
|
||||||
|
|
||||||
|
skill_name: str
|
||||||
|
parameters: dict[str, Any]
|
||||||
|
metadata: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StructuredOutput:
|
||||||
|
"""Structured output wrapper"""
|
||||||
|
|
||||||
|
skill_name: str
|
||||||
|
result: Any
|
||||||
|
success: bool
|
||||||
|
error: str | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredIO:
|
||||||
|
"""Handles structured input/output for agent communication"""
|
||||||
|
|
||||||
|
def parse_input(self, data: dict[str, Any]) -> StructuredInput:
|
||||||
|
"""Parse structured input from dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary containing skill_name, parameters, and metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StructuredInput instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If required fields are missing
|
||||||
|
"""
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise ValueError("Input data must be a dictionary")
|
||||||
|
|
||||||
|
skill_name = data.get("skill_name")
|
||||||
|
if not skill_name:
|
||||||
|
raise ValueError("Missing required field: skill_name")
|
||||||
|
if not isinstance(skill_name, str):
|
||||||
|
raise ValueError("skill_name must be a string")
|
||||||
|
|
||||||
|
parameters = data.get("parameters")
|
||||||
|
if parameters is None:
|
||||||
|
raise ValueError("Missing required field: parameters")
|
||||||
|
if not isinstance(parameters, dict):
|
||||||
|
raise ValueError("parameters must be a dictionary")
|
||||||
|
|
||||||
|
metadata = data.get("metadata", {})
|
||||||
|
if not isinstance(metadata, dict):
|
||||||
|
raise ValueError("metadata must be a dictionary")
|
||||||
|
|
||||||
|
return StructuredInput(skill_name=skill_name, parameters=parameters, metadata=metadata)
|
||||||
|
|
||||||
|
def format_output(self, output: StructuredOutput) -> dict[str, Any]:
|
||||||
|
"""Format structured output to dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output: StructuredOutput instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary representation of the output
|
||||||
|
"""
|
||||||
|
result = {
|
||||||
|
"skill_name": output.skill_name,
|
||||||
|
"result": output.result,
|
||||||
|
"success": output.success,
|
||||||
|
}
|
||||||
|
|
||||||
|
if output.error is not None:
|
||||||
|
result["error"] = output.error
|
||||||
|
|
||||||
|
if output.metadata is not None:
|
||||||
|
result["metadata"] = output.metadata
|
||||||
|
|
||||||
|
return result
|
||||||
207
backend/app/agents/transport/websocket.py
Normal file
207
backend/app/agents/transport/websocket.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
"""WebSocket 连接管理 - Phase 10.2
|
||||||
|
|
||||||
|
管理 WebSocket 连接的生命周期。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from typing import Any, Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WSConnection:
|
||||||
|
"""WebSocket 连接"""
|
||||||
|
|
||||||
|
session_id: str
|
||||||
|
websocket: Any # WebSocket 连接
|
||||||
|
user_id: str | None = None
|
||||||
|
created_at: float | None = None
|
||||||
|
last_ping: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketManager:
|
||||||
|
"""WebSocket 连接管理器
|
||||||
|
|
||||||
|
管理所有 WebSocket 连接的生命周期。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ping_interval: float = 30.0):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
ping_interval: 心跳间隔(秒)
|
||||||
|
"""
|
||||||
|
self._connections: dict[str, WSConnection] = {}
|
||||||
|
self._handlers: dict[str, Callable] = {}
|
||||||
|
self._ping_interval = ping_interval
|
||||||
|
self._ping_tasks: dict[str, asyncio.Task] = {}
|
||||||
|
|
||||||
|
async def connect(self, session_id: str, websocket: Any, user_id: str | None = None) -> bool:
|
||||||
|
"""建立连接
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话 ID
|
||||||
|
websocket: WebSocket 连接
|
||||||
|
user_id: 用户 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否连接成功
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
if session_id in self._connections:
|
||||||
|
return False
|
||||||
|
|
||||||
|
conn = WSConnection(
|
||||||
|
session_id=session_id,
|
||||||
|
websocket=websocket,
|
||||||
|
user_id=user_id,
|
||||||
|
created_at=time.time(),
|
||||||
|
last_ping=time.time(),
|
||||||
|
)
|
||||||
|
self._connections[session_id] = conn
|
||||||
|
|
||||||
|
# 启动心跳
|
||||||
|
self._ping_tasks[session_id] = asyncio.create_task(self._ping_loop(session_id))
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def disconnect(self, session_id: str) -> bool:
|
||||||
|
"""断开连接
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否断开成功
|
||||||
|
"""
|
||||||
|
if session_id not in self._connections:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 停止心跳
|
||||||
|
if session_id in self._ping_tasks:
|
||||||
|
self._ping_tasks[session_id].cancel()
|
||||||
|
del self._ping_tasks[session_id]
|
||||||
|
|
||||||
|
del self._connections[session_id]
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def send(self, session_id: str, message: dict[str, Any]) -> bool:
|
||||||
|
"""发送消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话 ID
|
||||||
|
message: 消息内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否发送成功
|
||||||
|
"""
|
||||||
|
if session_id not in self._connections:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
conn = self._connections[session_id]
|
||||||
|
await conn.websocket.send_json(message)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def broadcast(self, message: dict[str, Any]) -> int:
|
||||||
|
"""广播消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: 消息内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
发送成功的数量
|
||||||
|
"""
|
||||||
|
count = 0
|
||||||
|
for session_id in list(self._connections.keys()):
|
||||||
|
if await self.send(session_id, message):
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def _ping_loop(self, session_id: str) -> None:
|
||||||
|
"""心跳循环
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话 ID
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
while session_id in self._connections:
|
||||||
|
await asyncio.sleep(self._ping_interval)
|
||||||
|
|
||||||
|
if session_id not in self._connections:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
conn = self._connections[session_id]
|
||||||
|
await conn.websocket.send_json({"type": "ping", "timestamp": time.time()})
|
||||||
|
conn.last_ping = time.time()
|
||||||
|
except Exception:
|
||||||
|
await self.disconnect(session_id)
|
||||||
|
break
|
||||||
|
|
||||||
|
def register_handler(self, event_type: str, handler: Callable) -> None:
|
||||||
|
"""注册消息处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_type: 事件类型
|
||||||
|
handler: 处理函数
|
||||||
|
"""
|
||||||
|
self._handlers[event_type] = handler
|
||||||
|
|
||||||
|
async def handle_message(self, session_id: str, message: dict[str, Any]) -> None:
|
||||||
|
"""处理消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话 ID
|
||||||
|
message: 消息内容
|
||||||
|
"""
|
||||||
|
msg_type = message.get("type")
|
||||||
|
handler = self._handlers.get(msg_type)
|
||||||
|
if handler:
|
||||||
|
await handler(session_id, message.get("data"))
|
||||||
|
|
||||||
|
def get_connection(self, session_id: str) -> WSConnection | None:
|
||||||
|
"""获取连接
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
连接信息或 None
|
||||||
|
"""
|
||||||
|
return self._connections.get(session_id)
|
||||||
|
|
||||||
|
def list_connections(self) -> list[WSConnection]:
|
||||||
|
"""列出所有连接
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
连接列表
|
||||||
|
"""
|
||||||
|
return list(self._connections.values())
|
||||||
|
|
||||||
|
def is_connected(self, session_id: str) -> bool:
|
||||||
|
"""检查是否连接
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否已连接
|
||||||
|
"""
|
||||||
|
return session_id in self._connections
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
_ws_manager: WebSocketManager | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_websocket_manager() -> WebSocketManager:
|
||||||
|
"""获取全局 WebSocket 管理器"""
|
||||||
|
global _ws_manager
|
||||||
|
if _ws_manager is None:
|
||||||
|
_ws_manager = WebSocketManager()
|
||||||
|
return _ws_manager
|
||||||
93
backend/app/agents/verifier.py
Normal file
93
backend/app/agents/verifier.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.agents.schemas.task import AgentTask, TaskResult, TaskResultStatus, VerificationStatus
|
||||||
|
from app.agents.state import AgentState
|
||||||
|
|
||||||
|
|
||||||
|
class VerificationVerdict(BaseModel):
|
||||||
|
status: VerificationStatus
|
||||||
|
summary: str | None = None
|
||||||
|
evidence: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_task_result(
|
||||||
|
task_result: TaskResult | dict[str, Any],
|
||||||
|
*,
|
||||||
|
default_task_id: str | None = None,
|
||||||
|
) -> TaskResult:
|
||||||
|
payload = task_result.model_dump(mode="json") if isinstance(task_result, TaskResult) else dict(task_result or {})
|
||||||
|
normalized_status = payload.get("status")
|
||||||
|
if normalized_status not in {"completed", "failed", "blocked", "passed", "skipped"}:
|
||||||
|
normalized_status = "failed"
|
||||||
|
return TaskResult(
|
||||||
|
task_id=str(payload.get("task_id") or default_task_id or "unknown-task"),
|
||||||
|
status=cast(TaskResultStatus, normalized_status),
|
||||||
|
summary=payload.get("summary"),
|
||||||
|
evidence=list(payload.get("evidence") or []),
|
||||||
|
owner_agent_id=payload.get("owner_agent_id"),
|
||||||
|
parent_task_id=payload.get("parent_task_id"),
|
||||||
|
child_task_ids=list(payload.get("child_task_ids") or []),
|
||||||
|
thread_id=payload.get("thread_id"),
|
||||||
|
message_id=payload.get("message_id"),
|
||||||
|
message_index=payload.get("message_index") if isinstance(payload.get("message_index"), int) else None,
|
||||||
|
interrupt_records=list(payload.get("interrupt_records") or []),
|
||||||
|
recovery_records=list(payload.get("recovery_records") or []),
|
||||||
|
budget_snapshot=payload.get("budget_snapshot") if isinstance(payload.get("budget_snapshot"), dict) else None,
|
||||||
|
next_action=payload.get("next_action"),
|
||||||
|
output_data=payload.get("output_data") if isinstance(payload.get("output_data"), dict) else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_task_result(
|
||||||
|
*,
|
||||||
|
task: AgentTask | dict[str, Any] | None = None,
|
||||||
|
result: TaskResult | dict[str, Any] | None = None,
|
||||||
|
summary: str | None = None,
|
||||||
|
evidence: list[dict[str, Any]] | None = None,
|
||||||
|
status: VerificationStatus | None = None,
|
||||||
|
) -> VerificationVerdict:
|
||||||
|
normalized_result = result.model_dump() if isinstance(result, TaskResult) else dict(result or {})
|
||||||
|
normalized_task = task.model_dump() if isinstance(task, AgentTask) else dict(task or {})
|
||||||
|
normalized_summary = summary or normalized_result.get("summary") or normalized_task.get("result_summary")
|
||||||
|
normalized_evidence = list(evidence or normalized_result.get("evidence") or normalized_task.get("evidence") or [])
|
||||||
|
|
||||||
|
if status is not None:
|
||||||
|
return VerificationVerdict(status=status, summary=normalized_summary, evidence=normalized_evidence)
|
||||||
|
|
||||||
|
normalized_status = normalized_result.get("status")
|
||||||
|
if normalized_status in {"passed", "failed", "skipped"}:
|
||||||
|
inferred_status = normalized_status
|
||||||
|
elif normalized_status == "completed":
|
||||||
|
inferred_status = "passed"
|
||||||
|
elif normalized_status == "blocked":
|
||||||
|
inferred_status = "skipped"
|
||||||
|
elif normalized_result.get("success") is True:
|
||||||
|
inferred_status = "passed"
|
||||||
|
elif normalized_result.get("success") is False:
|
||||||
|
inferred_status = "failed"
|
||||||
|
elif normalized_summary or normalized_evidence:
|
||||||
|
inferred_status = "skipped"
|
||||||
|
else:
|
||||||
|
inferred_status = "failed"
|
||||||
|
normalized_summary = "No verification input available."
|
||||||
|
|
||||||
|
return VerificationVerdict(
|
||||||
|
status=inferred_status,
|
||||||
|
summary=normalized_summary,
|
||||||
|
evidence=normalized_evidence,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_verification_verdict(state: AgentState, verdict: VerificationVerdict) -> AgentState:
|
||||||
|
next_state = dict(state)
|
||||||
|
next_state["verification_status"] = verdict.status
|
||||||
|
next_state["verification_summary"] = verdict.summary
|
||||||
|
next_state["verification_evidence"] = list(verdict.evidence)
|
||||||
|
return AgentState(**next_state)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["VerificationVerdict", "apply_verification_verdict", "normalize_task_result", "verify_task_result"]
|
||||||
@@ -1,10 +1,13 @@
|
|||||||
from sqlalchemy import text
|
from collections.abc import AsyncGenerator
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
|
||||||
from app.config import settings
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
os.makedirs(settings.DATA_DIR, exist_ok=True)
|
os.makedirs(settings.DATA_DIR, exist_ok=True)
|
||||||
|
|
||||||
engine = create_async_engine(
|
engine = create_async_engine(
|
||||||
@@ -24,12 +27,9 @@ class Base(DeclarativeBase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
async def get_db() -> AsyncSession:
|
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||||
async with async_session() as session:
|
async with async_session() as session:
|
||||||
try:
|
yield session
|
||||||
yield session
|
|
||||||
finally:
|
|
||||||
await session.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def init_db():
|
async def init_db():
|
||||||
@@ -37,6 +37,7 @@ async def init_db():
|
|||||||
await conn.run_sync(Base.metadata.create_all)
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
await ensure_log_columns(conn)
|
await ensure_log_columns(conn)
|
||||||
await ensure_message_columns(conn)
|
await ensure_message_columns(conn)
|
||||||
|
await ensure_conversation_columns(conn)
|
||||||
await ensure_document_columns(conn)
|
await ensure_document_columns(conn)
|
||||||
await ensure_user_columns(conn)
|
await ensure_user_columns(conn)
|
||||||
await ensure_forum_columns(conn)
|
await ensure_forum_columns(conn)
|
||||||
@@ -79,6 +80,20 @@ async def ensure_message_columns(conn):
|
|||||||
await conn.execute(text(ddl))
|
await conn.execute(text(ddl))
|
||||||
|
|
||||||
|
|
||||||
|
async def ensure_conversation_columns(conn):
|
||||||
|
rows = await _get_table_info(conn, 'conversations')
|
||||||
|
if not rows:
|
||||||
|
return
|
||||||
|
|
||||||
|
columns = {row[1] for row in rows}
|
||||||
|
required_columns = {
|
||||||
|
'agent_state': "ALTER TABLE conversations ADD COLUMN agent_state JSON",
|
||||||
|
}
|
||||||
|
for column, ddl in required_columns.items():
|
||||||
|
if column not in columns:
|
||||||
|
await conn.execute(text(ddl))
|
||||||
|
|
||||||
|
|
||||||
async def ensure_document_columns(conn):
|
async def ensure_document_columns(conn):
|
||||||
result = await conn.execute(text("PRAGMA table_info(documents)"))
|
result = await conn.execute(text("PRAGMA table_info(documents)"))
|
||||||
rows = result.fetchall()
|
rows = result.fetchall()
|
||||||
|
|||||||
@@ -23,6 +23,11 @@ from app.routers import (
|
|||||||
log_router,
|
log_router,
|
||||||
system_router,
|
system_router,
|
||||||
brain_router,
|
brain_router,
|
||||||
|
hooks_router,
|
||||||
|
plugins_router,
|
||||||
|
marketplace_router,
|
||||||
|
agent_skills_router,
|
||||||
|
agent_sessions_router,
|
||||||
)
|
)
|
||||||
from app.routers.scheduler import router as scheduler_router
|
from app.routers.scheduler import router as scheduler_router
|
||||||
from app.services.scheduler_service import start_scheduler, stop_scheduler, get_scheduler_status
|
from app.services.scheduler_service import start_scheduler, stop_scheduler, get_scheduler_status
|
||||||
@@ -40,15 +45,15 @@ import os
|
|||||||
|
|
||||||
|
|
||||||
INSECURE_SECRET_KEYS = {
|
INSECURE_SECRET_KEYS = {
|
||||||
'change-me-in-production',
|
"change-me-in-production",
|
||||||
'change-me-to-a-random-secret-key',
|
"change-me-to-a-random-secret-key",
|
||||||
'jarvis-secret-key-change-in-production',
|
"jarvis-secret-key-change-in-production",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def validate_startup_security() -> None:
|
def validate_startup_security() -> None:
|
||||||
if not settings.DEBUG and settings.SECRET_KEY in INSECURE_SECRET_KEYS:
|
if not settings.DEBUG and settings.SECRET_KEY in INSECURE_SECRET_KEYS:
|
||||||
raise RuntimeError('SECRET_KEY must be changed before running with DEBUG disabled')
|
raise RuntimeError("SECRET_KEY must be changed before running with DEBUG disabled")
|
||||||
|
|
||||||
|
|
||||||
async def run_startup() -> None:
|
async def run_startup() -> None:
|
||||||
@@ -117,6 +122,11 @@ app.include_router(log_router)
|
|||||||
app.include_router(system_router)
|
app.include_router(system_router)
|
||||||
app.include_router(brain_router)
|
app.include_router(brain_router)
|
||||||
app.include_router(scheduler_router)
|
app.include_router(scheduler_router)
|
||||||
|
app.include_router(hooks_router)
|
||||||
|
app.include_router(plugins_router)
|
||||||
|
app.include_router(marketplace_router)
|
||||||
|
app.include_router(agent_skills_router)
|
||||||
|
app.include_router(agent_sessions_router)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/health")
|
@app.get("/api/health")
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ class Conversation(BaseModel):
|
|||||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||||
title = Column(String(500), nullable=True)
|
title = Column(String(500), nullable=True)
|
||||||
message_count = Column(Integer, default=0)
|
message_count = Column(Integer, default=0)
|
||||||
|
agent_state = Column(JSON, nullable=True)
|
||||||
|
|
||||||
messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan")
|
messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan")
|
||||||
|
|
||||||
|
|||||||
@@ -15,3 +15,8 @@ from app.routers.skill import router as skill_router
|
|||||||
from app.routers.log import router as log_router
|
from app.routers.log import router as log_router
|
||||||
from app.routers.system import router as system_router
|
from app.routers.system import router as system_router
|
||||||
from app.routers.brain import router as brain_router
|
from app.routers.brain import router as brain_router
|
||||||
|
from app.routers.hooks import router as hooks_router
|
||||||
|
from app.routers.plugins import router as plugins_router
|
||||||
|
from app.routers.plugins import _marketplace_router as marketplace_router
|
||||||
|
from app.routers.agent_skills import router as agent_skills_router
|
||||||
|
from app.routers.agent_sessions import router as agent_sessions_router
|
||||||
|
|||||||
@@ -1,12 +1,42 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException
|
from datetime import datetime
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
from app.agents.registry import load_builtin_registry_indexes
|
||||||
|
from app.agents.runtime_metrics import coerce_cost_thresholds, estimate_token_cost, is_cost_budget_warning
|
||||||
from app.models.agent import Agent
|
from app.models.agent import Agent
|
||||||
|
from app.models.conversation import Conversation
|
||||||
from app.models.skill import Skill
|
from app.models.skill import Skill
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.routers.auth import get_current_user
|
from app.routers.auth import get_current_user
|
||||||
from app.schemas.agent import AgentCreate, AgentOut, AgentStats, AgentConfigUpdate, AgentConfigOut
|
from app.schemas.agent import (
|
||||||
|
AgentConfigOut,
|
||||||
|
AgentConfigUpdate,
|
||||||
|
AgentCreate,
|
||||||
|
AgentOut,
|
||||||
|
AgentStats,
|
||||||
|
AgentVisibilityCostByAgentOut,
|
||||||
|
AgentVisibilityCostOut,
|
||||||
|
AgentVisibilityCostSummaryOut,
|
||||||
|
AgentVisibilityEvidenceOut,
|
||||||
|
AgentVisibilityEventsResponse,
|
||||||
|
AgentVisibilityEventOut,
|
||||||
|
AgentVisibilityIsolationOut,
|
||||||
|
AgentVisibilityRuntimeSummaryOut,
|
||||||
|
AgentVisibilityTaskSummaryOut,
|
||||||
|
AgentVisibilityThreadMessageOut,
|
||||||
|
AgentVisibilityThreadOut,
|
||||||
|
AgentVisibilityTopologyNodeOut,
|
||||||
|
AgentVisibilityTopologyOut,
|
||||||
|
AgentVisibilityToolGovernanceItemOut,
|
||||||
|
AgentVisibilityToolGovernanceOut,
|
||||||
|
AgentVisibilityVerifierOut,
|
||||||
|
)
|
||||||
|
from app.services.agent_service import _extract_continuity_snapshot
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/agents", tags=["Agent"])
|
router = APIRouter(prefix="/api/agents", tags=["Agent"])
|
||||||
|
|
||||||
@@ -21,6 +51,295 @@ SUB_COMMANDERS_BY_ROLE = {
|
|||||||
"librarian": ["librarian_retrieval", "librarian_graph"],
|
"librarian": ["librarian_retrieval", "librarian_graph"],
|
||||||
"analyst": ["analyst_progress", "analyst_insights"],
|
"analyst": ["analyst_progress", "analyst_insights"],
|
||||||
}
|
}
|
||||||
|
ALLOWED_AGENT_ROLES = set(DEFAULT_AGENT_ROLES) | {
|
||||||
|
role
|
||||||
|
for sub_roles in SUB_COMMANDERS_BY_ROLE.values()
|
||||||
|
for role in sub_roles
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_visibility_datetime(value: str | None) -> datetime | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=400, detail="时间参数必须是 ISO 8601 格式") from exc
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_visibility_state(
|
||||||
|
conversation_id: str,
|
||||||
|
*,
|
||||||
|
current_user: User,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
result = await db.execute(
|
||||||
|
select(Conversation).where(
|
||||||
|
Conversation.id == conversation_id,
|
||||||
|
Conversation.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
conversation = result.scalar_one_or_none()
|
||||||
|
if conversation is None:
|
||||||
|
raise HTTPException(status_code=404, detail="对话不存在")
|
||||||
|
snapshot = _extract_continuity_snapshot(conversation.agent_state)
|
||||||
|
if snapshot is None:
|
||||||
|
raise HTTPException(status_code=404, detail="当前会话暂无可视化运行时数据")
|
||||||
|
return snapshot
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_event_payload(event: dict[str, Any]) -> AgentVisibilityEventOut:
|
||||||
|
return AgentVisibilityEventOut.model_validate(event)
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_events(
|
||||||
|
events: list[dict[str, Any]],
|
||||||
|
*,
|
||||||
|
agent_id: str | None,
|
||||||
|
thread_id: str | None,
|
||||||
|
event_type: str | None,
|
||||||
|
started_after: datetime | None,
|
||||||
|
ended_before: datetime | None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
filtered: list[dict[str, Any]] = []
|
||||||
|
for event in events:
|
||||||
|
if agent_id and event.get("agent_id") != agent_id:
|
||||||
|
continue
|
||||||
|
if thread_id and event.get("thread_id") != thread_id:
|
||||||
|
continue
|
||||||
|
if event_type and event.get("event_type") != event_type:
|
||||||
|
continue
|
||||||
|
timestamp_raw = event.get("timestamp")
|
||||||
|
timestamp = None
|
||||||
|
if isinstance(timestamp_raw, str):
|
||||||
|
try:
|
||||||
|
timestamp = datetime.fromisoformat(timestamp_raw.replace("Z", "+00:00"))
|
||||||
|
except ValueError:
|
||||||
|
timestamp = None
|
||||||
|
if started_after and timestamp and timestamp < started_after:
|
||||||
|
continue
|
||||||
|
if ended_before and timestamp and timestamp > ended_before:
|
||||||
|
continue
|
||||||
|
filtered.append(event)
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
|
||||||
|
def _summarize_tasks(tasks: list[dict[str, Any]], task_results: list[dict[str, Any]]) -> list[AgentVisibilityTaskSummaryOut]:
|
||||||
|
result_by_task_id = {item.get("task_id"): item for item in task_results}
|
||||||
|
summaries: list[AgentVisibilityTaskSummaryOut] = []
|
||||||
|
for task in tasks:
|
||||||
|
task_id = str(task.get("task_id") or "")
|
||||||
|
result = result_by_task_id.get(task_id) or {}
|
||||||
|
evidence = result.get("evidence") or task.get("evidence") or []
|
||||||
|
summaries.append(
|
||||||
|
AgentVisibilityTaskSummaryOut(
|
||||||
|
task_id=task_id,
|
||||||
|
role=task.get("role"),
|
||||||
|
owner_agent_id=task.get("owner_agent_id") or result.get("owner_agent_id"),
|
||||||
|
status=result.get("status") or task.get("status"),
|
||||||
|
summary=result.get("summary") or task.get("result_summary"),
|
||||||
|
evidence_count=len(evidence),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return summaries
|
||||||
|
|
||||||
|
|
||||||
|
def _build_topology_nodes(
|
||||||
|
state: dict[str, Any],
|
||||||
|
tasks: list[dict[str, Any]],
|
||||||
|
task_results: list[dict[str, Any]],
|
||||||
|
) -> list[AgentVisibilityTopologyNodeOut]:
|
||||||
|
task_counts: dict[str, int] = {}
|
||||||
|
completed_counts: dict[str, int] = {}
|
||||||
|
for task in tasks:
|
||||||
|
owner = str(task.get("owner_agent_id") or "")
|
||||||
|
if owner:
|
||||||
|
task_counts[owner] = task_counts.get(owner, 0) + 1
|
||||||
|
for result in task_results:
|
||||||
|
owner = str(result.get("owner_agent_id") or "")
|
||||||
|
if owner and result.get("status") == "completed":
|
||||||
|
completed_counts[owner] = completed_counts.get(owner, 0) + 1
|
||||||
|
|
||||||
|
root_agent_id = str(state.get("root_agent_id") or state.get("agent_id") or "") or None
|
||||||
|
current_agent = str(state.get("current_agent") or "") or None
|
||||||
|
parent_agent_id = str(state.get("parent_agent_id") or "") or None
|
||||||
|
nodes: dict[str, AgentVisibilityTopologyNodeOut] = {}
|
||||||
|
if root_agent_id:
|
||||||
|
nodes[root_agent_id] = AgentVisibilityTopologyNodeOut(
|
||||||
|
agent_id=root_agent_id,
|
||||||
|
role=root_agent_id.split("-")[0],
|
||||||
|
parent_agent_id=parent_agent_id if root_agent_id != state.get("agent_id") else None,
|
||||||
|
source="root",
|
||||||
|
task_count=task_counts.get(root_agent_id, 0),
|
||||||
|
completed_task_count=completed_counts.get(root_agent_id, 0),
|
||||||
|
)
|
||||||
|
for agent_id in state.get("spawned_agent_ids") or []:
|
||||||
|
agent_id = str(agent_id)
|
||||||
|
nodes[agent_id] = AgentVisibilityTopologyNodeOut(
|
||||||
|
agent_id=agent_id,
|
||||||
|
role=agent_id.split("-")[0],
|
||||||
|
parent_agent_id=root_agent_id,
|
||||||
|
source="spawned",
|
||||||
|
task_count=task_counts.get(agent_id, 0),
|
||||||
|
completed_task_count=completed_counts.get(agent_id, 0),
|
||||||
|
)
|
||||||
|
if current_agent and current_agent not in nodes:
|
||||||
|
nodes[current_agent] = AgentVisibilityTopologyNodeOut(
|
||||||
|
agent_id=current_agent,
|
||||||
|
role=current_agent.split("-")[0],
|
||||||
|
parent_agent_id=None if current_agent == root_agent_id else root_agent_id,
|
||||||
|
source="current",
|
||||||
|
task_count=task_counts.get(current_agent, 0),
|
||||||
|
completed_task_count=completed_counts.get(current_agent, 0),
|
||||||
|
)
|
||||||
|
return list(nodes.values())
|
||||||
|
|
||||||
|
|
||||||
|
def _estimate_runtime_cost(input_tokens: int, output_tokens: int) -> float | None:
|
||||||
|
return estimate_token_cost(input_tokens, output_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_cost_summary(
|
||||||
|
state: dict[str, Any],
|
||||||
|
*,
|
||||||
|
conversation_id: str,
|
||||||
|
) -> AgentVisibilityCostSummaryOut:
|
||||||
|
input_tokens = int(state.get("input_tokens") or 0)
|
||||||
|
output_tokens = int(state.get("output_tokens") or 0)
|
||||||
|
estimated_cost = _estimate_runtime_cost(input_tokens, output_tokens)
|
||||||
|
thresholds = coerce_cost_thresholds(state.get("cost_thresholds"))
|
||||||
|
total_budget_warning = bool(state.get("budget_warning") or False) or is_cost_budget_warning(
|
||||||
|
input_tokens,
|
||||||
|
output_tokens,
|
||||||
|
estimated_cost,
|
||||||
|
thresholds,
|
||||||
|
)
|
||||||
|
|
||||||
|
by_agent_items: list[AgentVisibilityCostByAgentOut] = []
|
||||||
|
for agent_id, payload in dict(state.get("cost_by_agent") or {}).items():
|
||||||
|
payload_dict = dict(payload or {})
|
||||||
|
agent_input_tokens = int(payload_dict.get("input_tokens") or 0)
|
||||||
|
agent_output_tokens = int(payload_dict.get("output_tokens") or 0)
|
||||||
|
agent_estimated_cost = payload_dict.get("estimated_cost")
|
||||||
|
if agent_estimated_cost is None:
|
||||||
|
agent_estimated_cost = _estimate_runtime_cost(agent_input_tokens, agent_output_tokens)
|
||||||
|
by_agent_items.append(
|
||||||
|
AgentVisibilityCostByAgentOut(
|
||||||
|
agent_id=str(payload_dict.get("agent_id") or agent_id),
|
||||||
|
input_tokens=agent_input_tokens,
|
||||||
|
output_tokens=agent_output_tokens,
|
||||||
|
total_tokens=int(payload_dict.get("total_tokens") or (agent_input_tokens + agent_output_tokens)),
|
||||||
|
estimated_cost=agent_estimated_cost,
|
||||||
|
budget_warning=bool(payload_dict.get("budget_warning") or False),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
by_agent_items.sort(key=lambda item: item.total_tokens, reverse=True)
|
||||||
|
|
||||||
|
return AgentVisibilityCostSummaryOut(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
total=AgentVisibilityCostOut(
|
||||||
|
input_tokens=input_tokens,
|
||||||
|
output_tokens=output_tokens,
|
||||||
|
total_tokens=input_tokens + output_tokens,
|
||||||
|
estimated_cost=estimated_cost,
|
||||||
|
budget_warning=total_budget_warning,
|
||||||
|
),
|
||||||
|
thresholds=thresholds,
|
||||||
|
by_agent=by_agent_items,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_tool_governance(
|
||||||
|
state: dict[str, Any],
|
||||||
|
*,
|
||||||
|
conversation_id: str,
|
||||||
|
) -> AgentVisibilityToolGovernanceOut:
|
||||||
|
indexes = load_builtin_registry_indexes()
|
||||||
|
tool_outcomes = [dict(item) for item in state.get("tool_outcomes") or [] if isinstance(item, dict)]
|
||||||
|
usage_count_by_tool: dict[str, int] = {}
|
||||||
|
last_result_preview_by_tool: dict[str, str | None] = {}
|
||||||
|
for item in tool_outcomes:
|
||||||
|
tool_name = str(item.get("tool_name") or "")
|
||||||
|
if tool_name == "search_web":
|
||||||
|
tool_name = "web_search"
|
||||||
|
if not tool_name:
|
||||||
|
continue
|
||||||
|
usage_count_by_tool[tool_name] = usage_count_by_tool.get(tool_name, 0) + 1
|
||||||
|
preview = item.get("result_preview")
|
||||||
|
if isinstance(preview, str) and preview:
|
||||||
|
last_result_preview_by_tool[tool_name] = preview
|
||||||
|
|
||||||
|
items = [
|
||||||
|
AgentVisibilityToolGovernanceItemOut(
|
||||||
|
capability_id=capability.capability_id,
|
||||||
|
tool_name=capability.tool_name,
|
||||||
|
permission_class=capability.permission_class.value,
|
||||||
|
side_effect_scope=capability.side_effect_scope.value,
|
||||||
|
supports_retry=capability.supports_retry,
|
||||||
|
idempotent=capability.idempotent,
|
||||||
|
safe_for_parallel_use=capability.safe_for_parallel_use,
|
||||||
|
requires_confirmation=capability.requires_confirmation,
|
||||||
|
usage_count=usage_count_by_tool.get(capability.tool_name, 0),
|
||||||
|
last_result_preview=last_result_preview_by_tool.get(capability.tool_name),
|
||||||
|
)
|
||||||
|
for capability in indexes.capability_by_id.values()
|
||||||
|
]
|
||||||
|
items.sort(key=lambda item: (-item.usage_count, item.tool_name))
|
||||||
|
|
||||||
|
return AgentVisibilityToolGovernanceOut(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
total_tools=len(items),
|
||||||
|
used_tools=sum(1 for item in items if item.usage_count > 0),
|
||||||
|
items=items,
|
||||||
|
upgrade_candidates=[
|
||||||
|
"worktree_manager",
|
||||||
|
"cost_inspector",
|
||||||
|
"runtime_event_drilldown",
|
||||||
|
"tool_policy_explorer",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_runtime_summary(
|
||||||
|
state: dict[str, Any],
|
||||||
|
*,
|
||||||
|
conversation_id: str,
|
||||||
|
) -> AgentVisibilityRuntimeSummaryOut:
|
||||||
|
tasks = [dict(item) for item in state.get("active_tasks") or []]
|
||||||
|
task_results = [dict(item) for item in state.get("task_results") or []]
|
||||||
|
topology_nodes = _build_topology_nodes(state, tasks, task_results)
|
||||||
|
cost_summary = _build_cost_summary(state, conversation_id=conversation_id)
|
||||||
|
input_tokens = cost_summary.total.input_tokens
|
||||||
|
output_tokens = cost_summary.total.output_tokens
|
||||||
|
recent_events_raw = [dict(item) for item in (state.get("event_trace") or [])[-10:]]
|
||||||
|
isolation_mode = str(state.get("isolation_mode") or "none")
|
||||||
|
|
||||||
|
return AgentVisibilityRuntimeSummaryOut(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
execution_mode=state.get("execution_mode"),
|
||||||
|
current_phase=state.get("current_phase"),
|
||||||
|
current_checkpoint=state.get("current_checkpoint"),
|
||||||
|
phase_history=list(state.get("phase_history") or []),
|
||||||
|
checkpoint_history=list(state.get("checkpoint_history") or []),
|
||||||
|
verifier=AgentVisibilityVerifierOut(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
status=state.get("verification_status"),
|
||||||
|
summary=state.get("verification_summary"),
|
||||||
|
evidence=list(state.get("verification_evidence") or []),
|
||||||
|
),
|
||||||
|
isolation=AgentVisibilityIsolationOut(
|
||||||
|
mode=isolation_mode,
|
||||||
|
isolation_id=state.get("isolation_id"),
|
||||||
|
workspace_path=state.get("isolation_workspace_path"),
|
||||||
|
parent_conversation_id=state.get("isolation_parent_conversation_id") or state.get("parent_conversation_id"),
|
||||||
|
metadata=dict(state.get("isolation_metadata") or {}),
|
||||||
|
),
|
||||||
|
cost=cost_summary.total,
|
||||||
|
topology_node_count=len(topology_nodes),
|
||||||
|
active_task_count=len(tasks),
|
||||||
|
completed_task_count=sum(1 for item in task_results if item.get("status") == "completed"),
|
||||||
|
recent_events=[_coerce_event_payload(item) for item in recent_events_raw],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def record_agent_call(agent_id: str):
|
def record_agent_call(agent_id: str):
|
||||||
@@ -83,6 +402,7 @@ async def get_agent_hierarchy_stats(
|
|||||||
@router.get("/config/{agent_id}", response_model=AgentConfigOut)
|
@router.get("/config/{agent_id}", response_model=AgentConfigOut)
|
||||||
async def get_agent_config(
|
async def get_agent_config(
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
result = await db.execute(select(Agent).where(Agent.role == agent_id))
|
result = await db.execute(select(Agent).where(Agent.role == agent_id))
|
||||||
@@ -172,12 +492,189 @@ async def update_agent_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/visibility/events", response_model=AgentVisibilityEventsResponse)
|
||||||
|
async def get_visibility_events(
|
||||||
|
conversation_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
agent_id: str | None = None,
|
||||||
|
thread_id: str | None = None,
|
||||||
|
event_type: str | None = None,
|
||||||
|
started_after: str | None = None,
|
||||||
|
ended_before: str | None = None,
|
||||||
|
limit: int = Query(default=50, ge=1, le=200),
|
||||||
|
offset: int = Query(default=0, ge=0),
|
||||||
|
):
|
||||||
|
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
|
||||||
|
events = [dict(item) for item in state.get("event_trace") or []]
|
||||||
|
filtered = _filter_events(
|
||||||
|
events,
|
||||||
|
agent_id=agent_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
event_type=event_type,
|
||||||
|
started_after=_parse_visibility_datetime(started_after),
|
||||||
|
ended_before=_parse_visibility_datetime(ended_before),
|
||||||
|
)
|
||||||
|
paged = filtered[offset:offset + limit]
|
||||||
|
return AgentVisibilityEventsResponse(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
total=len(filtered),
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
items=[_coerce_event_payload(item) for item in paged],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/visibility/topology", response_model=AgentVisibilityTopologyOut)
|
||||||
|
async def get_visibility_topology(
|
||||||
|
conversation_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
|
||||||
|
tasks = [dict(item) for item in state.get("active_tasks") or []]
|
||||||
|
task_results = [dict(item) for item in state.get("task_results") or []]
|
||||||
|
nodes = _build_topology_nodes(state, tasks, task_results)
|
||||||
|
root_agent_id = str(state.get("root_agent_id") or state.get("agent_id") or "") or None
|
||||||
|
edges = [
|
||||||
|
{"parent_agent_id": root_agent_id, "child_agent_id": node.agent_id}
|
||||||
|
for node in nodes
|
||||||
|
if node.parent_agent_id and root_agent_id and node.agent_id != root_agent_id
|
||||||
|
]
|
||||||
|
return AgentVisibilityTopologyOut(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
root_agent_id=root_agent_id,
|
||||||
|
current_agent=str(state.get("current_agent") or "") or None,
|
||||||
|
nodes=nodes,
|
||||||
|
edges=edges,
|
||||||
|
tasks=_summarize_tasks(tasks, task_results),
|
||||||
|
task_hierarchy=dict(state.get("task_hierarchy") or {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/visibility/tasks/{task_id}/evidence", response_model=AgentVisibilityEvidenceOut)
|
||||||
|
async def get_visibility_task_evidence(
|
||||||
|
task_id: str,
|
||||||
|
conversation_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
|
||||||
|
tasks = [dict(item) for item in state.get("active_tasks") or []]
|
||||||
|
task = next((item for item in tasks if item.get("task_id") == task_id), None)
|
||||||
|
task_results = [dict(item) for item in state.get("task_results") or []]
|
||||||
|
result = next((item for item in task_results if item.get("task_id") == task_id), None)
|
||||||
|
if task is None and result is None:
|
||||||
|
raise HTTPException(status_code=404, detail="任务不存在")
|
||||||
|
tool_outcomes = [
|
||||||
|
dict(evidence)
|
||||||
|
for evidence in (result or {}).get("evidence") or []
|
||||||
|
if isinstance(evidence, dict) and evidence.get("tool_name")
|
||||||
|
]
|
||||||
|
verification_entry = next(
|
||||||
|
(
|
||||||
|
dict(evidence)
|
||||||
|
for evidence in (result or {}).get("evidence") or []
|
||||||
|
if isinstance(evidence, dict) and evidence.get("type") == "verification"
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
verifier = {
|
||||||
|
"status": (verification_entry or {}).get("status"),
|
||||||
|
"summary": (verification_entry or {}).get("summary"),
|
||||||
|
"evidence": [dict(item) for item in state.get("verification_evidence") or [] if item.get("task_id") == task_id],
|
||||||
|
}
|
||||||
|
return AgentVisibilityEvidenceOut(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
task_id=task_id,
|
||||||
|
task=task,
|
||||||
|
result=result,
|
||||||
|
tool_outcomes=tool_outcomes,
|
||||||
|
verifier=verifier,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/visibility/threads/{thread_id}/messages", response_model=AgentVisibilityThreadOut)
|
||||||
|
async def get_visibility_thread_messages(
|
||||||
|
thread_id: str,
|
||||||
|
conversation_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
|
||||||
|
items = [
|
||||||
|
AgentVisibilityThreadMessageOut.model_validate(item)
|
||||||
|
for item in state.get("message_trace") or []
|
||||||
|
if isinstance(item, dict) and item.get("thread_id") == thread_id
|
||||||
|
]
|
||||||
|
if not items:
|
||||||
|
raise HTTPException(status_code=404, detail="线程不存在")
|
||||||
|
return AgentVisibilityThreadOut(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
total=len(items),
|
||||||
|
items=items,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/visibility/verifier", response_model=AgentVisibilityVerifierOut)
|
||||||
|
async def get_visibility_verifier(
|
||||||
|
conversation_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
|
||||||
|
return AgentVisibilityVerifierOut(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
status=state.get("verification_status"),
|
||||||
|
summary=state.get("verification_summary"),
|
||||||
|
evidence=list(state.get("verification_evidence") or []),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/visibility/runtime-summary", response_model=AgentVisibilityRuntimeSummaryOut)
|
||||||
|
async def get_visibility_runtime_summary(
|
||||||
|
conversation_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
|
||||||
|
return _build_runtime_summary(state, conversation_id=conversation_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/visibility/cost", response_model=AgentVisibilityCostSummaryOut)
|
||||||
|
async def get_visibility_cost(
|
||||||
|
conversation_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
|
||||||
|
return _build_cost_summary(state, conversation_id=conversation_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/visibility/tools", response_model=AgentVisibilityToolGovernanceOut)
|
||||||
|
async def get_visibility_tools(
|
||||||
|
conversation_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
|
||||||
|
return _build_tool_governance(state, conversation_id=conversation_id)
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=AgentOut, status_code=201)
|
@router.post("", response_model=AgentOut, status_code=201)
|
||||||
async def create_agent(
|
async def create_agent(
|
||||||
data: AgentCreate,
|
data: AgentCreate,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
|
if not current_user.is_superuser:
|
||||||
|
raise HTTPException(status_code=403, detail="仅管理员可创建 Agent")
|
||||||
|
if not data.spawn_permission:
|
||||||
|
raise HTTPException(status_code=400, detail="缺少 spawn_permission,禁止直接创建 runtime agent")
|
||||||
|
if data.role not in ALLOWED_AGENT_ROLES:
|
||||||
|
raise HTTPException(status_code=400, detail="不支持的 Agent 角色")
|
||||||
|
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
name=data.name,
|
name=data.name,
|
||||||
role=data.role,
|
role=data.role,
|
||||||
@@ -193,6 +690,7 @@ async def create_agent(
|
|||||||
@router.get("/{agent_id}", response_model=AgentOut)
|
@router.get("/{agent_id}", response_model=AgentOut)
|
||||||
async def get_agent(
|
async def get_agent(
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
result = await db.execute(select(Agent).where(Agent.id == agent_id))
|
result = await db.execute(select(Agent).where(Agent.id == agent_id))
|
||||||
|
|||||||
113
backend/app/routers/agent_sessions.py
Normal file
113
backend/app/routers/agent_sessions.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""Agent Session API 路由 - Phase 10.3"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
|
||||||
|
from app.agents.session.manager import AgentSession, create_agent_session, get_agent_session
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/agent/sessions", tags=["Agent Sessions"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=dict[str, Any])
|
||||||
|
async def create_session(
|
||||||
|
user_id: str | None = None,
|
||||||
|
parent_session_id: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""创建新会话"""
|
||||||
|
session = create_agent_session(
|
||||||
|
user_id=user_id,
|
||||||
|
parent_session_id=parent_session_id,
|
||||||
|
)
|
||||||
|
return await session.initialize()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{session_id}", response_model=dict[str, Any])
|
||||||
|
async def get_session(session_id: str) -> dict[str, Any]:
|
||||||
|
"""获取会话信息"""
|
||||||
|
session = get_agent_session(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
||||||
|
return await session.get_session_summary()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{session_id}/message", response_model=dict[str, str])
|
||||||
|
async def process_message(
|
||||||
|
session_id: str,
|
||||||
|
message: str,
|
||||||
|
response: str,
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""处理消息"""
|
||||||
|
session = get_agent_session(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
||||||
|
await session.process_message(message, response)
|
||||||
|
return {"status": "recorded", "session_id": session_id}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{session_id}/spawn", response_model=dict[str, Any])
|
||||||
|
async def spawn_child_session(
|
||||||
|
session_id: str,
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""创建子会话"""
|
||||||
|
session = get_agent_session(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
||||||
|
child = await session.spawn_child_session(user_id=user_id)
|
||||||
|
return await child.get_session_summary()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{session_id}/history", response_model=dict[str, Any])
|
||||||
|
async def get_session_history(session_id: str) -> dict[str, Any]:
|
||||||
|
"""获取会话历史"""
|
||||||
|
session = get_agent_session(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"history": session.get_history(),
|
||||||
|
"count": len(session._history),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{session_id}/persist", response_model=dict[str, str])
|
||||||
|
async def persist_session(session_id: str) -> dict[str, str]:
|
||||||
|
"""持久化会话"""
|
||||||
|
session = get_agent_session(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
||||||
|
success = await session.persist()
|
||||||
|
if success:
|
||||||
|
return {"status": "persisted", "session_id": session_id}
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to persist session")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{session_id}/metadata", response_model=dict[str, Any])
|
||||||
|
async def set_session_metadata(
|
||||||
|
session_id: str,
|
||||||
|
key: str,
|
||||||
|
value: Any,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""设置会话元数据"""
|
||||||
|
session = get_agent_session(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
||||||
|
session.add_metadata(key, value)
|
||||||
|
await session.persist()
|
||||||
|
return {"key": key, "value": value}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{session_id}/metadata/{key}", response_model=dict[str, Any])
|
||||||
|
async def get_session_metadata(
|
||||||
|
session_id: str,
|
||||||
|
key: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""获取会话元数据"""
|
||||||
|
session = get_agent_session(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
||||||
|
value = session.get_metadata(key)
|
||||||
|
if value is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Metadata key '{key}' not found")
|
||||||
|
return {"key": key, "value": value}
|
||||||
126
backend/app/routers/agent_skills.py
Normal file
126
backend/app/routers/agent_skills.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
"""Agent Skills API 路由 - Phase 9.6
|
||||||
|
|
||||||
|
使用新的 SkillRegistry (file-based) 而不是 DB-based skill 系统。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
|
||||||
|
from app.agents.skills.registry import get_skill_registry, SkillRegistry
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/agent/skills", tags=["Agent Skills"])
|
||||||
|
|
||||||
|
|
||||||
|
def _skill_to_dict(skill) -> dict[str, Any]:
|
||||||
|
"""将 SkillMetadata 转换为字典"""
|
||||||
|
return {
|
||||||
|
"name": skill.name,
|
||||||
|
"description": skill.description,
|
||||||
|
"tags": skill.tags,
|
||||||
|
"triggers": skill.triggers,
|
||||||
|
"enabled": skill.enabled,
|
||||||
|
"content_preview": skill.content[:200] + "..."
|
||||||
|
if len(skill.content) > 200
|
||||||
|
else skill.content,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=dict[str, Any])
|
||||||
|
async def list_agent_skills() -> dict[str, Any]:
|
||||||
|
"""列出所有已加载的 Agent Skills"""
|
||||||
|
registry = get_skill_registry()
|
||||||
|
skills = registry.list_all()
|
||||||
|
return {
|
||||||
|
"skills": [_skill_to_dict(s) for s in skills],
|
||||||
|
"count": len(skills),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/search", response_model=dict[str, Any])
|
||||||
|
async def search_agent_skills(
|
||||||
|
query: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""搜索 Skills"""
|
||||||
|
registry = get_skill_registry()
|
||||||
|
results = registry.search(query)
|
||||||
|
return {
|
||||||
|
"skills": [_skill_to_dict(s) for s in results],
|
||||||
|
"count": len(results),
|
||||||
|
"query": query,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{skill_name}", response_model=dict[str, Any])
|
||||||
|
async def get_agent_skill(skill_name: str) -> dict[str, Any]:
|
||||||
|
"""获取指定 Skill 详情"""
|
||||||
|
registry = get_skill_registry()
|
||||||
|
skill = registry.get_skill(skill_name)
|
||||||
|
if not skill:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
||||||
|
return {
|
||||||
|
"name": skill.name,
|
||||||
|
"description": skill.description,
|
||||||
|
"tags": skill.tags,
|
||||||
|
"triggers": skill.triggers,
|
||||||
|
"enabled": skill.enabled,
|
||||||
|
"content": skill.content,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{skill_name}/context", response_model=dict[str, str])
|
||||||
|
async def get_skill_context(skill_name: str) -> dict[str, str]:
|
||||||
|
"""获取 Skill 上下文字符串"""
|
||||||
|
registry = get_skill_registry()
|
||||||
|
context = registry.get_skill_context([skill_name])
|
||||||
|
if not context:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail=f"Skill '{skill_name}' not found or not enabled"
|
||||||
|
)
|
||||||
|
return {"skill_name": skill_name, "context": context}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/context/batch", response_model=dict[str, str])
|
||||||
|
async def get_batch_skill_context(
|
||||||
|
skill_names: list[str],
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""批量获取多个 Skill 的上下文"""
|
||||||
|
registry = get_skill_registry()
|
||||||
|
context = registry.get_skill_context(skill_names)
|
||||||
|
return {"skills": skill_names, "context": context}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/reload", response_model=dict[str, Any])
|
||||||
|
async def reload_skills(
|
||||||
|
skills_dir: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""重新加载所有 Skills"""
|
||||||
|
registry = get_skill_registry()
|
||||||
|
# 清除旧 skills
|
||||||
|
for name in list(registry._skills.keys()):
|
||||||
|
registry.unregister(name)
|
||||||
|
# 重新加载
|
||||||
|
count = registry.load_all(skills_dir)
|
||||||
|
return {"loaded": count, "message": f"Loaded {count} skills"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{skill_name}/enable", response_model=dict[str, str])
|
||||||
|
async def enable_skill(skill_name: str) -> dict[str, str]:
|
||||||
|
"""启用 Skill"""
|
||||||
|
registry = get_skill_registry()
|
||||||
|
skill = registry.get_skill(skill_name)
|
||||||
|
if not skill:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
||||||
|
skill.enabled = True
|
||||||
|
return {"status": "enabled", "skill_name": skill_name}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{skill_name}/disable", response_model=dict[str, str])
|
||||||
|
async def disable_skill(skill_name: str) -> dict[str, str]:
|
||||||
|
"""禁用 Skill"""
|
||||||
|
registry = get_skill_registry()
|
||||||
|
skill = registry.get_skill(skill_name)
|
||||||
|
if not skill:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
||||||
|
skill.enabled = False
|
||||||
|
return {"status": "disabled", "skill_name": skill_name}
|
||||||
241
backend/app/routers/hooks.py
Normal file
241
backend/app/routers/hooks.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
"""Hook API 路由 - Phase 7.5"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.agents.tools.hooks import HookType
|
||||||
|
from app.agents.tools.hooks.builtins import (
|
||||||
|
AuditLogHook,
|
||||||
|
DangerousConfirmationHook,
|
||||||
|
SecurityScanHook,
|
||||||
|
)
|
||||||
|
from app.agents.tools.hooks.config import (
|
||||||
|
HookConfigEntry,
|
||||||
|
get_hook_config_persistence,
|
||||||
|
)
|
||||||
|
from app.agents.tools.hooks.manager import get_hook_manager
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/hooks", tags=["Hooks"])
|
||||||
|
|
||||||
|
|
||||||
|
class HookInfo(BaseModel):
|
||||||
|
"""Hook 信息"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
hook_type: str
|
||||||
|
description: str
|
||||||
|
builtin: bool
|
||||||
|
|
||||||
|
|
||||||
|
class HookConfigUpdate(BaseModel):
|
||||||
|
"""更新 Hook 配置"""
|
||||||
|
|
||||||
|
entries: list[HookConfigEntry]
|
||||||
|
|
||||||
|
|
||||||
|
class HookConfigResponse(BaseModel):
|
||||||
|
"""Hook 配置响应"""
|
||||||
|
|
||||||
|
entries: list[dict[str, Any]]
|
||||||
|
count: int
|
||||||
|
|
||||||
|
|
||||||
|
class HookStatusResponse(BaseModel):
|
||||||
|
"""Hook 状态响应"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
enabled: bool
|
||||||
|
hook_type: str
|
||||||
|
registered: bool
|
||||||
|
|
||||||
|
|
||||||
|
# 内置 Hook 注册表
|
||||||
|
BUILTIN_HOOKS: dict[str, dict[str, str]] = {
|
||||||
|
"audit_log": {
|
||||||
|
"name": "audit_log",
|
||||||
|
"hook_type": "pre_tool_use,post_tool_use,tool_error",
|
||||||
|
"description": "审计日志 Hook - 记录所有工具调用",
|
||||||
|
"class": "AuditLogHook",
|
||||||
|
},
|
||||||
|
"dangerous_confirmation": {
|
||||||
|
"name": "dangerous_confirmation",
|
||||||
|
"hook_type": "pre_tool_use",
|
||||||
|
"description": "危险操作确认 Hook - 拦截危险工具调用",
|
||||||
|
"class": "DangerousConfirmationHook",
|
||||||
|
},
|
||||||
|
"security_scan": {
|
||||||
|
"name": "security_scan",
|
||||||
|
"hook_type": "post_tool_use",
|
||||||
|
"description": "安全扫描 Hook - 检测敏感信息泄露",
|
||||||
|
"class": "SecurityScanHook",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/available", response_model=list[HookInfo])
|
||||||
|
async def list_available_hooks() -> list[HookInfo]:
|
||||||
|
"""列出所有可用的内置 Hook"""
|
||||||
|
return [
|
||||||
|
HookInfo(
|
||||||
|
name=info["name"],
|
||||||
|
hook_type=info["hook_type"],
|
||||||
|
description=info["description"],
|
||||||
|
builtin=True,
|
||||||
|
)
|
||||||
|
for info in BUILTIN_HOOKS.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config", response_model=HookConfigResponse)
|
||||||
|
async def get_hook_config() -> HookConfigResponse:
|
||||||
|
"""获取当前 Hook 配置"""
|
||||||
|
persistence = get_hook_config_persistence()
|
||||||
|
entries = persistence.load_config()
|
||||||
|
return HookConfigResponse(
|
||||||
|
entries=[vars(e) if isinstance(e, HookConfigEntry) else e for e in entries],
|
||||||
|
count=len(entries),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/config", response_model=HookConfigResponse)
|
||||||
|
async def update_hook_config(
|
||||||
|
entries: list[HookConfigEntry],
|
||||||
|
) -> HookConfigResponse:
|
||||||
|
"""更新 Hook 配置"""
|
||||||
|
persistence = get_hook_config_persistence()
|
||||||
|
success = persistence.save_config(entries)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to save hook config")
|
||||||
|
|
||||||
|
# 应用配置到 HookManager
|
||||||
|
manager = get_hook_manager()
|
||||||
|
manager.clear() # 清除旧配置
|
||||||
|
persistence.apply_config() # 应用新配置
|
||||||
|
|
||||||
|
return HookConfigResponse(
|
||||||
|
entries=[vars(e) if isinstance(e, HookConfigEntry) else e for e in entries],
|
||||||
|
count=len(entries),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/apply-config", response_model=dict[str, Any])
|
||||||
|
async def apply_hook_config() -> dict[str, Any]:
|
||||||
|
"""应用配置文件到 HookManager"""
|
||||||
|
persistence = get_hook_config_persistence()
|
||||||
|
manager = get_hook_manager()
|
||||||
|
manager.clear()
|
||||||
|
count = persistence.apply_config()
|
||||||
|
return {"applied": count, "message": f"Applied {count} hook configurations"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/status", response_model=list[HookStatusResponse])
|
||||||
|
async def get_hook_status() -> list[HookStatusResponse]:
|
||||||
|
"""获取所有已注册 Hook 的状态"""
|
||||||
|
manager = get_hook_manager()
|
||||||
|
all_hooks = manager.list_all()
|
||||||
|
|
||||||
|
# 按名称索引已注册的 hooks
|
||||||
|
registered: dict[str, dict[str, Any]] = {}
|
||||||
|
for hook in all_hooks:
|
||||||
|
registered[hook.name] = {
|
||||||
|
"name": hook.name,
|
||||||
|
"enabled": hook.enabled,
|
||||||
|
"hook_type": hook.hook_type.value,
|
||||||
|
"registered": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 合并内置 Hook 信息
|
||||||
|
result: list[HookStatusResponse] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
|
||||||
|
# 先添加已注册的
|
||||||
|
for hook in all_hooks:
|
||||||
|
result.append(
|
||||||
|
HookStatusResponse(
|
||||||
|
name=hook.name,
|
||||||
|
enabled=hook.enabled,
|
||||||
|
hook_type=hook.hook_type.value,
|
||||||
|
registered=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen.add(hook.name)
|
||||||
|
|
||||||
|
# 再添加内置但未注册的
|
||||||
|
for name, info in BUILTIN_HOOKS.items():
|
||||||
|
if name not in seen:
|
||||||
|
result.append(
|
||||||
|
HookStatusResponse(
|
||||||
|
name=name,
|
||||||
|
enabled=False,
|
||||||
|
hook_type=info["hook_type"],
|
||||||
|
registered=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{name}/enable", response_model=dict[str, str])
|
||||||
|
async def enable_hook(name: str) -> dict[str, str]:
|
||||||
|
"""启用指定 Hook"""
|
||||||
|
manager = get_hook_manager()
|
||||||
|
if manager.enable(name):
|
||||||
|
return {"status": "enabled", "name": name}
|
||||||
|
raise HTTPException(status_code=404, detail=f"Hook '{name}' not found")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{name}/disable", response_model=dict[str, str])
|
||||||
|
async def disable_hook(name: str) -> dict[str, str]:
|
||||||
|
"""禁用指定 Hook"""
|
||||||
|
manager = get_hook_manager()
|
||||||
|
if manager.disable(name):
|
||||||
|
return {"status": "disabled", "name": name}
|
||||||
|
raise HTTPException(status_code=404, detail=f"Hook '{name}' not found")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register-builtin", response_model=dict[str, str])
|
||||||
|
async def register_builtin_hook(
|
||||||
|
name: str,
|
||||||
|
hook_type: str = "pre_tool_use",
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""注册内置 Hook 到 HookManager"""
|
||||||
|
from app.agents.tools.hooks.types import HookDefinition, HookTrigger
|
||||||
|
|
||||||
|
manager = get_hook_manager()
|
||||||
|
|
||||||
|
if name == "audit_log":
|
||||||
|
hook_instance = AuditLogHook()
|
||||||
|
handler = hook_instance.pre_tool_use
|
||||||
|
hook_types = [HookType.PRE_TOOL_USE, HookType.POST_TOOL_USE, HookType.TOOL_ERROR]
|
||||||
|
elif name == "dangerous_confirmation":
|
||||||
|
hook_instance = DangerousConfirmationHook()
|
||||||
|
handler = hook_instance.pre_tool_use
|
||||||
|
hook_types = [HookType.PRE_TOOL_USE]
|
||||||
|
elif name == "security_scan":
|
||||||
|
hook_instance = SecurityScanHook()
|
||||||
|
handler = hook_instance.post_tool_use
|
||||||
|
hook_types = [HookType.POST_TOOL_USE]
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Unknown builtin hook: {name}")
|
||||||
|
|
||||||
|
registered = []
|
||||||
|
for ht in hook_types:
|
||||||
|
hook_def = HookDefinition(
|
||||||
|
name=f"{name}_{ht.value}",
|
||||||
|
hook_type=ht,
|
||||||
|
trigger=HookTrigger(),
|
||||||
|
handler=handler,
|
||||||
|
priority=0,
|
||||||
|
enabled=True,
|
||||||
|
description=f"Built-in {name} hook",
|
||||||
|
)
|
||||||
|
manager.register(hook_def)
|
||||||
|
registered.append(ht.value)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "registered",
|
||||||
|
"name": name,
|
||||||
|
"hook_types": ", ".join(registered),
|
||||||
|
}
|
||||||
222
backend/app/routers/plugins.py
Normal file
222
backend/app/routers/plugins.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
"""Plugin API 路由 - Phase 8.6"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import zipfile
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.agents.plugins import get_plugin_manager, PluginManifest
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/plugins", tags=["Plugins"])
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInfo(BaseModel):
|
||||||
|
"""插件信息"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
version: str
|
||||||
|
description: str
|
||||||
|
author: str
|
||||||
|
enabled: bool
|
||||||
|
main: str
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInstallRequest(BaseModel):
|
||||||
|
"""插件安装请求"""
|
||||||
|
|
||||||
|
plugin_path: str
|
||||||
|
|
||||||
|
|
||||||
|
class PluginListResponse(BaseModel):
|
||||||
|
"""插件列表响应"""
|
||||||
|
|
||||||
|
plugins: list[dict[str, Any]]
|
||||||
|
count: int
|
||||||
|
|
||||||
|
|
||||||
|
# 全局插件市场(简单内存实现)
|
||||||
|
_plugin_marketplace: list[dict[str, str]] = []
|
||||||
|
|
||||||
|
|
||||||
|
def _manifest_to_dict(manifest: PluginManifest, enabled: bool) -> dict[str, Any]:
|
||||||
|
"""将 PluginManifest 转换为字典"""
|
||||||
|
return {
|
||||||
|
"id": manifest.id,
|
||||||
|
"name": manifest.name,
|
||||||
|
"version": manifest.version,
|
||||||
|
"description": manifest.description,
|
||||||
|
"author": manifest.author,
|
||||||
|
"enabled": enabled,
|
||||||
|
"main": manifest.main,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=PluginListResponse)
|
||||||
|
async def list_plugins() -> PluginListResponse:
|
||||||
|
"""列出所有已安装的插件"""
|
||||||
|
manager = get_plugin_manager()
|
||||||
|
plugins = manager.list_plugins()
|
||||||
|
result = []
|
||||||
|
for p in plugins:
|
||||||
|
enabled = manager.is_enabled(p.id)
|
||||||
|
result.append(_manifest_to_dict(p, enabled))
|
||||||
|
return PluginListResponse(plugins=result, count=len(result))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{plugin_id}", response_model=dict[str, Any])
|
||||||
|
async def get_plugin(plugin_id: str) -> dict[str, Any]:
|
||||||
|
"""获取指定插件信息"""
|
||||||
|
manager = get_plugin_manager()
|
||||||
|
manifest = manager.get_plugin(plugin_id)
|
||||||
|
if not manifest:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found")
|
||||||
|
enabled = manager.is_enabled(plugin_id)
|
||||||
|
return _manifest_to_dict(manifest, enabled)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/install", response_model=dict[str, str])
|
||||||
|
async def install_plugin(request: PluginInstallRequest) -> dict[str, str]:
|
||||||
|
"""安装插件"""
|
||||||
|
manager = get_plugin_manager()
|
||||||
|
if not os.path.exists(request.plugin_path):
|
||||||
|
raise HTTPException(status_code=400, detail="Plugin path does not exist")
|
||||||
|
|
||||||
|
if manager.install(request.plugin_path):
|
||||||
|
return {"status": "installed", "path": request.plugin_path}
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to install plugin")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{plugin_id}/enable", response_model=dict[str, str])
|
||||||
|
async def enable_plugin(plugin_id: str) -> dict[str, str]:
|
||||||
|
"""启用插件"""
|
||||||
|
manager = get_plugin_manager()
|
||||||
|
if manager.enable(plugin_id):
|
||||||
|
return {"status": "enabled", "plugin_id": plugin_id}
|
||||||
|
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{plugin_id}/disable", response_model=dict[str, str])
|
||||||
|
async def disable_plugin(plugin_id: str) -> dict[str, str]:
|
||||||
|
"""禁用插件"""
|
||||||
|
manager = get_plugin_manager()
|
||||||
|
if manager.disable(plugin_id):
|
||||||
|
return {"status": "disabled", "plugin_id": plugin_id}
|
||||||
|
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found")
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{plugin_id}", response_model=dict[str, str])
|
||||||
|
async def uninstall_plugin(plugin_id: str) -> dict[str, str]:
|
||||||
|
"""卸载插件"""
|
||||||
|
manager = get_plugin_manager()
|
||||||
|
if manager.uninstall(plugin_id):
|
||||||
|
return {"status": "uninstalled", "plugin_id": plugin_id}
|
||||||
|
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{plugin_id}/reload", response_model=dict[str, str])
|
||||||
|
async def reload_plugin(plugin_id: str) -> dict[str, str]:
|
||||||
|
"""重新加载插件"""
|
||||||
|
manager = get_plugin_manager()
|
||||||
|
if manager.reload(plugin_id):
|
||||||
|
return {"status": "reloaded", "plugin_id": plugin_id}
|
||||||
|
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found")
|
||||||
|
|
||||||
|
|
||||||
|
# === Plugin Marketplace ===
|
||||||
|
|
||||||
|
_marketplace_router = APIRouter(prefix="/api/marketplace", tags=["Plugin Marketplace"])
|
||||||
|
|
||||||
|
|
||||||
|
@_marketplace_router.get("/plugins", response_model=dict[str, Any])
|
||||||
|
async def search_marketplace_plugins(
|
||||||
|
query: str | None = None,
|
||||||
|
category: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""搜索插件市场"""
|
||||||
|
results = _plugin_marketplace
|
||||||
|
if query:
|
||||||
|
results = [
|
||||||
|
p
|
||||||
|
for p in results
|
||||||
|
if query.lower() in p.get("name", "").lower()
|
||||||
|
or query.lower() in p.get("description", "").lower()
|
||||||
|
]
|
||||||
|
if category:
|
||||||
|
results = [p for p in results if p.get("category") == category]
|
||||||
|
return {"plugins": results, "count": len(results)}
|
||||||
|
|
||||||
|
|
||||||
|
@_marketplace_router.get("/plugins/{plugin_id}", response_model=dict[str, Any])
|
||||||
|
async def get_marketplace_plugin(plugin_id: str) -> dict[str, Any]:
|
||||||
|
"""获取市场中的插件详情"""
|
||||||
|
for plugin in _plugin_marketplace:
|
||||||
|
if plugin.get("id") == plugin_id:
|
||||||
|
return plugin
|
||||||
|
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found in marketplace")
|
||||||
|
|
||||||
|
|
||||||
|
@_marketplace_router.post("/plugins", response_model=dict[str, str])
|
||||||
|
async def add_to_marketplace(plugin: dict[str, str]) -> dict[str, str]:
|
||||||
|
"""添加插件到市场(仅供测试/开发)"""
|
||||||
|
if "id" not in plugin or "name" not in plugin:
|
||||||
|
raise HTTPException(status_code=400, detail="Plugin must have id and name")
|
||||||
|
# 移除已存在的同 ID 插件
|
||||||
|
global _plugin_marketplace
|
||||||
|
_plugin_marketplace = [p for p in _plugin_marketplace if p.get("id") != plugin["id"]]
|
||||||
|
_plugin_marketplace.append(plugin)
|
||||||
|
return {"status": "added", "id": plugin["id"]}
|
||||||
|
|
||||||
|
|
||||||
|
@_marketplace_router.post("/plugins/{plugin_id}/download", response_model=dict[str, str])
|
||||||
|
async def download_plugin(plugin_id: str) -> dict[str, str]:
|
||||||
|
"""从市场下载并安装插件"""
|
||||||
|
# Find plugin in marketplace
|
||||||
|
plugin = None
|
||||||
|
for p in _plugin_marketplace:
|
||||||
|
if p.get("id") == plugin_id:
|
||||||
|
plugin = p
|
||||||
|
break
|
||||||
|
|
||||||
|
if not plugin:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail=f"Plugin '{plugin_id}' not found in marketplace"
|
||||||
|
)
|
||||||
|
|
||||||
|
download_url = plugin.get("download_url")
|
||||||
|
if not download_url:
|
||||||
|
raise HTTPException(status_code=400, detail="Plugin has no download URL")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Download the plugin archive
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(download_url, timeout=60.0)
|
||||||
|
response.raise_for_status()
|
||||||
|
archive_content = response.content
|
||||||
|
|
||||||
|
# Extract to temp directory and install
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
archive_path = os.path.join(temp_dir, "plugin.zip")
|
||||||
|
with open(archive_path, "wb") as f:
|
||||||
|
f.write(archive_content)
|
||||||
|
|
||||||
|
extract_dir = os.path.join(temp_dir, "extracted")
|
||||||
|
with zipfile.ZipFile(archive_path, "r") as zf:
|
||||||
|
zf.extractall(extract_dir)
|
||||||
|
|
||||||
|
# Install the plugin
|
||||||
|
manager = get_plugin_manager()
|
||||||
|
if manager.install(extract_dir):
|
||||||
|
return {"status": "installed", "plugin_id": plugin_id}
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to install plugin")
|
||||||
|
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
raise HTTPException(status_code=502, detail=f"Download failed: {str(e)}")
|
||||||
|
except zipfile.BadZipFile:
|
||||||
|
raise HTTPException(status_code=502, detail="Invalid plugin archive")
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"Installation failed: {str(e)}")
|
||||||
@@ -1,4 +1,7 @@
|
|||||||
from pydantic import BaseModel
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class AgentCreate(BaseModel):
|
class AgentCreate(BaseModel):
|
||||||
@@ -6,6 +9,7 @@ class AgentCreate(BaseModel):
|
|||||||
role: str
|
role: str
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
system_prompt: str
|
system_prompt: str
|
||||||
|
spawn_permission: bool = False
|
||||||
|
|
||||||
|
|
||||||
class AgentOut(BaseModel):
|
class AgentOut(BaseModel):
|
||||||
@@ -55,3 +59,163 @@ class AgentConfigOut(BaseModel):
|
|||||||
selected_skill_ids: list[str]
|
selected_skill_ids: list[str]
|
||||||
|
|
||||||
model_config = {"from_attributes": True}
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityEventOut(BaseModel):
|
||||||
|
event_id: str
|
||||||
|
event_type: str
|
||||||
|
timestamp: datetime
|
||||||
|
conversation_id: str | None = None
|
||||||
|
agent_id: str | None = None
|
||||||
|
sub_commander_id: str | None = None
|
||||||
|
task_id: str | None = None
|
||||||
|
parent_task_id: str | None = None
|
||||||
|
child_task_id: str | None = None
|
||||||
|
thread_id: str | None = None
|
||||||
|
message_id: str | None = None
|
||||||
|
interrupt_id: str | None = None
|
||||||
|
recovery_id: str | None = None
|
||||||
|
payload: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
severity: str = "info"
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityEventsResponse(BaseModel):
|
||||||
|
conversation_id: str
|
||||||
|
total: int
|
||||||
|
limit: int
|
||||||
|
offset: int
|
||||||
|
items: list[AgentVisibilityEventOut]
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityTaskSummaryOut(BaseModel):
|
||||||
|
task_id: str
|
||||||
|
role: str | None = None
|
||||||
|
owner_agent_id: str | None = None
|
||||||
|
status: str | None = None
|
||||||
|
summary: str | None = None
|
||||||
|
evidence_count: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityTopologyNodeOut(BaseModel):
|
||||||
|
agent_id: str
|
||||||
|
role: str | None = None
|
||||||
|
parent_agent_id: str | None = None
|
||||||
|
source: str
|
||||||
|
task_count: int = 0
|
||||||
|
completed_task_count: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityTopologyOut(BaseModel):
|
||||||
|
conversation_id: str
|
||||||
|
root_agent_id: str | None = None
|
||||||
|
current_agent: str | None = None
|
||||||
|
nodes: list[AgentVisibilityTopologyNodeOut]
|
||||||
|
edges: list[dict[str, str]]
|
||||||
|
tasks: list[AgentVisibilityTaskSummaryOut]
|
||||||
|
task_hierarchy: dict[str, list[str]] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityEvidenceOut(BaseModel):
|
||||||
|
conversation_id: str
|
||||||
|
task_id: str
|
||||||
|
task: dict[str, Any] | None = None
|
||||||
|
result: dict[str, Any] | None = None
|
||||||
|
tool_outcomes: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
verifier: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityThreadMessageOut(BaseModel):
|
||||||
|
message_id: str
|
||||||
|
thread_id: str
|
||||||
|
from_agent_id: str
|
||||||
|
to_agent_id: str
|
||||||
|
task_id: str | None = None
|
||||||
|
reply_to_message_id: str | None = None
|
||||||
|
message_type: str
|
||||||
|
content_summary: str
|
||||||
|
created_at: datetime
|
||||||
|
payload: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityThreadOut(BaseModel):
|
||||||
|
conversation_id: str
|
||||||
|
thread_id: str
|
||||||
|
total: int
|
||||||
|
items: list[AgentVisibilityThreadMessageOut]
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityVerifierOut(BaseModel):
|
||||||
|
conversation_id: str
|
||||||
|
status: str | None = None
|
||||||
|
summary: str | None = None
|
||||||
|
evidence: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityIsolationOut(BaseModel):
|
||||||
|
mode: str = "none"
|
||||||
|
isolation_id: str | None = None
|
||||||
|
workspace_path: str | None = None
|
||||||
|
parent_conversation_id: str | None = None
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityCostOut(BaseModel):
|
||||||
|
input_tokens: int = 0
|
||||||
|
output_tokens: int = 0
|
||||||
|
total_tokens: int = 0
|
||||||
|
estimated_cost: float | None = None
|
||||||
|
budget_warning: bool = False
|
||||||
|
currency: str = "USD"
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityCostByAgentOut(BaseModel):
|
||||||
|
agent_id: str
|
||||||
|
input_tokens: int = 0
|
||||||
|
output_tokens: int = 0
|
||||||
|
total_tokens: int = 0
|
||||||
|
estimated_cost: float | None = None
|
||||||
|
budget_warning: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityCostSummaryOut(BaseModel):
|
||||||
|
conversation_id: str
|
||||||
|
total: AgentVisibilityCostOut
|
||||||
|
thresholds: dict[str, float] = Field(default_factory=dict)
|
||||||
|
by_agent: list[AgentVisibilityCostByAgentOut] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityToolGovernanceItemOut(BaseModel):
|
||||||
|
capability_id: str
|
||||||
|
tool_name: str
|
||||||
|
permission_class: str
|
||||||
|
side_effect_scope: str
|
||||||
|
supports_retry: bool = False
|
||||||
|
idempotent: bool = False
|
||||||
|
safe_for_parallel_use: bool = False
|
||||||
|
requires_confirmation: bool = False
|
||||||
|
usage_count: int = 0
|
||||||
|
last_result_preview: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityToolGovernanceOut(BaseModel):
|
||||||
|
conversation_id: str
|
||||||
|
total_tools: int = 0
|
||||||
|
used_tools: int = 0
|
||||||
|
items: list[AgentVisibilityToolGovernanceItemOut] = Field(default_factory=list)
|
||||||
|
upgrade_candidates: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentVisibilityRuntimeSummaryOut(BaseModel):
|
||||||
|
conversation_id: str
|
||||||
|
execution_mode: str | None = None
|
||||||
|
current_phase: str | None = None
|
||||||
|
current_checkpoint: str | None = None
|
||||||
|
phase_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
checkpoint_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
verifier: AgentVisibilityVerifierOut
|
||||||
|
isolation: AgentVisibilityIsolationOut
|
||||||
|
cost: AgentVisibilityCostOut
|
||||||
|
topology_node_count: int = 0
|
||||||
|
active_task_count: int = 0
|
||||||
|
completed_task_count: int = 0
|
||||||
|
recent_events: list[AgentVisibilityEventOut] = Field(default_factory=list)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from app.models.conversation import Conversation, Message
|
|||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.agents.graph import get_agent_graph
|
from app.agents.graph import get_agent_graph
|
||||||
from app.agents.context import set_current_user, clear_current_user
|
from app.agents.context import set_current_user, clear_current_user
|
||||||
|
from app.agents.skills.registry import get_skill_registry
|
||||||
from app.services import memory_service
|
from app.services import memory_service
|
||||||
from app.services.brain_service import BrainService
|
from app.services.brain_service import BrainService
|
||||||
from app.services.llm_service import create_llm_from_config, resolve_provider_capabilities
|
from app.services.llm_service import create_llm_from_config, resolve_provider_capabilities
|
||||||
@@ -30,6 +31,56 @@ from app.agents.state import initial_state
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
MEMORY_SECTION_HEADERS = (
|
||||||
|
"【用户记忆】",
|
||||||
|
"【之前对话摘要】",
|
||||||
|
"【知识大脑】",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _split_memory_context_sections(memory_context: str | None) -> dict[str, str]:
|
||||||
|
text = (memory_context or "").strip()
|
||||||
|
if not text:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
sections: dict[str, str] = {}
|
||||||
|
current_header: str | None = None
|
||||||
|
current_lines: list[str] = []
|
||||||
|
|
||||||
|
for line in text.splitlines():
|
||||||
|
stripped = line.strip()
|
||||||
|
if stripped in MEMORY_SECTION_HEADERS:
|
||||||
|
if current_header and current_lines:
|
||||||
|
sections[current_header] = "\n".join(current_lines).strip()
|
||||||
|
current_header = stripped
|
||||||
|
current_lines = [stripped]
|
||||||
|
continue
|
||||||
|
if current_header:
|
||||||
|
current_lines.append(line)
|
||||||
|
|
||||||
|
if current_header and current_lines:
|
||||||
|
sections[current_header] = "\n".join(current_lines).strip()
|
||||||
|
|
||||||
|
return sections
|
||||||
|
|
||||||
|
|
||||||
|
def _derive_role_memory_contexts(memory_context: str | None) -> dict[str, str | None]:
|
||||||
|
sections = _split_memory_context_sections(memory_context)
|
||||||
|
user_memory = sections.get("【用户记忆】")
|
||||||
|
summaries = sections.get("【之前对话摘要】")
|
||||||
|
knowledge = sections.get("【知识大脑】")
|
||||||
|
|
||||||
|
def _join_parts(*parts: str | None) -> str | None:
|
||||||
|
values = [part for part in parts if part]
|
||||||
|
return "\n\n".join(values) if values else None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"schedule_context_summary": _join_parts(user_memory, summaries),
|
||||||
|
"knowledge_context": knowledge,
|
||||||
|
"analysis_report": _join_parts(summaries, knowledge),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _is_streaming_rejection_error(error: Exception, user_llm_config: dict | None) -> bool:
|
def _is_streaming_rejection_error(error: Exception, user_llm_config: dict | None) -> bool:
|
||||||
capabilities = resolve_provider_capabilities(user_llm_config)
|
capabilities = resolve_provider_capabilities(user_llm_config)
|
||||||
error_text = str(error).lower()
|
error_text = str(error).lower()
|
||||||
@@ -45,9 +96,8 @@ def _is_streaming_rejection_error(error: Exception, user_llm_config: dict | None
|
|||||||
]
|
]
|
||||||
|
|
||||||
if isinstance(error, BadRequestError):
|
if isinstance(error, BadRequestError):
|
||||||
return (
|
return getattr(capabilities, "provider", None) not in {"openai", "claude"} and any(
|
||||||
getattr(capabilities, "provider", None) not in {"openai", "claude"}
|
marker in error_text for marker in markers
|
||||||
and any(marker in error_text for marker in markers)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return any(marker in error_text for marker in markers)
|
return any(marker in error_text for marker in markers)
|
||||||
@@ -84,14 +134,165 @@ _CONTINUITY_SNAPSHOT_FIELDS = (
|
|||||||
"current_agent",
|
"current_agent",
|
||||||
"next_step",
|
"next_step",
|
||||||
"agent_trace",
|
"agent_trace",
|
||||||
|
"agent_id",
|
||||||
|
"parent_agent_id",
|
||||||
|
"root_agent_id",
|
||||||
|
"collaboration_depth",
|
||||||
|
"thread_id",
|
||||||
|
"last_message_id",
|
||||||
|
"message_sequence",
|
||||||
|
"spawned_agent_ids",
|
||||||
|
"current_sub_commander",
|
||||||
|
"active_sub_commanders",
|
||||||
|
"sub_commander_trace",
|
||||||
|
"event_trace",
|
||||||
|
"message_trace",
|
||||||
|
"active_tasks",
|
||||||
|
"task_results",
|
||||||
|
"task_hierarchy",
|
||||||
|
"verification_status",
|
||||||
|
"verification_summary",
|
||||||
|
"verification_evidence",
|
||||||
|
"isolation_mode",
|
||||||
|
"isolation_id",
|
||||||
|
"isolation_workspace_path",
|
||||||
|
"isolation_parent_conversation_id",
|
||||||
|
"isolation_metadata",
|
||||||
|
"input_tokens",
|
||||||
|
"output_tokens",
|
||||||
|
"estimated_cost",
|
||||||
|
"budget_warning",
|
||||||
|
"cost_by_agent",
|
||||||
|
"cost_thresholds",
|
||||||
|
"budget_state",
|
||||||
|
"collaboration_budget_history",
|
||||||
|
"current_phase",
|
||||||
|
"phase_history",
|
||||||
|
"current_checkpoint",
|
||||||
|
"checkpoint_history",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_legacy_turn_context(turn_context: Any, current_agent: Any) -> dict[str, Any] | None:
|
||||||
|
if not isinstance(turn_context, dict):
|
||||||
|
return None
|
||||||
|
normalized = dict(turn_context)
|
||||||
|
active_agent = normalized.pop("active_agent", None)
|
||||||
|
active_sub_flow = normalized.pop("active_sub_flow", None)
|
||||||
|
if isinstance(active_agent, str) and active_agent and "active_agent" not in normalized:
|
||||||
|
normalized["active_agent"] = active_agent
|
||||||
|
if (
|
||||||
|
isinstance(active_sub_flow, str)
|
||||||
|
and active_sub_flow
|
||||||
|
and "active_sub_commander" not in normalized
|
||||||
|
):
|
||||||
|
normalized["active_sub_commander"] = active_sub_flow
|
||||||
|
if not normalized.get("active_agent") and isinstance(current_agent, str) and current_agent:
|
||||||
|
normalized["active_agent"] = current_agent
|
||||||
|
return normalized or None
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_legacy_pending_action(pending_action: Any) -> dict[str, Any] | None:
|
||||||
|
if not isinstance(pending_action, dict):
|
||||||
|
return None
|
||||||
|
normalized = dict(pending_action)
|
||||||
|
legacy_action_type = normalized.pop("action_type", None)
|
||||||
|
if legacy_action_type and "type" not in normalized:
|
||||||
|
normalized["type"] = legacy_action_type
|
||||||
|
legacy_agent = normalized.pop("agent", None)
|
||||||
|
legacy_sub_flow = normalized.pop("sub_flow", None)
|
||||||
|
if legacy_agent and "owner_agent" not in normalized:
|
||||||
|
normalized["owner_agent"] = legacy_agent
|
||||||
|
if legacy_sub_flow and "owner_sub_commander" not in normalized:
|
||||||
|
normalized["owner_sub_commander"] = legacy_sub_flow
|
||||||
|
legacy_status = normalized.get("status")
|
||||||
|
if legacy_status == "awaiting_confirmation":
|
||||||
|
normalized["status"] = "pending"
|
||||||
|
elif legacy_status == "awaiting_clarification":
|
||||||
|
normalized["status"] = "blocked_on_clarification"
|
||||||
|
return normalized or None
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_legacy_clarification_context(
|
||||||
|
clarification_context: Any,
|
||||||
|
pending_action: dict[str, Any] | None,
|
||||||
|
current_agent: Any,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
if not isinstance(clarification_context, dict):
|
||||||
|
return None
|
||||||
|
normalized = dict(clarification_context)
|
||||||
|
active_agent = normalized.pop("active_agent", None)
|
||||||
|
sub_flow = normalized.pop("sub_flow", None)
|
||||||
|
awaiting_user_input = normalized.pop("awaiting_user_input", None)
|
||||||
|
if isinstance(active_agent, str) and active_agent and "owning_agent" not in normalized:
|
||||||
|
normalized["owning_agent"] = active_agent
|
||||||
|
if isinstance(sub_flow, str) and sub_flow and "owning_sub_commander" not in normalized:
|
||||||
|
normalized["owning_sub_commander"] = sub_flow
|
||||||
|
if "target_action" not in normalized:
|
||||||
|
target_action = None
|
||||||
|
if pending_action:
|
||||||
|
pending_type = pending_action.get("type")
|
||||||
|
if isinstance(pending_type, str) and pending_type and pending_type != "clarification":
|
||||||
|
target_action = pending_type
|
||||||
|
if target_action is None and isinstance(sub_flow, str) and sub_flow.startswith("create_"):
|
||||||
|
target_action = sub_flow
|
||||||
|
if target_action:
|
||||||
|
normalized["target_action"] = target_action
|
||||||
|
if not normalized.get("owning_agent") and isinstance(current_agent, str) and current_agent:
|
||||||
|
normalized["owning_agent"] = current_agent
|
||||||
|
if awaiting_user_input is True and "status" not in normalized:
|
||||||
|
normalized["status"] = "pending"
|
||||||
|
return normalized or None
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_legacy_continuity_state(
|
||||||
|
continuity_state: Any,
|
||||||
|
clarification_context: dict[str, Any] | None,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
if not isinstance(continuity_state, dict):
|
||||||
|
return None
|
||||||
|
normalized = dict(continuity_state)
|
||||||
|
normalized.pop("active_agent", None)
|
||||||
|
normalized.pop("active_sub_flow", None)
|
||||||
|
legacy_status = normalized.get("status")
|
||||||
|
if legacy_status == "awaiting_clarification":
|
||||||
|
normalized["status"] = "fresh"
|
||||||
|
if clarification_context and "mode" not in normalized:
|
||||||
|
normalized["mode"] = "resume_after_clarification"
|
||||||
|
return normalized or None
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_continuity_snapshot(state: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
normalized = dict(state)
|
||||||
|
current_agent = normalized.get("current_agent")
|
||||||
|
pending_action = _normalize_legacy_pending_action(normalized.get("pending_action"))
|
||||||
|
clarification_context = _normalize_legacy_clarification_context(
|
||||||
|
normalized.get("clarification_context"),
|
||||||
|
pending_action,
|
||||||
|
current_agent,
|
||||||
|
)
|
||||||
|
continuity_state = _normalize_legacy_continuity_state(
|
||||||
|
normalized.get("continuity_state"),
|
||||||
|
clarification_context,
|
||||||
|
)
|
||||||
|
turn_context = _normalize_legacy_turn_context(normalized.get("turn_context"), current_agent)
|
||||||
|
if pending_action is not None:
|
||||||
|
normalized["pending_action"] = pending_action
|
||||||
|
if clarification_context is not None:
|
||||||
|
normalized["clarification_context"] = clarification_context
|
||||||
|
if continuity_state is not None:
|
||||||
|
normalized["continuity_state"] = continuity_state
|
||||||
|
if turn_context is not None:
|
||||||
|
normalized["turn_context"] = turn_context
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
def _build_continuity_snapshot(state: dict[str, Any]) -> dict[str, Any] | None:
|
def _build_continuity_snapshot(state: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
normalized_state = _normalize_continuity_snapshot(state)
|
||||||
snapshot = {
|
snapshot = {
|
||||||
field: state.get(field)
|
field: normalized_state.get(field)
|
||||||
for field in _CONTINUITY_SNAPSHOT_FIELDS
|
for field in _CONTINUITY_SNAPSHOT_FIELDS
|
||||||
if state.get(field) is not None
|
if normalized_state.get(field) is not None
|
||||||
}
|
}
|
||||||
if not snapshot:
|
if not snapshot:
|
||||||
return None
|
return None
|
||||||
@@ -116,7 +317,7 @@ def _extract_continuity_snapshot(payload: Any) -> dict[str, Any] | None:
|
|||||||
return None
|
return None
|
||||||
state = payload.get("state")
|
state = payload.get("state")
|
||||||
if isinstance(state, dict):
|
if isinstance(state, dict):
|
||||||
return state
|
return _normalize_continuity_snapshot(state)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -160,11 +361,32 @@ class AgentService:
|
|||||||
"【当前时间】\n"
|
"【当前时间】\n"
|
||||||
f"- current_time_utc: {reference['current_time_iso']}\n"
|
f"- current_time_utc: {reference['current_time_iso']}\n"
|
||||||
f"- current_date_utc: {reference['current_date_iso']}\n"
|
f"- current_date_utc: {reference['current_date_iso']}\n"
|
||||||
"说明:解析‘今天/明天/后天/本周/下周’等相对时间时,请以 current_time_utc 为准。"
|
"说明:解析'今天/明天/后天/本周/下周'等相对时间时,请以 current_time_utc 为准。"
|
||||||
)
|
)
|
||||||
return context, reference
|
return context, reference
|
||||||
|
|
||||||
async def _get_user_llm_config(self, user_id: str, model_name: str | None = None) -> dict | None:
|
def build_skill_context(self, skill_names: list[str]) -> dict:
|
||||||
|
"""构建 Skills 上下文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_names: Skill 名称列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含 skills 上下文的字典
|
||||||
|
"""
|
||||||
|
registry = get_skill_registry()
|
||||||
|
merged_context = registry.get_skill_context(skill_names)
|
||||||
|
return {
|
||||||
|
"skills_context": merged_context,
|
||||||
|
"skills_metadata": {
|
||||||
|
"skills": skill_names,
|
||||||
|
"count": len(skill_names),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _get_user_llm_config(
|
||||||
|
self, user_id: str, model_name: str | None = None
|
||||||
|
) -> dict | None:
|
||||||
"""获取用户的 LLM 模型配置"""
|
"""获取用户的 LLM 模型配置"""
|
||||||
user = await self.db.get(User, user_id)
|
user = await self.db.get(User, user_id)
|
||||||
if not user or not user.llm_config:
|
if not user or not user.llm_config:
|
||||||
@@ -187,7 +409,7 @@ class AgentService:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def _load_continuity_snapshot(self, conversation: Conversation) -> dict[str, Any] | None:
|
async def _load_continuity_snapshot(self, conversation: Conversation) -> dict[str, Any] | None:
|
||||||
snapshot = _extract_continuity_snapshot(conversation.agent_state)
|
snapshot = _extract_continuity_snapshot(getattr(conversation, "agent_state", None))
|
||||||
if snapshot:
|
if snapshot:
|
||||||
return snapshot
|
return snapshot
|
||||||
|
|
||||||
@@ -214,13 +436,15 @@ class AgentService:
|
|||||||
user_llm_config: dict | None,
|
user_llm_config: dict | None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
state = initial_state(user_id, conversation.id)
|
state = initial_state(user_id, conversation.id)
|
||||||
state.update({
|
state.update(
|
||||||
"messages": [HumanMessage(content=full_message)],
|
{
|
||||||
"memory_context": memory_context,
|
"messages": [HumanMessage(content=full_message)],
|
||||||
"current_datetime_context": current_datetime_context,
|
"memory_context": memory_context,
|
||||||
"current_datetime_reference": current_datetime_reference,
|
"current_datetime_context": current_datetime_context,
|
||||||
"user_llm_config": user_llm_config,
|
"current_datetime_reference": current_datetime_reference,
|
||||||
})
|
"user_llm_config": user_llm_config,
|
||||||
|
}
|
||||||
|
)
|
||||||
previous_snapshot = await self._load_continuity_snapshot(conversation)
|
previous_snapshot = await self._load_continuity_snapshot(conversation)
|
||||||
if previous_snapshot:
|
if previous_snapshot:
|
||||||
state.update(previous_snapshot)
|
state.update(previous_snapshot)
|
||||||
@@ -282,6 +506,7 @@ class AgentService:
|
|||||||
file_context = ""
|
file_context = ""
|
||||||
if file_ids:
|
if file_ids:
|
||||||
from app.services.document_service import DocumentService
|
from app.services.document_service import DocumentService
|
||||||
|
|
||||||
doc_svc = DocumentService(self.db)
|
doc_svc = DocumentService(self.db)
|
||||||
for file_id in file_ids:
|
for file_id in file_ids:
|
||||||
content = await doc_svc.get_document_content(user_id, file_id)
|
content = await doc_svc.get_document_content(user_id, file_id)
|
||||||
@@ -347,7 +572,9 @@ class AgentService:
|
|||||||
set_current_user(user_id)
|
set_current_user(user_id)
|
||||||
try:
|
try:
|
||||||
graph = get_agent_graph()
|
graph = get_agent_graph()
|
||||||
current_datetime_context, current_datetime_reference = self._build_current_datetime_context()
|
current_datetime_context, current_datetime_reference = (
|
||||||
|
self._build_current_datetime_context()
|
||||||
|
)
|
||||||
|
|
||||||
state = await self._build_agent_state(
|
state = await self._build_agent_state(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -358,8 +585,11 @@ class AgentService:
|
|||||||
current_datetime_reference=current_datetime_reference,
|
current_datetime_reference=current_datetime_reference,
|
||||||
user_llm_config=user_llm_config,
|
user_llm_config=user_llm_config,
|
||||||
)
|
)
|
||||||
|
state.update(_derive_role_memory_contexts(memory_ctx))
|
||||||
|
|
||||||
yield self._build_progress_event("thinking", "Jarvis 正在分析请求", agent="master", step="理解你的问题")
|
yield self._build_progress_event(
|
||||||
|
"thinking", "Jarvis 正在分析请求", agent="master", step="理解你的问题"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for event in graph.astream_events(state, version="v2"):
|
async for event in graph.astream_events(state, version="v2"):
|
||||||
@@ -368,7 +598,13 @@ class AgentService:
|
|||||||
metadata = event.get("metadata", {})
|
metadata = event.get("metadata", {})
|
||||||
data = event.get("data", {})
|
data = event.get("data", {})
|
||||||
|
|
||||||
if kind == "on_chain_start" and event_name in {"master", "schedule_planner", "executor", "librarian", "analyst"}:
|
if kind == "on_chain_start" and event_name in {
|
||||||
|
"master",
|
||||||
|
"schedule_planner",
|
||||||
|
"executor",
|
||||||
|
"librarian",
|
||||||
|
"analyst",
|
||||||
|
}:
|
||||||
stage_map = {
|
stage_map = {
|
||||||
"master": ("thinking", "Jarvis 正在理解请求"),
|
"master": ("thinking", "Jarvis 正在理解请求"),
|
||||||
"schedule_planner": ("planning", "Jarvis 正在编排日程"),
|
"schedule_planner": ("planning", "Jarvis 正在编排日程"),
|
||||||
@@ -376,9 +612,13 @@ class AgentService:
|
|||||||
"librarian": ("tool", "Jarvis 正在检索知识"),
|
"librarian": ("tool", "Jarvis 正在检索知识"),
|
||||||
"analyst": ("thinking", "Jarvis 正在分析信息"),
|
"analyst": ("thinking", "Jarvis 正在分析信息"),
|
||||||
}
|
}
|
||||||
stage, label = stage_map.get(event_name, ("thinking", "Jarvis 正在思考"))
|
stage, label = stage_map.get(
|
||||||
yield self._build_progress_event(stage, label, agent=event_name, step=label)
|
event_name, ("thinking", "Jarvis 正在思考")
|
||||||
|
)
|
||||||
|
yield self._build_progress_event(
|
||||||
|
stage, label, agent=event_name, step=label
|
||||||
|
)
|
||||||
|
|
||||||
elif kind == "on_tool_start":
|
elif kind == "on_tool_start":
|
||||||
yield self._build_progress_event(
|
yield self._build_progress_event(
|
||||||
"tool",
|
"tool",
|
||||||
@@ -387,7 +627,7 @@ class AgentService:
|
|||||||
tool_name=event_name,
|
tool_name=event_name,
|
||||||
step=f"正在执行 {event_name}",
|
step=f"正在执行 {event_name}",
|
||||||
)
|
)
|
||||||
|
|
||||||
elif kind == "on_tool_end":
|
elif kind == "on_tool_end":
|
||||||
tool_result = data.get("output")
|
tool_result = data.get("output")
|
||||||
step = f"已完成 {event_name}"
|
step = f"已完成 {event_name}"
|
||||||
@@ -400,14 +640,16 @@ class AgentService:
|
|||||||
tool_name=event_name,
|
tool_name=event_name,
|
||||||
step=step,
|
step=step,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif kind == "on_chat_model_stream":
|
elif kind == "on_chat_model_stream":
|
||||||
chunk = data.get("chunk")
|
chunk = data.get("chunk")
|
||||||
content = _coerce_event_text(getattr(chunk, "content", "") if chunk else "")
|
content = _coerce_event_text(
|
||||||
|
getattr(chunk, "content", "") if chunk else ""
|
||||||
|
)
|
||||||
if content:
|
if content:
|
||||||
collected += content
|
collected += content
|
||||||
yield {"type": "chunk", "content": content}
|
yield {"type": "chunk", "content": content}
|
||||||
|
|
||||||
elif kind == "on_chain_end":
|
elif kind == "on_chain_end":
|
||||||
output = data.get("output")
|
output = data.get("output")
|
||||||
final_resp = None
|
final_resp = None
|
||||||
@@ -422,7 +664,9 @@ class AgentService:
|
|||||||
|
|
||||||
elif kind == "on_chat_model_end":
|
elif kind == "on_chat_model_end":
|
||||||
output = data.get("output")
|
output = data.get("output")
|
||||||
final_content = _coerce_event_text(getattr(output, "content", "") if output else "")
|
final_content = _coerce_event_text(
|
||||||
|
getattr(output, "content", "") if output else ""
|
||||||
|
)
|
||||||
if final_content:
|
if final_content:
|
||||||
final_text = final_content
|
final_text = final_content
|
||||||
if final_text != collected:
|
if final_text != collected:
|
||||||
@@ -431,12 +675,16 @@ class AgentService:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if _is_streaming_rejection_error(e, user_llm_config) and not collected:
|
if _is_streaming_rejection_error(e, user_llm_config) and not collected:
|
||||||
yield self._build_progress_event("responding", "Jarvis 正在生成回复", agent="master", step="fallback")
|
yield self._build_progress_event(
|
||||||
|
"responding", "Jarvis 正在生成回复", agent="master", step="fallback"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
result_state = await graph.ainvoke(state)
|
result_state = await graph.ainvoke(state)
|
||||||
if isinstance(result_state, dict):
|
if isinstance(result_state, dict):
|
||||||
state.update(result_state)
|
state.update(result_state)
|
||||||
fallback_content = result_state.get("final_response") or str(result_state.get("messages", [AIMessage(content="")])[-1].content)
|
fallback_content = result_state.get("final_response") or str(
|
||||||
|
result_state.get("messages", [AIMessage(content="")])[-1].content
|
||||||
|
)
|
||||||
collected = str(fallback_content)
|
collected = str(fallback_content)
|
||||||
yield {"type": "chunk", "content": collected}
|
yield {"type": "chunk", "content": collected}
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -460,11 +708,24 @@ class AgentService:
|
|||||||
if collected:
|
if collected:
|
||||||
assistant_msg.content = collected
|
assistant_msg.content = collected
|
||||||
continuity_snapshot = _build_continuity_snapshot(state or {})
|
continuity_snapshot = _build_continuity_snapshot(state or {})
|
||||||
assistant_msg.attachments = ([{
|
assistant_msg.attachments = (
|
||||||
"kind": "agent_continuity_state",
|
[
|
||||||
**continuity_snapshot,
|
{
|
||||||
}] if continuity_snapshot else None)
|
"kind": "agent_continuity_state",
|
||||||
conv.agent_state = continuity_snapshot
|
**continuity_snapshot,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
if continuity_snapshot
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
conv.agent_state = (
|
||||||
|
{
|
||||||
|
"kind": "agent_continuity_state",
|
||||||
|
**continuity_snapshot,
|
||||||
|
}
|
||||||
|
if continuity_snapshot
|
||||||
|
else None
|
||||||
|
)
|
||||||
await BrainService(self.db).create_event(
|
await BrainService(self.db).create_event(
|
||||||
user_id,
|
user_id,
|
||||||
**_build_assistant_event_payload(collected),
|
**_build_assistant_event_payload(collected),
|
||||||
@@ -542,12 +803,16 @@ class AgentService:
|
|||||||
importance_signal=1.0,
|
importance_signal=1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
memory_ctx = await memory_service.build_memory_context(self.db, user_id, conversation_id, message)
|
memory_ctx = await memory_service.build_memory_context(
|
||||||
|
self.db, user_id, conversation_id, message
|
||||||
|
)
|
||||||
|
|
||||||
set_current_user(user_id)
|
set_current_user(user_id)
|
||||||
try:
|
try:
|
||||||
graph = get_agent_graph()
|
graph = get_agent_graph()
|
||||||
current_datetime_context, current_datetime_reference = self._build_current_datetime_context()
|
current_datetime_context, current_datetime_reference = (
|
||||||
|
self._build_current_datetime_context()
|
||||||
|
)
|
||||||
state = await self._build_agent_state(
|
state = await self._build_agent_state(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
conversation=conv,
|
conversation=conv,
|
||||||
@@ -557,9 +822,11 @@ class AgentService:
|
|||||||
current_datetime_reference=current_datetime_reference,
|
current_datetime_reference=current_datetime_reference,
|
||||||
user_llm_config=user_llm_config,
|
user_llm_config=user_llm_config,
|
||||||
)
|
)
|
||||||
|
state.update(_derive_role_memory_contexts(memory_ctx))
|
||||||
result_state = await graph.ainvoke(state)
|
result_state = await graph.ainvoke(state)
|
||||||
response_content = result_state.get("final_response") or str(result_state.get("messages", [AIMessage(content="")])[-1].content)
|
response_content = result_state.get("final_response") or str(
|
||||||
|
result_state.get("messages", [AIMessage(content="")])[-1].content
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("agent_chat_simple_failed")
|
logger.exception("agent_chat_simple_failed")
|
||||||
response_content = "抱歉,发生错误。"
|
response_content = "抱歉,发生错误。"
|
||||||
@@ -580,12 +847,27 @@ class AgentService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assistant_msg.content = response_content
|
assistant_msg.content = response_content
|
||||||
continuity_snapshot = _build_continuity_snapshot(result_state) if 'result_state' in locals() else None
|
continuity_snapshot = (
|
||||||
assistant_msg.attachments = ([{
|
_build_continuity_snapshot(result_state) if "result_state" in locals() else None
|
||||||
"kind": "agent_continuity_state",
|
)
|
||||||
**continuity_snapshot,
|
assistant_msg.attachments = (
|
||||||
}] if continuity_snapshot else None)
|
[
|
||||||
conv.agent_state = continuity_snapshot
|
{
|
||||||
|
"kind": "agent_continuity_state",
|
||||||
|
**continuity_snapshot,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
if continuity_snapshot
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
conv.agent_state = (
|
||||||
|
{
|
||||||
|
"kind": "agent_continuity_state",
|
||||||
|
**continuity_snapshot,
|
||||||
|
}
|
||||||
|
if continuity_snapshot
|
||||||
|
else None
|
||||||
|
)
|
||||||
await self.db.commit()
|
await self.db.commit()
|
||||||
await self.db.refresh(assistant_msg)
|
await self.db.refresh(assistant_msg)
|
||||||
|
|
||||||
|
|||||||
@@ -4,12 +4,15 @@ Jarvis 记忆系统 (基于 Mem0)
|
|||||||
底层使用 Mem0 实现事实提取、时间线、矛盾解决和遗忘机制
|
底层使用 Mem0 实现事实提取、时间线、矛盾解决和遗忘机制
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
import re
|
||||||
|
from datetime import UTC, datetime
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
from sqlalchemy import select, desc, func
|
from sqlalchemy import select, desc, func
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from app.models.conversation import Conversation, Message
|
from app.models.conversation import Conversation, Message
|
||||||
|
from app.models.memory import UserMemory
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.services.brain_service import BrainService
|
from app.services.brain_service import BrainService
|
||||||
from app.config import settings as _settings
|
from app.config import settings as _settings
|
||||||
@@ -23,6 +26,9 @@ except ImportError:
|
|||||||
Memory = None
|
Memory = None
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def _get_user_embedding_config(db: AsyncSession, user_id: str) -> dict | None:
|
async def _get_user_embedding_config(db: AsyncSession, user_id: str) -> dict | None:
|
||||||
"""从用户配置中获取 embedding 模型配置"""
|
"""从用户配置中获取 embedding 模型配置"""
|
||||||
result = await db.execute(select(User).where(User.id == user_id))
|
result = await db.execute(select(User).where(User.id == user_id))
|
||||||
@@ -296,6 +302,23 @@ async def extract_user_memories(
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_memory_query_tokens(query: str) -> list[str]:
|
||||||
|
normalized_query = (query or "").lower()
|
||||||
|
tokens = [token for token in re.findall(r"[a-z0-9]+", normalized_query) if len(token) >= 3]
|
||||||
|
|
||||||
|
for chunk in re.findall(r"[\u4e00-\u9fff]+", query or ""):
|
||||||
|
stripped_chunk = chunk.strip()
|
||||||
|
if len(stripped_chunk) >= 4:
|
||||||
|
tokens.append(stripped_chunk)
|
||||||
|
if len(stripped_chunk) > 6:
|
||||||
|
tokens.extend(
|
||||||
|
stripped_chunk[index:index + 4]
|
||||||
|
for index in range(len(stripped_chunk) - 3)
|
||||||
|
)
|
||||||
|
|
||||||
|
return list(dict.fromkeys(tokens))
|
||||||
|
|
||||||
|
|
||||||
async def recall_user_memories(
|
async def recall_user_memories(
|
||||||
db: AsyncSession,
|
db: AsyncSession,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
@@ -304,7 +327,7 @@ async def recall_user_memories(
|
|||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
根据当前输入召回相关的用户记忆。
|
根据当前输入召回相关的用户记忆。
|
||||||
使用 Mem0 的语义搜索。
|
使用 Mem0 的语义搜索;如果 Mem0 不可用或失败,则回退到本地 UserMemory。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
mem0 = await get_mem0(db, user_id)
|
mem0 = await get_mem0(db, user_id)
|
||||||
@@ -313,10 +336,56 @@ async def recall_user_memories(
|
|||||||
filters={"user_id": user_id},
|
filters={"user_id": user_id},
|
||||||
limit=top_k,
|
limit=top_k,
|
||||||
)
|
)
|
||||||
return results.get("results", [])
|
mem0_results = results.get("results", [])
|
||||||
|
if mem0_results:
|
||||||
|
return mem0_results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Mem0 search error: {e}")
|
print(f"Mem0 search error: {e}")
|
||||||
return []
|
|
||||||
|
query_tokens = _extract_memory_query_tokens(query)
|
||||||
|
statement = select(UserMemory).where(UserMemory.user_id == user_id)
|
||||||
|
result = await db.execute(statement.order_by(UserMemory.importance.desc(), UserMemory.created_at.desc()))
|
||||||
|
fallback_memories = list(result.scalars().all())
|
||||||
|
|
||||||
|
if _contains_hint(_normalize_query(query), MEMORY_QUERY_HINTS) or _matches_memory_query_pattern(_normalize_query(query)):
|
||||||
|
return fallback_memories[:top_k]
|
||||||
|
|
||||||
|
if query_tokens:
|
||||||
|
matched_memories = [
|
||||||
|
memory for memory in fallback_memories
|
||||||
|
if any(token in (memory.content or '').lower() for token in query_tokens)
|
||||||
|
]
|
||||||
|
return matched_memories[:top_k]
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def _mark_memories_recalled(db: AsyncSession, memories: list[UserMemory]) -> None:
|
||||||
|
recalled_at = datetime.now(UTC)
|
||||||
|
updated = False
|
||||||
|
for memory in memories:
|
||||||
|
memory.is_recalled = True
|
||||||
|
memory.recall_count = (memory.recall_count or 0) + 1
|
||||||
|
memory.last_recalled_at = recalled_at
|
||||||
|
updated = True
|
||||||
|
if updated:
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_tolerated_section(
|
||||||
|
db: AsyncSession,
|
||||||
|
section_name: str,
|
||||||
|
builder,
|
||||||
|
) -> str:
|
||||||
|
try:
|
||||||
|
return await builder()
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"[MemoryService] %s失败,继续构建剩余上下文",
|
||||||
|
section_name,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
async def get_user_profile(db: AsyncSession, user_id: str) -> dict:
|
async def get_user_profile(db: AsyncSession, user_id: str) -> dict:
|
||||||
@@ -339,6 +408,131 @@ async def get_user_profile(db: AsyncSession, user_id: str) -> dict:
|
|||||||
|
|
||||||
# ———— 记忆组装: 供 Agent 使用的上下文 ————
|
# ———— 记忆组装: 供 Agent 使用的上下文 ————
|
||||||
|
|
||||||
|
MEMORY_QUERY_HINTS = (
|
||||||
|
"记住",
|
||||||
|
"记下",
|
||||||
|
"记一下",
|
||||||
|
"记着",
|
||||||
|
"提醒",
|
||||||
|
"偏好",
|
||||||
|
"习惯",
|
||||||
|
)
|
||||||
|
MEMORY_QUERY_PATTERNS = (
|
||||||
|
re.compile(r"\bremember\s+(?:that\s+)?i\b"),
|
||||||
|
)
|
||||||
|
GROUNDING_QUERY_HINTS = (
|
||||||
|
"根据文档",
|
||||||
|
"严格根据",
|
||||||
|
"只根据",
|
||||||
|
"文档内容",
|
||||||
|
"grounded",
|
||||||
|
"strictly based on",
|
||||||
|
"based on the document",
|
||||||
|
"based on the docs",
|
||||||
|
"document only",
|
||||||
|
"docs only",
|
||||||
|
"only use the document",
|
||||||
|
"only use the docs",
|
||||||
|
)
|
||||||
|
AVOID_USER_MEMORY_HINTS = (
|
||||||
|
"不要结合我的个人偏好",
|
||||||
|
"不要结合个人偏好",
|
||||||
|
"不要结合偏好",
|
||||||
|
"不要结合我的记忆",
|
||||||
|
"不要结合记忆",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_query(text: str) -> str:
|
||||||
|
return text.strip().lower()
|
||||||
|
|
||||||
|
|
||||||
|
def _contains_hint(text: str, hints: tuple[str, ...]) -> bool:
|
||||||
|
return any(hint in text for hint in hints)
|
||||||
|
|
||||||
|
|
||||||
|
def _matches_memory_query_pattern(text: str) -> bool:
|
||||||
|
return any(pattern.search(text) for pattern in MEMORY_QUERY_PATTERNS)
|
||||||
|
|
||||||
|
|
||||||
|
def _should_include_user_memories(query: str) -> bool:
|
||||||
|
normalized_query = _normalize_query(query)
|
||||||
|
if _contains_hint(normalized_query, GROUNDING_QUERY_HINTS):
|
||||||
|
return False
|
||||||
|
if _contains_hint(normalized_query, AVOID_USER_MEMORY_HINTS):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _should_include_summaries(query: str) -> bool:
|
||||||
|
normalized_query = _normalize_query(query)
|
||||||
|
if _contains_hint(normalized_query, GROUNDING_QUERY_HINTS):
|
||||||
|
return False
|
||||||
|
if _contains_hint(normalized_query, MEMORY_QUERY_HINTS):
|
||||||
|
return False
|
||||||
|
if _matches_memory_query_pattern(normalized_query):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def _build_user_memory_section(
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: str,
|
||||||
|
current_query: str,
|
||||||
|
) -> str:
|
||||||
|
memories = await recall_user_memories(db, user_id, current_query, top_k=5)
|
||||||
|
if not memories:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
recalled_user_memories: list[UserMemory] = []
|
||||||
|
for memory in memories:
|
||||||
|
if isinstance(memory, UserMemory):
|
||||||
|
memory_text = memory.content
|
||||||
|
memory_type = memory.memory_type
|
||||||
|
recalled_user_memories.append(memory)
|
||||||
|
else:
|
||||||
|
memory_text = memory.get("memory", memory.get("text", ""))
|
||||||
|
memory_type = memory.get("memory_type")
|
||||||
|
|
||||||
|
if not memory_text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if memory_type:
|
||||||
|
lines.append(f" [{memory_type}] {memory_text}")
|
||||||
|
else:
|
||||||
|
lines.append(f" - {memory_text}")
|
||||||
|
|
||||||
|
if not lines:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if recalled_user_memories:
|
||||||
|
await _mark_memories_recalled(db, recalled_user_memories)
|
||||||
|
return "【用户记忆】\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
async def _build_summary_section(db: AsyncSession, conversation_id: str) -> str:
|
||||||
|
summaries = await get_summaries(db, conversation_id)
|
||||||
|
if not summaries:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
recent = summaries[-2:]
|
||||||
|
lines = [f"[对话摘要{i + 1}] {summary.summary_text}" for i, summary in enumerate(recent)]
|
||||||
|
return "【之前对话摘要】\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
async def _build_brain_section(
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: str,
|
||||||
|
current_query: str,
|
||||||
|
) -> str:
|
||||||
|
brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3)
|
||||||
|
if not brain_memories:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
lines = [f"- {memory.title}: {memory.content}" for memory in brain_memories]
|
||||||
|
return "【知识大脑】\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
async def build_memory_context(
|
async def build_memory_context(
|
||||||
db: AsyncSession,
|
db: AsyncSession,
|
||||||
@@ -350,30 +544,33 @@ async def build_memory_context(
|
|||||||
构建完整的记忆上下文字符串,
|
构建完整的记忆上下文字符串,
|
||||||
供注入到 Agent system prompt 中使用。
|
供注入到 Agent system prompt 中使用。
|
||||||
"""
|
"""
|
||||||
parts = []
|
parts: list[str] = []
|
||||||
|
|
||||||
memories = await recall_user_memories(db, user_id, current_query, top_k=5)
|
if _should_include_user_memories(current_query):
|
||||||
if memories:
|
user_memory_section = await _run_tolerated_section(
|
||||||
lines = []
|
db,
|
||||||
for m in memories:
|
"用户记忆召回",
|
||||||
memory_text = m.get("memory", m.get("text", ""))
|
lambda: _build_user_memory_section(db, user_id, current_query),
|
||||||
if memory_text:
|
)
|
||||||
lines.append(f" - {memory_text}")
|
if user_memory_section:
|
||||||
if lines:
|
parts.append(user_memory_section)
|
||||||
parts.append("【用户记忆】\n" + "\n".join(lines))
|
|
||||||
|
|
||||||
summaries = await get_summaries(db, conversation_id)
|
if _should_include_summaries(current_query):
|
||||||
if summaries:
|
summary_section = await _run_tolerated_section(
|
||||||
recent = summaries[-2:]
|
db,
|
||||||
lines = [f"[对话摘要{i + 1}] {s.summary_text}" for i, s in enumerate(recent)]
|
"对话摘要加载",
|
||||||
parts.append("【之前对话摘要】\n" + "\n".join(lines))
|
lambda: _build_summary_section(db, conversation_id),
|
||||||
|
)
|
||||||
|
if summary_section:
|
||||||
|
parts.append(summary_section)
|
||||||
|
|
||||||
brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3)
|
brain_section = await _run_tolerated_section(
|
||||||
if brain_memories:
|
db,
|
||||||
lines = []
|
"知识大脑召回",
|
||||||
for memory in brain_memories:
|
lambda: _build_brain_section(db, user_id, current_query),
|
||||||
lines.append(f"- {memory.title}: {memory.content}")
|
)
|
||||||
parts.append("【知识大脑】\n" + "\n".join(lines))
|
if brain_section:
|
||||||
|
parts.append(brain_section)
|
||||||
|
|
||||||
if not parts:
|
if not parts:
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
167
backend/tests/backend/app/agents/test_agent_schemas.py
Normal file
167
backend/tests/backend/app/agents/test_agent_schemas.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
from app.agents.schemas.event import AgentEvent
|
||||||
|
from app.agents.schemas.task import AgentTask, CollaborationBudget, InterruptRecord, RecoveryRecord, TaskResult
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_task_accepts_day1_fields():
|
||||||
|
task = AgentTask(
|
||||||
|
task_id="task-1",
|
||||||
|
title="Verify foundation",
|
||||||
|
status="in_progress",
|
||||||
|
owner_agent_id="executor",
|
||||||
|
role="verifier",
|
||||||
|
goal="check output",
|
||||||
|
expected_evidence=[{"type": "assertion"}],
|
||||||
|
evidence=[{"type": "log"}],
|
||||||
|
result_summary="running",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert task.task_id == "task-1"
|
||||||
|
assert task.owner_agent_id == "executor"
|
||||||
|
assert task.status == "in_progress"
|
||||||
|
assert task.expected_evidence == [{"type": "assertion"}]
|
||||||
|
assert task.evidence == [{"type": "log"}]
|
||||||
|
assert task.result_summary == "running"
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_task_accepts_day3_runtime_fields():
|
||||||
|
task = AgentTask(
|
||||||
|
task_id="task-2",
|
||||||
|
title="Recover interrupted collaboration",
|
||||||
|
owner_agent_id="executor",
|
||||||
|
parent_task_id="task-1",
|
||||||
|
child_task_ids=["task-2a"],
|
||||||
|
thread_id="thread-1",
|
||||||
|
message_id="msg-1",
|
||||||
|
message_index=2,
|
||||||
|
interrupt_records=[
|
||||||
|
InterruptRecord(
|
||||||
|
interrupt_id="interrupt-1",
|
||||||
|
reason="manual stop",
|
||||||
|
requested_by="coordinator",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
recovery_records=[
|
||||||
|
RecoveryRecord(
|
||||||
|
recovery_id="recovery-1",
|
||||||
|
source_interrupt_id="interrupt-1",
|
||||||
|
resumed_from_task_id="task-2",
|
||||||
|
resumed_from_thread_id="thread-1",
|
||||||
|
strategy="resume_from_checkpoint",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
collaboration_budget=CollaborationBudget(
|
||||||
|
mode="collaboration",
|
||||||
|
max_parallel_tasks=2,
|
||||||
|
remaining_parallel_tasks=1,
|
||||||
|
max_tool_calls=4,
|
||||||
|
remaining_tool_calls=3,
|
||||||
|
max_iterations=5,
|
||||||
|
remaining_iterations=4,
|
||||||
|
escalation_threshold=1,
|
||||||
|
metadata={"max_spawn_depth": 2},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert task.parent_task_id == "task-1"
|
||||||
|
assert task.child_task_ids == ["task-2a"]
|
||||||
|
assert task.thread_id == "thread-1"
|
||||||
|
assert task.message_id == "msg-1"
|
||||||
|
assert task.message_index == 2
|
||||||
|
assert task.interrupt_records[0].interrupt_id == "interrupt-1"
|
||||||
|
assert task.recovery_records[0].recovery_id == "recovery-1"
|
||||||
|
assert task.collaboration_budget.mode == "collaboration"
|
||||||
|
assert task.collaboration_budget.metadata == {"max_spawn_depth": 2}
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_event_accepts_day1_fields():
|
||||||
|
event = AgentEvent(
|
||||||
|
event_id="evt-1",
|
||||||
|
event_type="agent.verify.completed",
|
||||||
|
conversation_id="conv-1",
|
||||||
|
agent_id="executor",
|
||||||
|
sub_commander_id="executor_tasks",
|
||||||
|
task_id="task-1",
|
||||||
|
payload={"status": "passed"},
|
||||||
|
severity="info",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event_id == "evt-1"
|
||||||
|
assert event.event_type == "agent.verify.completed"
|
||||||
|
assert event.conversation_id == "conv-1"
|
||||||
|
assert event.payload == {"status": "passed"}
|
||||||
|
assert event.severity == "info"
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_event_accepts_day3_trace_fields():
|
||||||
|
event = AgentEvent(
|
||||||
|
event_id="evt-2",
|
||||||
|
event_type="agent.collaboration.budget.updated",
|
||||||
|
conversation_id="conv-1",
|
||||||
|
agent_id="coordinator",
|
||||||
|
task_id="task-2",
|
||||||
|
parent_task_id="task-1",
|
||||||
|
child_task_id="task-2a",
|
||||||
|
thread_id="thread-1",
|
||||||
|
message_id="msg-3",
|
||||||
|
interrupt_id="interrupt-1",
|
||||||
|
recovery_id="recovery-1",
|
||||||
|
payload={"remaining_parallel_tasks": 1},
|
||||||
|
severity="warning",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.parent_task_id == "task-1"
|
||||||
|
assert event.child_task_id == "task-2a"
|
||||||
|
assert event.thread_id == "thread-1"
|
||||||
|
assert event.message_id == "msg-3"
|
||||||
|
assert event.interrupt_id == "interrupt-1"
|
||||||
|
assert event.recovery_id == "recovery-1"
|
||||||
|
assert event.severity == "warning"
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_result_supports_collaboration_result_fields():
|
||||||
|
result = TaskResult(
|
||||||
|
task_id="task-1",
|
||||||
|
status="completed",
|
||||||
|
summary="retrieval finished",
|
||||||
|
evidence=[{"type": "source"}],
|
||||||
|
owner_agent_id="librarian",
|
||||||
|
next_action="handoff_to_analyst",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.status == "completed"
|
||||||
|
assert result.owner_agent_id == "librarian"
|
||||||
|
assert result.next_action == "handoff_to_analyst"
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_result_supports_day3_thread_budget_and_recovery_fields():
|
||||||
|
result = TaskResult(
|
||||||
|
task_id="task-2",
|
||||||
|
status="blocked",
|
||||||
|
owner_agent_id="executor",
|
||||||
|
parent_task_id="task-1",
|
||||||
|
child_task_ids=["task-2a"],
|
||||||
|
thread_id="thread-1",
|
||||||
|
message_id="msg-4",
|
||||||
|
message_index=4,
|
||||||
|
interrupt_records=[{"interrupt_id": "interrupt-1", "reason": "budget exceeded"}],
|
||||||
|
recovery_records=[{"recovery_id": "recovery-1", "strategy": "resume_after_budget_reset"}],
|
||||||
|
budget_snapshot=CollaborationBudget(
|
||||||
|
mode="collaboration",
|
||||||
|
max_parallel_tasks=2,
|
||||||
|
remaining_parallel_tasks=0,
|
||||||
|
max_tool_calls=4,
|
||||||
|
remaining_tool_calls=0,
|
||||||
|
),
|
||||||
|
next_action="resume_after_budget_reset",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.parent_task_id == "task-1"
|
||||||
|
assert result.child_task_ids == ["task-2a"]
|
||||||
|
assert result.thread_id == "thread-1"
|
||||||
|
assert result.message_id == "msg-4"
|
||||||
|
assert result.message_index == 4
|
||||||
|
assert result.interrupt_records[0].interrupt_id == "interrupt-1"
|
||||||
|
assert result.recovery_records[0].recovery_id == "recovery-1"
|
||||||
|
assert result.budget_snapshot.mode == "collaboration"
|
||||||
|
assert result.budget_snapshot.remaining_parallel_tasks == 0
|
||||||
|
assert result.next_action == "resume_after_budget_reset"
|
||||||
File diff suppressed because it is too large
Load Diff
317
backend/tests/backend/app/agents/test_graph_system_messages.py
Normal file
317
backend/tests/backend/app/agents/test_graph_system_messages.py
Normal file
@@ -0,0 +1,317 @@
|
|||||||
|
import sys
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
|
sys.modules.setdefault("trafilatura", Mock())
|
||||||
|
|
||||||
|
from app.agents.graph import _build_system_messages, _run_sub_commander
|
||||||
|
from app.agents.state import AgentRole
|
||||||
|
|
||||||
|
|
||||||
|
def _base_state(message: str, user_llm_config: dict | None = None) -> dict:
|
||||||
|
return {
|
||||||
|
"messages": [HumanMessage(content=message)],
|
||||||
|
"user_id": "u1",
|
||||||
|
"conversation_id": "c1",
|
||||||
|
"current_agent": AgentRole.MASTER,
|
||||||
|
"active_agents": [AgentRole.MASTER],
|
||||||
|
"current_sub_commander": None,
|
||||||
|
"active_sub_commanders": [],
|
||||||
|
"sub_commander_trace": [],
|
||||||
|
"pending_tasks": [],
|
||||||
|
"completed_tasks": [],
|
||||||
|
"tool_calls": [],
|
||||||
|
"last_tool_result": None,
|
||||||
|
"action_results": [],
|
||||||
|
"created_entities": [],
|
||||||
|
"tool_strategy_used": None,
|
||||||
|
"provider_capabilities": None,
|
||||||
|
"fallback_parse_error": None,
|
||||||
|
"knowledge_context": None,
|
||||||
|
"graph_context": None,
|
||||||
|
"schedule_context_summary": None,
|
||||||
|
"plan": None,
|
||||||
|
"plan_steps": [],
|
||||||
|
"analysis_report": None,
|
||||||
|
"final_response": None,
|
||||||
|
"should_respond": True,
|
||||||
|
"memory_context": "memory context",
|
||||||
|
"current_datetime_context": "CURRENT_TIME: 2026-03-28T12:00:00+08:00",
|
||||||
|
"current_datetime_reference": {
|
||||||
|
"current_time_iso": "2026-03-28T12:00:00+08:00",
|
||||||
|
"current_date_iso": "2026-03-28",
|
||||||
|
"timezone": "UTC",
|
||||||
|
},
|
||||||
|
"user_llm_config": user_llm_config,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FakeTool:
|
||||||
|
def __init__(self, name: str, result: str):
|
||||||
|
self.name = name
|
||||||
|
self.result = result
|
||||||
|
self.invocations: list[dict] = []
|
||||||
|
|
||||||
|
def invoke(self, args: dict):
|
||||||
|
self.invocations.append(args)
|
||||||
|
return self.result
|
||||||
|
|
||||||
|
|
||||||
|
class SingleSystemMessageLLM:
|
||||||
|
def __init__(self):
|
||||||
|
self.calls = 0
|
||||||
|
self.system_message_counts: list[int] = []
|
||||||
|
self._jarvis_provider_capabilities = SimpleNamespace(
|
||||||
|
provider="minimax",
|
||||||
|
supports_native_tools=False,
|
||||||
|
preferred_tool_strategy="json_fallback",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def ainvoke(self, messages):
|
||||||
|
self.calls += 1
|
||||||
|
self.system_message_counts.append(
|
||||||
|
sum(1 for message in messages if getattr(message, "type", None) == "system")
|
||||||
|
)
|
||||||
|
if self.system_message_counts[-1] != 1:
|
||||||
|
raise AssertionError(
|
||||||
|
f"expected exactly one system message, got {self.system_message_counts[-1]}"
|
||||||
|
)
|
||||||
|
if self.calls == 1:
|
||||||
|
return AIMessage(
|
||||||
|
content=(
|
||||||
|
'{"mode":"tool_call","tool_calls":[{"name":"create_reminder",'
|
||||||
|
'"arguments":{"title":"blanket","reminder_at":"\\u660e\\u5929 09:00"}}]}'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return AIMessage(content="created reminder for blanket")
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_system_messages_includes_structured_continuity_summary():
|
||||||
|
state = _base_state("创建")
|
||||||
|
state["pending_action"] = {
|
||||||
|
"type": "schedule_creation",
|
||||||
|
"summary": "为周报安排明天下午提醒",
|
||||||
|
"status": "pending",
|
||||||
|
}
|
||||||
|
state["routing_decision"] = {
|
||||||
|
"target_agent": AgentRole.SCHEDULE_PLANNER.value,
|
||||||
|
"reason": "continue_pending_action",
|
||||||
|
}
|
||||||
|
state["continuity_state"] = {"status": "fresh"}
|
||||||
|
|
||||||
|
messages = _build_system_messages(
|
||||||
|
state,
|
||||||
|
"manager prompt",
|
||||||
|
AgentRole.SCHEDULE_PLANNER,
|
||||||
|
"schedule_planning",
|
||||||
|
)
|
||||||
|
|
||||||
|
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
|
||||||
|
assert "pending_action" in system_text
|
||||||
|
assert "schedule_creation" in system_text
|
||||||
|
assert "continue_pending_action" in system_text
|
||||||
|
assert "为周报安排明天下午提醒" in system_text
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_system_messages_skips_structured_continuity_when_pending_action_is_not_pending():
|
||||||
|
state = _base_state("创建")
|
||||||
|
state["pending_action"] = {
|
||||||
|
"type": "schedule_creation",
|
||||||
|
"summary": "为周报安排明天下午提醒",
|
||||||
|
"status": "completed",
|
||||||
|
}
|
||||||
|
state["routing_decision"] = {
|
||||||
|
"target_agent": AgentRole.SCHEDULE_PLANNER.value,
|
||||||
|
"reason": "continue_pending_action",
|
||||||
|
}
|
||||||
|
state["continuity_state"] = {"status": "fresh"}
|
||||||
|
|
||||||
|
messages = _build_system_messages(
|
||||||
|
state,
|
||||||
|
"manager prompt",
|
||||||
|
AgentRole.SCHEDULE_PLANNER,
|
||||||
|
"schedule_planning",
|
||||||
|
)
|
||||||
|
|
||||||
|
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
|
||||||
|
assert "structured_continuity" not in system_text
|
||||||
|
assert "continue_pending_action" not in system_text
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_system_messages_skips_structured_continuity_when_routing_reason_is_not_continuation():
|
||||||
|
state = _base_state("创建")
|
||||||
|
state["pending_action"] = {
|
||||||
|
"type": "schedule_creation",
|
||||||
|
"summary": "为周报安排明天下午提醒",
|
||||||
|
"status": "pending",
|
||||||
|
}
|
||||||
|
state["routing_decision"] = {
|
||||||
|
"target_agent": AgentRole.SCHEDULE_PLANNER.value,
|
||||||
|
"reason": "initial_schedule_detection",
|
||||||
|
}
|
||||||
|
state["continuity_state"] = {"status": "fresh"}
|
||||||
|
|
||||||
|
messages = _build_system_messages(
|
||||||
|
state,
|
||||||
|
"manager prompt",
|
||||||
|
AgentRole.SCHEDULE_PLANNER,
|
||||||
|
"schedule_planning",
|
||||||
|
)
|
||||||
|
|
||||||
|
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
|
||||||
|
assert "structured_continuity" not in system_text
|
||||||
|
assert "continue_pending_action" not in system_text
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_system_messages_skips_structured_continuity_when_routing_decision_missing():
|
||||||
|
state = _base_state("创建")
|
||||||
|
state["pending_action"] = {
|
||||||
|
"type": "schedule_creation",
|
||||||
|
"summary": "为周报安排明天下午提醒",
|
||||||
|
}
|
||||||
|
state["routing_decision"] = None
|
||||||
|
|
||||||
|
messages = _build_system_messages(
|
||||||
|
state,
|
||||||
|
"manager prompt",
|
||||||
|
AgentRole.SCHEDULE_PLANNER,
|
||||||
|
"schedule_planning",
|
||||||
|
)
|
||||||
|
|
||||||
|
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
|
||||||
|
assert "pending_action" not in system_text
|
||||||
|
assert "schedule_creation" not in system_text
|
||||||
|
assert "为周报安排明天下午提醒" not in system_text
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_system_messages_skips_stale_structured_continuity_for_unrelated_new_request():
|
||||||
|
state = _base_state(
|
||||||
|
"帮我搜索 Rust 异步 trait 最佳实践",
|
||||||
|
{
|
||||||
|
"provider": "openai",
|
||||||
|
"model": "MiniMax-M2.7-highspeed",
|
||||||
|
"base_url": "https://api.minimaxi.com/v1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
state["current_agent"] = AgentRole.SCHEDULE_PLANNER
|
||||||
|
state["pending_action"] = {
|
||||||
|
"type": "schedule_creation",
|
||||||
|
"summary": "为周报安排明天下午提醒",
|
||||||
|
"status": "pending",
|
||||||
|
}
|
||||||
|
state["routing_decision"] = {
|
||||||
|
"target_agent": AgentRole.SCHEDULE_PLANNER.value,
|
||||||
|
"reason": "continue_pending_action",
|
||||||
|
}
|
||||||
|
state["continuity_state"] = {
|
||||||
|
"status": "stale",
|
||||||
|
"override_reason": "new_explicit_request",
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = _build_system_messages(
|
||||||
|
state,
|
||||||
|
"manager prompt",
|
||||||
|
AgentRole.SCHEDULE_PLANNER,
|
||||||
|
"schedule_planning",
|
||||||
|
)
|
||||||
|
|
||||||
|
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
|
||||||
|
assert "structured_continuity" not in system_text
|
||||||
|
assert "pending_action" not in system_text
|
||||||
|
assert "continue_pending_action" not in system_text
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_system_messages_uses_role_scoped_context_instead_of_raw_memory_blob():
|
||||||
|
state = _base_state("帮我搜索 Rust 异步 trait 最佳实践")
|
||||||
|
state["memory_context"] = "【用户记忆】\n- 用户喜欢燕麦拿铁。\n\n【之前对话摘要】\n[对话摘要1] 之前聊过提醒。\n\n【知识大脑】\n- Rust Async: trait object 需要 pin。"
|
||||||
|
state["schedule_context_summary"] = "【用户记忆】\n- 用户喜欢燕麦拿铁。\n\n【之前对话摘要】\n[对话摘要1] 之前聊过提醒。"
|
||||||
|
state["knowledge_context"] = "【知识大脑】\n- Rust Async: trait object 需要 pin。"
|
||||||
|
state["analysis_report"] = "【之前对话摘要】\n[对话摘要1] 之前聊过提醒。\n\n【知识大脑】\n- Rust Async: trait object 需要 pin。"
|
||||||
|
|
||||||
|
messages = _build_system_messages(
|
||||||
|
state,
|
||||||
|
"manager prompt",
|
||||||
|
AgentRole.LIBRARIAN,
|
||||||
|
"librarian_retrieval",
|
||||||
|
)
|
||||||
|
|
||||||
|
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
|
||||||
|
assert "角色上下文" in system_text
|
||||||
|
assert "【知识大脑】" in system_text
|
||||||
|
assert "Rust Async" in system_text
|
||||||
|
assert "用户喜欢燕麦拿铁" not in system_text
|
||||||
|
assert "记忆上下文" not in system_text
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_system_messages_keeps_fresh_structured_continuity_for_matching_followup():
|
||||||
|
state = _base_state(
|
||||||
|
"创建",
|
||||||
|
{
|
||||||
|
"provider": "openai",
|
||||||
|
"model": "MiniMax-M2.7-highspeed",
|
||||||
|
"base_url": "https://api.minimaxi.com/v1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
state["current_agent"] = AgentRole.SCHEDULE_PLANNER
|
||||||
|
state["pending_action"] = {
|
||||||
|
"type": "schedule_creation",
|
||||||
|
"summary": "为周报安排明天下午提醒",
|
||||||
|
"status": "pending",
|
||||||
|
}
|
||||||
|
state["routing_decision"] = {
|
||||||
|
"target_agent": AgentRole.SCHEDULE_PLANNER.value,
|
||||||
|
"reason": "continue_pending_action",
|
||||||
|
}
|
||||||
|
state["continuity_state"] = {
|
||||||
|
"status": "fresh",
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = _build_system_messages(
|
||||||
|
state,
|
||||||
|
"manager prompt",
|
||||||
|
AgentRole.SCHEDULE_PLANNER,
|
||||||
|
"schedule_planning",
|
||||||
|
)
|
||||||
|
|
||||||
|
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
|
||||||
|
assert "pending_action" in system_text
|
||||||
|
assert "continue_pending_action" in system_text
|
||||||
|
|
||||||
|
|
||||||
|
async def test_run_sub_commander_coalesces_system_messages_for_openai_compatible_provider(
|
||||||
|
monkeypatch,
|
||||||
|
):
|
||||||
|
fake_llm = SingleSystemMessageLLM()
|
||||||
|
fake_tool = FakeTool("create_reminder", "created reminder: blanket @ tomorrow 09:00")
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.agents.graph._get_llm_for_state", lambda state: fake_llm)
|
||||||
|
monkeypatch.setitem(
|
||||||
|
__import__("app.agents.graph", fromlist=["SUB_COMMANDER_TOOLSETS"]).SUB_COMMANDER_TOOLSETS,
|
||||||
|
"schedule_planning",
|
||||||
|
[fake_tool],
|
||||||
|
)
|
||||||
|
|
||||||
|
state = _base_state(
|
||||||
|
"给我设置明天的提醒,提醒我收被子",
|
||||||
|
{
|
||||||
|
"provider": "openai",
|
||||||
|
"model": "MiniMax-M2.7-highspeed",
|
||||||
|
"base_url": "https://api.minimaxi.com/v1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
state["current_agent"] = AgentRole.SCHEDULE_PLANNER
|
||||||
|
|
||||||
|
result = await _run_sub_commander(
|
||||||
|
state,
|
||||||
|
AgentRole.SCHEDULE_PLANNER,
|
||||||
|
"manager prompt",
|
||||||
|
"给我设置明天的提醒,提醒我收被子",
|
||||||
|
use_tools=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert fake_llm.system_message_counts == [1, 1]
|
||||||
|
assert result["tool_strategy_used"] == "json_fallback"
|
||||||
|
assert fake_tool.invocations == [{"title": "blanket", "reminder_at": "2026-03-29T09:00:00"}]
|
||||||
|
assert result["final_response"] == "created reminder for blanket"
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
from app.agents.prompts import MASTER_SYSTEM_PROMPT
|
from app.agents.prompts import COORDINATOR_SYSTEM_PROMPT, MASTER_SYSTEM_PROMPT
|
||||||
|
|
||||||
|
|
||||||
def test_master_prompt_forbids_subagent_rollcall_in_simple_greetings():
|
def test_master_prompt_forbids_subagent_rollcall_in_simple_greetings():
|
||||||
@@ -10,3 +10,10 @@ def test_master_prompt_does_not_include_full_canned_answers_for_greetings_or_ide
|
|||||||
assert 'Jarvis:您好。我在。' not in MASTER_SYSTEM_PROMPT
|
assert 'Jarvis:您好。我在。' not in MASTER_SYSTEM_PROMPT
|
||||||
assert 'Jarvis:我是 Jarvis。' not in MASTER_SYSTEM_PROMPT
|
assert 'Jarvis:我是 Jarvis。' not in MASTER_SYSTEM_PROMPT
|
||||||
assert 'Jarvis:主要做三件事。' not in MASTER_SYSTEM_PROMPT
|
assert 'Jarvis:主要做三件事。' not in MASTER_SYSTEM_PROMPT
|
||||||
|
|
||||||
|
|
||||||
|
def test_coordinator_prompt_limits_collaboration_scope():
|
||||||
|
assert "2~4 个子任务" in COORDINATOR_SYSTEM_PROMPT
|
||||||
|
assert "禁止无限递归拆分" in COORDINATOR_SYSTEM_PROMPT
|
||||||
|
assert "schedule_planner" in COORDINATOR_SYSTEM_PROMPT
|
||||||
|
assert "librarian" in COORDINATOR_SYSTEM_PROMPT
|
||||||
|
|||||||
@@ -5,11 +5,13 @@ from app.agents.prompts import (
|
|||||||
SUB_COMMANDER_PROMPTS_BY_KEY,
|
SUB_COMMANDER_PROMPTS_BY_KEY,
|
||||||
TOP_LEVEL_SYSTEM_PROMPTS_BY_KEY,
|
TOP_LEVEL_SYSTEM_PROMPTS_BY_KEY,
|
||||||
)
|
)
|
||||||
from app.agents.registry import build_registry_indexes, load_builtin_registry_bundle
|
from app.agents.registry import build_registry_indexes, load_builtin_registry_bundle, load_builtin_registry_indexes
|
||||||
from app.agents.registry.indexes import summarize_registry_indexes
|
from app.agents.registry.indexes import summarize_registry_indexes
|
||||||
from app.agents.registry.models import (
|
from app.agents.registry.models import (
|
||||||
AgentManifest,
|
AgentManifest,
|
||||||
CapabilityManifest,
|
CapabilityManifest,
|
||||||
|
PermissionClass,
|
||||||
|
SideEffectScope,
|
||||||
SpecialistTemplateManifest,
|
SpecialistTemplateManifest,
|
||||||
SubCommanderManifest,
|
SubCommanderManifest,
|
||||||
)
|
)
|
||||||
@@ -251,17 +253,34 @@ def test_builtin_capabilities_reference_actual_runtime_tool_names() -> None:
|
|||||||
assert manifest_tool_names == expected_tool_names
|
assert manifest_tool_names == expected_tool_names
|
||||||
|
|
||||||
|
|
||||||
def test_builtin_sub_commander_capabilities_match_runtime_toolsets() -> None:
|
def test_builtin_capability_metadata_distinguishes_read_and_write_surfaces() -> None:
|
||||||
capabilities_by_tool_name = {
|
capability_by_id = {manifest.capability_id: manifest for manifest in BUILTIN_CAPABILITY_MANIFESTS}
|
||||||
manifest.tool_name: manifest.capability_id for manifest in BUILTIN_CAPABILITY_MANIFESTS
|
|
||||||
}
|
|
||||||
|
|
||||||
for sub_commander in BUILTIN_SUB_COMMANDER_MANIFESTS:
|
assert capability_by_id["get_tasks"].permission_class == PermissionClass.READ
|
||||||
expected_capability_ids = {
|
assert capability_by_id["get_tasks"].side_effect_scope == SideEffectScope.NONE
|
||||||
capabilities_by_tool_name[tool.name]
|
assert capability_by_id["get_tasks"].supports_retry is True
|
||||||
for tool in SUB_COMMANDER_TOOLSETS[sub_commander.sub_commander_id]
|
assert capability_by_id["get_tasks"].idempotent is True
|
||||||
}
|
assert capability_by_id["get_tasks"].safe_for_parallel_use is True
|
||||||
assert set(sub_commander.capability_ids) == expected_capability_ids
|
assert capability_by_id["get_tasks"].requires_confirmation is False
|
||||||
|
|
||||||
|
assert capability_by_id["create_reminder"].permission_class == PermissionClass.WRITE
|
||||||
|
assert capability_by_id["create_reminder"].side_effect_scope == SideEffectScope.LOCAL_STATE
|
||||||
|
assert capability_by_id["create_reminder"].supports_retry is False
|
||||||
|
assert capability_by_id["create_reminder"].idempotent is False
|
||||||
|
assert capability_by_id["create_reminder"].safe_for_parallel_use is False
|
||||||
|
assert capability_by_id["create_reminder"].requires_confirmation is True
|
||||||
|
|
||||||
|
assert capability_by_id["web_search"].permission_class == PermissionClass.EXTERNAL
|
||||||
|
assert capability_by_id["web_search"].side_effect_scope == SideEffectScope.NETWORK
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_builtin_registry_indexes_is_cached_and_matches_bundle_indexes() -> None:
|
||||||
|
cached = load_builtin_registry_indexes()
|
||||||
|
rebuilt = build_registry_indexes(load_builtin_registry_bundle())
|
||||||
|
|
||||||
|
assert cached is load_builtin_registry_indexes()
|
||||||
|
assert cached.capability_id_by_tool_name == rebuilt.capability_id_by_tool_name
|
||||||
|
assert cached.capability_by_id["create_reminder"].requires_confirmation is True
|
||||||
|
|
||||||
|
|
||||||
def test_builtin_manifests_form_a_valid_registry_bundle() -> None:
|
def test_builtin_manifests_form_a_valid_registry_bundle() -> None:
|
||||||
@@ -288,6 +307,7 @@ def test_build_registry_indexes_exposes_manifest_lookups_by_id() -> None:
|
|||||||
indexes = build_registry_indexes(bundle)
|
indexes = build_registry_indexes(bundle)
|
||||||
|
|
||||||
assert indexes.agent_by_id
|
assert indexes.agent_by_id
|
||||||
|
assert indexes.agent_by_role_value
|
||||||
assert indexes.sub_commander_by_id
|
assert indexes.sub_commander_by_id
|
||||||
assert indexes.capability_by_id
|
assert indexes.capability_by_id
|
||||||
assert isinstance(indexes.specialist_template_by_id, Mapping)
|
assert isinstance(indexes.specialist_template_by_id, Mapping)
|
||||||
@@ -343,6 +363,14 @@ def test_build_registry_indexes_exposes_prompt_keys_skill_context_keys_and_capab
|
|||||||
sub_commander.sub_commander_id: tuple(sub_commander.capability_ids)
|
sub_commander.sub_commander_id: tuple(sub_commander.capability_ids)
|
||||||
for sub_commander in bundle.sub_commanders
|
for sub_commander in bundle.sub_commanders
|
||||||
}
|
}
|
||||||
|
assert indexes.agent_by_role_value == {
|
||||||
|
agent.role_value: agent for agent in bundle.agents
|
||||||
|
}
|
||||||
|
assert indexes.spawnable_role_values_by_agent_id == {
|
||||||
|
agent.agent_id: tuple(agent.allowed_spawn_role_values)
|
||||||
|
for agent in bundle.agents
|
||||||
|
if agent.can_spawn_children and agent.allowed_spawn_role_values
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_validate_registry_bundle_accepts_loaded_builtin_registry_bundle() -> None:
|
def test_validate_registry_bundle_accepts_loaded_builtin_registry_bundle() -> None:
|
||||||
|
|||||||
135
backend/tests/backend/app/agents/test_schema_verifier.py
Normal file
135
backend/tests/backend/app/agents/test_schema_verifier.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
from app.agents.schemas import AgentEvent, AgentTask, TaskResult
|
||||||
|
from app.agents.schemas.task import CollaborationBudget, InterruptRecord, RecoveryRecord
|
||||||
|
from app.agents.state import initial_state
|
||||||
|
from app.agents.verifier import apply_verification_verdict, normalize_task_result, verify_task_result
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_task_supports_day3_interrupt_recovery_and_budget_fields():
|
||||||
|
interrupt = InterruptRecord(interrupt_id="interrupt-1", reason="user_cancel")
|
||||||
|
recovery = RecoveryRecord(recovery_id="recovery-1", source_interrupt_id="interrupt-1", resumed_from_task_id="task-1")
|
||||||
|
budget = CollaborationBudget(
|
||||||
|
mode="collaboration",
|
||||||
|
max_parallel_tasks=3,
|
||||||
|
remaining_parallel_tasks=2,
|
||||||
|
max_tool_calls=6,
|
||||||
|
remaining_tool_calls=4,
|
||||||
|
metadata={"phase": "day3"},
|
||||||
|
)
|
||||||
|
|
||||||
|
task = AgentTask(
|
||||||
|
task_id="task-1",
|
||||||
|
title="Recover interrupted collaboration task",
|
||||||
|
owner_agent_id="analyst",
|
||||||
|
role="analyst",
|
||||||
|
parent_task_id="parent-1",
|
||||||
|
child_task_ids=["child-1"],
|
||||||
|
thread_id="thread-1",
|
||||||
|
message_id="message-1",
|
||||||
|
message_index=3,
|
||||||
|
interrupt_records=[interrupt],
|
||||||
|
recovery_records=[recovery],
|
||||||
|
collaboration_budget=budget,
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = task.model_dump(mode="json")
|
||||||
|
|
||||||
|
assert payload["parent_task_id"] == "parent-1"
|
||||||
|
assert payload["child_task_ids"] == ["child-1"]
|
||||||
|
assert payload["thread_id"] == "thread-1"
|
||||||
|
assert payload["message_id"] == "message-1"
|
||||||
|
assert payload["message_index"] == 3
|
||||||
|
assert payload["interrupt_records"][0]["interrupt_id"] == "interrupt-1"
|
||||||
|
assert payload["recovery_records"][0]["recovery_id"] == "recovery-1"
|
||||||
|
assert payload["collaboration_budget"]["mode"] == "collaboration"
|
||||||
|
assert payload["collaboration_budget"]["remaining_tool_calls"] == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_event_supports_day3_thread_interrupt_and_recovery_metadata():
|
||||||
|
event = AgentEvent(
|
||||||
|
event_id="evt-1",
|
||||||
|
event_type="agent.task.recovered",
|
||||||
|
conversation_id="conv-1",
|
||||||
|
agent_id="executor",
|
||||||
|
task_id="task-1",
|
||||||
|
parent_task_id="parent-1",
|
||||||
|
child_task_id="child-1",
|
||||||
|
thread_id="thread-1",
|
||||||
|
message_id="message-1",
|
||||||
|
interrupt_id="interrupt-1",
|
||||||
|
recovery_id="recovery-1",
|
||||||
|
severity="warning",
|
||||||
|
payload={"status": "resumed"},
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = event.model_dump(mode="json")
|
||||||
|
|
||||||
|
assert payload["event_type"] == "agent.task.recovered"
|
||||||
|
assert payload["parent_task_id"] == "parent-1"
|
||||||
|
assert payload["child_task_id"] == "child-1"
|
||||||
|
assert payload["thread_id"] == "thread-1"
|
||||||
|
assert payload["message_id"] == "message-1"
|
||||||
|
assert payload["interrupt_id"] == "interrupt-1"
|
||||||
|
assert payload["recovery_id"] == "recovery-1"
|
||||||
|
assert payload["severity"] == "warning"
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_task_result_preserves_day3_metadata_fields():
|
||||||
|
result = normalize_task_result(
|
||||||
|
{
|
||||||
|
"task_id": "task-1",
|
||||||
|
"status": "completed",
|
||||||
|
"summary": "Recovered successfully.",
|
||||||
|
"owner_agent_id": "executor",
|
||||||
|
"parent_task_id": "parent-1",
|
||||||
|
"child_task_ids": ["child-1"],
|
||||||
|
"thread_id": "thread-1",
|
||||||
|
"message_id": "message-1",
|
||||||
|
"message_index": 2,
|
||||||
|
"interrupt_records": [{"interrupt_id": "interrupt-1", "reason": "user_pause"}],
|
||||||
|
"recovery_records": [{"recovery_id": "recovery-1", "source_interrupt_id": "interrupt-1"}],
|
||||||
|
"budget_snapshot": {"mode": "collaboration", "max_parallel_tasks": 4},
|
||||||
|
"next_action": "notify_user",
|
||||||
|
"output_data": {"ok": True},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.parent_task_id == "parent-1"
|
||||||
|
assert result.child_task_ids == ["child-1"]
|
||||||
|
assert result.thread_id == "thread-1"
|
||||||
|
assert result.message_id == "message-1"
|
||||||
|
assert result.message_index == 2
|
||||||
|
assert result.interrupt_records[0].interrupt_id == "interrupt-1"
|
||||||
|
assert result.recovery_records[0].recovery_id == "recovery-1"
|
||||||
|
assert result.budget_snapshot.mode == "collaboration"
|
||||||
|
assert result.budget_snapshot.max_parallel_tasks == 4
|
||||||
|
assert result.next_action == "notify_user"
|
||||||
|
assert result.output_data == {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_verification_verdict_updates_state_with_recovery_evidence():
|
||||||
|
state = initial_state("u1", "c1")
|
||||||
|
|
||||||
|
verdict = verify_task_result(
|
||||||
|
status="passed",
|
||||||
|
summary="Interrupt and recovery chain verified.",
|
||||||
|
evidence=[
|
||||||
|
{
|
||||||
|
"task_id": "task-1",
|
||||||
|
"thread_id": "thread-1",
|
||||||
|
"interrupt_id": "interrupt-1",
|
||||||
|
"recovery_id": "recovery-1",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
updated_state = apply_verification_verdict(state, verdict)
|
||||||
|
|
||||||
|
assert updated_state["verification_status"] == "passed"
|
||||||
|
assert updated_state["verification_summary"] == "Interrupt and recovery chain verified."
|
||||||
|
assert updated_state["verification_evidence"] == [
|
||||||
|
{
|
||||||
|
"task_id": "task-1",
|
||||||
|
"thread_id": "thread-1",
|
||||||
|
"interrupt_id": "interrupt-1",
|
||||||
|
"recovery_id": "recovery-1",
|
||||||
|
}
|
||||||
|
]
|
||||||
@@ -47,3 +47,27 @@ def test_web_search_tool_returns_stable_message_when_unavailable(monkeypatch):
|
|||||||
result = web_search.func('Jarvis')
|
result = web_search.func('Jarvis')
|
||||||
|
|
||||||
assert result == '网页搜索不可用: 网页搜索未启用或未配置'
|
assert result == '网页搜索不可用: 网页搜索未启用或未配置'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_web_search_tool_runs_from_active_event_loop(monkeypatch):
|
||||||
|
class FakeService:
|
||||||
|
async def search(self, query: str, limit: int | None = None):
|
||||||
|
assert query == 'Jarvis 最新更新'
|
||||||
|
assert limit == 1
|
||||||
|
return [
|
||||||
|
FakeResult(
|
||||||
|
title='Jarvis release notes',
|
||||||
|
url='https://example.com/jarvis-release',
|
||||||
|
snippet='Latest Jarvis changes.',
|
||||||
|
source='duckduckgo',
|
||||||
|
published_at='2026-03-29',
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
monkeypatch.setattr('app.services.web_search_service.WebSearchService', FakeService)
|
||||||
|
|
||||||
|
result = web_search.func('Jarvis 最新更新', top_k=1)
|
||||||
|
|
||||||
|
assert '[1] Jarvis release notes' in result
|
||||||
|
assert '链接: https://example.com/jarvis-release' in result
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import pytest
|
|||||||
|
|
||||||
from app.agents.tools import forum as forum_tools
|
from app.agents.tools import forum as forum_tools
|
||||||
from app.agents.tools import schedule as schedule_tools
|
from app.agents.tools import schedule as schedule_tools
|
||||||
|
from app.agents.tools import search as search_tools
|
||||||
from app.agents.tools import task as task_tools
|
from app.agents.tools import task as task_tools
|
||||||
|
|
||||||
|
|
||||||
@@ -12,6 +13,7 @@ from app.agents.tools import task as task_tools
|
|||||||
(task_tools, "task"),
|
(task_tools, "task"),
|
||||||
(schedule_tools, "schedule"),
|
(schedule_tools, "schedule"),
|
||||||
(forum_tools, "forum"),
|
(forum_tools, "forum"),
|
||||||
|
(search_tools, "search"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_run_async_bridge_works_inside_running_event_loop(module, label):
|
async def test_run_async_bridge_works_inside_running_event_loop(module, label):
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user