Compare commits
48 Commits
phase1-reg
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 145c43f09c | |||
| 847d9f96db | |||
| 7f5b133fad | |||
| 21c869db62 | |||
| 1ca8855751 | |||
| d8f8b0c177 | |||
| 7e6eb6a7b3 | |||
| c70e7e7253 | |||
| 39a9058de1 | |||
| ac49c13965 | |||
| 3e39b40a50 | |||
| 8c7cf0732b | |||
| aa12c92a5a | |||
| 51e38e039b | |||
| e637c8ca2f | |||
| 52fb619084 | |||
| dc9051debc | |||
| 74fdfc2652 | |||
| 36c93a764f | |||
| 72a60c698a | |||
| 4ef7549efe | |||
| de08165e07 | |||
| 4702cc8ed2 | |||
| 62bf414ff2 | |||
| 536c541a5b | |||
| 7aef898bf5 | |||
| 721ddbeef9 | |||
| 3bff9b3b93 | |||
| 3cf8762b96 | |||
| 712d9e1652 | |||
| ff042cd932 | |||
| 472528e708 | |||
| e24092f3ab | |||
| f0658201e5 | |||
| f033fb5879 | |||
| 5667190abe | |||
| 11160ec4d2 | |||
| 9bfa0dcc11 | |||
| bfe3b6bb9d | |||
| 10d9340c53 | |||
| fca7a7cf3d | |||
| d18167826e | |||
| 88955ed550 | |||
| a3fe4d24fc | |||
| e5bd492d74 | |||
| a7b6b5eb90 | |||
| aa0ef0fbea | |||
| 4972b4e6b1 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -40,6 +40,9 @@ logs/
|
||||
.claude/
|
||||
.worktrees/
|
||||
|
||||
# Demo (excluded from version control)
|
||||
demo/
|
||||
|
||||
# Lock files (use in development, commit in production)
|
||||
# uv.lock - uncomment if you want to commit lock file
|
||||
# package-lock.json - uncomment if you want to commit lock file
|
||||
|
||||
220
backend/app/agents/background/executor.py
Normal file
220
backend/app/agents/background/executor.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""Background task executor - Phase 10.4"""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from .manager import (
|
||||
BackgroundTask,
|
||||
BackgroundTaskManager,
|
||||
BackgroundTaskStatus,
|
||||
get_background_task_manager,
|
||||
)
|
||||
|
||||
|
||||
class BackgroundExecutor:
|
||||
"""Executes background tasks with error handling and result storage.
|
||||
|
||||
Provides methods to execute tasks synchronously or asynchronously,
|
||||
with full integration into BackgroundTaskManager for tracking.
|
||||
"""
|
||||
|
||||
def __init__(self, task_manager: BackgroundTaskManager | None = None):
|
||||
"""Initialize the executor.
|
||||
|
||||
Args:
|
||||
task_manager: Optional BackgroundTaskManager instance.
|
||||
If not provided, uses the global singleton.
|
||||
"""
|
||||
self._task_manager = task_manager or get_background_task_manager()
|
||||
self._executors: dict[str, asyncio.Task] = {}
|
||||
|
||||
async def execute_task(
|
||||
self,
|
||||
task_id: str,
|
||||
func: Callable[..., Coroutine[Any, Any, Any]],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> BackgroundTask:
|
||||
"""Execute a specific task by ID.
|
||||
|
||||
Args:
|
||||
task_id: Unique task identifier
|
||||
func: Async function to execute
|
||||
*args: Positional arguments for the function
|
||||
**kwargs: Keyword arguments for the function
|
||||
|
||||
Returns:
|
||||
The BackgroundTask with result or error
|
||||
"""
|
||||
# Get or create task record
|
||||
task = self._task_manager.get_task_status(task_id)
|
||||
if task is None:
|
||||
# Create a new task record if one doesn't exist
|
||||
task = BackgroundTask(
|
||||
id=task_id,
|
||||
name=f"executor_task_{task_id}",
|
||||
status=BackgroundTaskStatus.PENDING,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
self._task_manager._tasks[task_id] = task
|
||||
|
||||
# Update status to running
|
||||
task.status = BackgroundTaskStatus.RUNNING
|
||||
task.started_at = datetime.now()
|
||||
|
||||
try:
|
||||
# Execute the async function
|
||||
result = await func(*args, **kwargs)
|
||||
task.status = BackgroundTaskStatus.COMPLETED
|
||||
task.result = result
|
||||
except Exception as e:
|
||||
task.status = BackgroundTaskStatus.FAILED
|
||||
task.error = f"{type(e).__name__}: {str(e)}"
|
||||
task.result = None
|
||||
finally:
|
||||
task.completed_at = datetime.now()
|
||||
# Clean up executor reference
|
||||
if task_id in self._executors:
|
||||
del self._executors[task_id]
|
||||
|
||||
return task
|
||||
|
||||
async def execute_async(
|
||||
self,
|
||||
task_id: str,
|
||||
func: Callable[..., Coroutine[Any, Any, Any]],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Execute a task asynchronously in the background.
|
||||
|
||||
Args:
|
||||
task_id: Unique task identifier
|
||||
func: Async function to execute
|
||||
*args: Positional arguments for the function
|
||||
**kwargs: Keyword arguments for the function
|
||||
|
||||
Returns:
|
||||
The task ID
|
||||
"""
|
||||
# Create task record if it doesn't exist
|
||||
if self._task_manager.get_task_status(task_id) is None:
|
||||
self._task_manager._tasks[task_id] = BackgroundTask(
|
||||
id=task_id,
|
||||
name=f"async_task_{task_id}",
|
||||
status=BackgroundTaskStatus.PENDING,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
# Create and store the asyncio task
|
||||
async_task = asyncio.create_task(self.execute_task(task_id, func, *args, **kwargs))
|
||||
self._executors[task_id] = async_task
|
||||
|
||||
return task_id
|
||||
|
||||
def cancel_task(self, task_id: str) -> bool:
|
||||
"""Cancel a running task.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to cancel
|
||||
|
||||
Returns:
|
||||
True if cancelled, False if not found or not running
|
||||
"""
|
||||
if task_id not in self._executors:
|
||||
return False
|
||||
|
||||
self._executors[task_id].cancel()
|
||||
del self._executors[task_id]
|
||||
|
||||
# Update task status
|
||||
task = self._task_manager.get_task_status(task_id)
|
||||
if task:
|
||||
task.status = BackgroundTaskStatus.CANCELLED
|
||||
task.completed_at = datetime.now()
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_task_result(self, task_id: str) -> Any:
|
||||
"""Get the result of a completed task.
|
||||
|
||||
Args:
|
||||
task_id: The task ID
|
||||
|
||||
Returns:
|
||||
The task result or None if not found/not completed
|
||||
"""
|
||||
task = self._task_manager.get_task_status(task_id)
|
||||
if task and task.status == BackgroundTaskStatus.COMPLETED:
|
||||
return task.result
|
||||
return None
|
||||
|
||||
def get_task_error(self, task_id: str) -> str | None:
|
||||
"""Get the error of a failed task.
|
||||
|
||||
Args:
|
||||
task_id: The task ID
|
||||
|
||||
Returns:
|
||||
The error message or None if not found/not failed
|
||||
"""
|
||||
task = self._task_manager.get_task_status(task_id)
|
||||
if task and task.status == BackgroundTaskStatus.FAILED:
|
||||
return task.error
|
||||
return None
|
||||
|
||||
def is_task_running(self, task_id: str) -> bool:
|
||||
"""Check if a task is currently running.
|
||||
|
||||
Args:
|
||||
task_id: The task ID
|
||||
|
||||
Returns:
|
||||
True if running, False otherwise
|
||||
"""
|
||||
return task_id in self._executors
|
||||
|
||||
def wait_for_task(self, task_id: str, timeout: float | None = None) -> BackgroundTask:
|
||||
"""Wait for a task to complete.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to wait for
|
||||
timeout: Optional timeout in seconds
|
||||
|
||||
Returns:
|
||||
The completed BackgroundTask
|
||||
|
||||
Raises:
|
||||
asyncio.TimeoutError: If task doesn't complete within timeout
|
||||
asyncio.CancelledError: If task is cancelled
|
||||
"""
|
||||
if task_id not in self._executors:
|
||||
task = self._task_manager.get_task_status(task_id)
|
||||
if task:
|
||||
return task
|
||||
raise ValueError(f"Task {task_id} not found")
|
||||
|
||||
async def wait_task() -> BackgroundTask:
|
||||
await self._executors[task_id]
|
||||
return self._task_manager.get_task_status(task_id)
|
||||
|
||||
return asyncio.run_until_complete(asyncio.wait_for(wait_task(), timeout=timeout))
|
||||
|
||||
@property
|
||||
def task_manager(self) -> BackgroundTaskManager:
|
||||
"""Get the underlying task manager."""
|
||||
return self._task_manager
|
||||
|
||||
|
||||
# Global executor instance
|
||||
_executor: BackgroundExecutor | None = None
|
||||
|
||||
|
||||
def get_background_executor() -> BackgroundExecutor:
|
||||
"""Get the global BackgroundExecutor instance."""
|
||||
global _executor
|
||||
if _executor is None:
|
||||
_executor = BackgroundExecutor()
|
||||
return _executor
|
||||
119
backend/app/agents/background/manager.py
Normal file
119
backend/app/agents/background/manager.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""后台任务系统 - Phase 10.4"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class BackgroundTaskStatus(Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackgroundTask:
|
||||
"""后台任务"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
status: BackgroundTaskStatus
|
||||
created_at: datetime
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class BackgroundTaskManager:
|
||||
"""后台任务管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self._tasks: dict[str, BackgroundTask] = {}
|
||||
self._.coroutines: dict[str, asyncio.Task] = {}
|
||||
|
||||
def submit_task(self, name: str, coro: Any, *args, **kwargs) -> str:
|
||||
"""提交后台任务
|
||||
|
||||
Args:
|
||||
name: 任务名称
|
||||
coro: 协程函数
|
||||
*args: 位置参数
|
||||
**kwargs: 关键字参数
|
||||
|
||||
Returns:
|
||||
任务 ID
|
||||
"""
|
||||
task_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# 创建任务记录
|
||||
self._tasks[task_id] = BackgroundTask(
|
||||
id=task_id,
|
||||
name=name,
|
||||
status=BackgroundTaskStatus.PENDING,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
# 创建 asyncio task
|
||||
async def run_task():
|
||||
self._tasks[task_id].status = BackgroundTaskStatus.RUNNING
|
||||
self._tasks[task_id].started_at = datetime.now()
|
||||
try:
|
||||
result = await coro(*args, **kwargs)
|
||||
self._tasks[task_id].status = BackgroundTaskStatus.COMPLETED
|
||||
self._tasks[task_id].result = result
|
||||
except Exception as e:
|
||||
self._tasks[task_id].status = BackgroundTaskStatus.FAILED
|
||||
self._tasks[task_id].error = str(e)
|
||||
finally:
|
||||
self._tasks[task_id].completed_at = datetime.now()
|
||||
if task_id in self._coroutines:
|
||||
del self._coroutines[task_id]
|
||||
|
||||
self._coroutines[task_id] = asyncio.create_task(run_task())
|
||||
return task_id
|
||||
|
||||
def cancel_task(self, task_id: str) -> bool:
|
||||
"""取消任务
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
|
||||
Returns:
|
||||
是否成功取消
|
||||
"""
|
||||
if task_id not in self._tasks:
|
||||
return False
|
||||
|
||||
if task_id in self._coroutines:
|
||||
self._coroutines[task_id].cancel()
|
||||
del self._coroutines[task_id]
|
||||
|
||||
self._tasks[task_id].status = BackgroundTaskStatus.CANCELLED
|
||||
self._tasks[task_id].completed_at = datetime.now()
|
||||
return True
|
||||
|
||||
def get_task_status(self, task_id: str) -> BackgroundTask | None:
|
||||
"""获取任务状态"""
|
||||
return self._tasks.get(task_id)
|
||||
|
||||
def list_tasks(self) -> list[BackgroundTask]:
|
||||
"""列出所有任务"""
|
||||
return list(self._tasks.values())
|
||||
|
||||
|
||||
# 全局单例
|
||||
_manager: BackgroundTaskManager | None = None
|
||||
|
||||
|
||||
def get_background_task_manager() -> BackgroundTaskManager:
|
||||
"""获取全局后台任务管理器"""
|
||||
global _manager
|
||||
if _manager is None:
|
||||
_manager = BackgroundTaskManager()
|
||||
return _manager
|
||||
146
backend/app/agents/background/scheduler.py
Normal file
146
backend/app/agents/background/scheduler.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Background task scheduler - Phase 10.4"""
|
||||
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.base import BaseTrigger
|
||||
|
||||
from .manager import BackgroundTaskManager, get_background_task_manager
|
||||
|
||||
|
||||
class BackgroundScheduler:
|
||||
"""Background task scheduler using APScheduler.
|
||||
|
||||
Integrates with BackgroundTaskManager for task tracking and execution.
|
||||
"""
|
||||
|
||||
def __init__(self, task_manager: BackgroundTaskManager | None = None):
|
||||
"""Initialize the scheduler.
|
||||
|
||||
Args:
|
||||
task_manager: Optional BackgroundTaskManager instance.
|
||||
If not provided, uses the global singleton.
|
||||
"""
|
||||
self._scheduler = AsyncIOScheduler()
|
||||
self._task_manager = task_manager or get_background_task_manager()
|
||||
self._job_tasks: dict[str, str] = {} # Maps APScheduler job_id to task_id
|
||||
|
||||
def add_job(
|
||||
self,
|
||||
func: Callable[..., Coroutine[Any, Any, Any]],
|
||||
trigger: BaseTrigger,
|
||||
args: tuple[Any, ...] | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
id: str | None = None,
|
||||
name: str | None = None,
|
||||
**apscheduler_kwargs: Any,
|
||||
) -> str:
|
||||
"""Add a job to the scheduler.
|
||||
|
||||
Args:
|
||||
func: Async function to execute
|
||||
trigger: APScheduler trigger (date, interval, cron, etc.)
|
||||
args: Positional arguments for the function
|
||||
kwargs: Keyword arguments for the function
|
||||
id: Unique job ID (auto-generated if not provided)
|
||||
name: Job name for display purposes
|
||||
**apscheduler_kwargs: Additional APScheduler options
|
||||
|
||||
Returns:
|
||||
The job ID
|
||||
"""
|
||||
job_id = id or f"job_{len(self._job_tasks)}"
|
||||
task_name = name or f"scheduled_task_{job_id}"
|
||||
|
||||
# Wrap the async function to integrate with BackgroundTaskManager
|
||||
async def wrapped_func() -> None:
|
||||
coro = func(*(args or ()), **(kwargs or {}))
|
||||
task_id = self._task_manager.submit_task(task_name, coro)
|
||||
self._job_tasks[job_id] = task_id
|
||||
|
||||
self._scheduler.add_job(
|
||||
wrapped_func,
|
||||
trigger=trigger,
|
||||
id=job_id,
|
||||
name=task_name,
|
||||
**apscheduler_kwargs,
|
||||
)
|
||||
|
||||
return job_id
|
||||
|
||||
def remove_job(self, job_id: str) -> bool:
|
||||
"""Remove a job from the scheduler.
|
||||
|
||||
Args:
|
||||
job_id: The ID of the job to remove
|
||||
|
||||
Returns:
|
||||
True if job was removed, False if job didn't exist
|
||||
"""
|
||||
try:
|
||||
self._scheduler.remove_job(job_id)
|
||||
# Clean up task mapping if exists
|
||||
if job_id in self._job_tasks:
|
||||
task_id = self._job_tasks.pop(job_id)
|
||||
# Cancel the background task if still running
|
||||
self._task_manager.cancel_task(task_id)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def list_jobs(self) -> list[dict[str, Any]]:
|
||||
"""List all scheduled jobs.
|
||||
|
||||
Returns:
|
||||
List of job information dictionaries
|
||||
"""
|
||||
jobs = self._scheduler.get_jobs()
|
||||
return [
|
||||
{
|
||||
"id": job.id,
|
||||
"name": job.name,
|
||||
"next_run_time": job.next_run_time,
|
||||
"trigger": str(job.trigger),
|
||||
}
|
||||
for job in jobs
|
||||
]
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the scheduler."""
|
||||
if not self._scheduler.running:
|
||||
self._scheduler.start()
|
||||
|
||||
def shutdown(self, wait: bool = True) -> None:
|
||||
"""Shutdown the scheduler.
|
||||
|
||||
Args:
|
||||
wait: Whether to wait for running jobs to complete
|
||||
"""
|
||||
if self._scheduler.running:
|
||||
self._scheduler.shutdown(wait=wait)
|
||||
|
||||
def pause(self) -> None:
|
||||
"""Pause the scheduler."""
|
||||
self._scheduler.pause()
|
||||
|
||||
def resume(self) -> None:
|
||||
"""Resume the scheduler."""
|
||||
self._scheduler.resume()
|
||||
|
||||
@property
|
||||
def task_manager(self) -> BackgroundTaskManager:
|
||||
"""Get the underlying task manager."""
|
||||
return self._task_manager
|
||||
|
||||
|
||||
# Global scheduler instance
|
||||
_scheduler: BackgroundScheduler | None = None
|
||||
|
||||
|
||||
def get_background_scheduler() -> BackgroundScheduler:
|
||||
"""Get the global BackgroundScheduler instance."""
|
||||
global _scheduler
|
||||
if _scheduler is None:
|
||||
_scheduler = BackgroundScheduler()
|
||||
return _scheduler
|
||||
508
backend/app/agents/coordinator.py
Normal file
508
backend/app/agents/coordinator.py
Normal file
@@ -0,0 +1,508 @@
|
||||
"""Agent 协调整器 - Phase 10.5
|
||||
|
||||
统一协调所有 Agent 组件:TeamLeader, RemoteTransport, BackgroundTaskManager, SessionManager
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.background.manager import BackgroundTaskManager, get_background_task_manager
|
||||
from app.agents.session.manager import AgentSession, create_agent_session, get_agent_session
|
||||
from app.agents.team.leader import TeamLeader
|
||||
from app.agents.transport.remote import RemoteTransport
|
||||
|
||||
|
||||
class AgentCoordinator:
|
||||
"""Agent 协调整器
|
||||
|
||||
统一协调所有 Agent 组件,提供单一入口处理各类 Agent 操作。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
background_manager: BackgroundTaskManager | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
background_manager: 后台任务管理器,None 则使用全局单例
|
||||
"""
|
||||
self._team_leaders: dict[str, TeamLeader] = {}
|
||||
self._remote_transport = RemoteTransport()
|
||||
self._background_manager = background_manager or get_background_task_manager()
|
||||
self._sessions: dict[str, AgentSession] = {}
|
||||
|
||||
# === Team 协作方法 ===
|
||||
|
||||
def create_team(self, team_id: str, members: list[str]) -> dict[str, Any]:
|
||||
"""创建团队
|
||||
|
||||
Args:
|
||||
team_id: 团队 ID
|
||||
members: 成员 ID 列表
|
||||
|
||||
Returns:
|
||||
团队创建结果
|
||||
"""
|
||||
if team_id in self._team_leaders:
|
||||
return {"status": "error", "message": f"Team '{team_id}' already exists"}
|
||||
|
||||
leader = TeamLeader(team_id=team_id, members=members)
|
||||
self._team_leaders[team_id] = leader
|
||||
return {
|
||||
"status": "created",
|
||||
"team_id": team_id,
|
||||
"members": members,
|
||||
}
|
||||
|
||||
def get_team(self, team_id: str) -> TeamLeader | None:
|
||||
"""获取团队
|
||||
|
||||
Args:
|
||||
team_id: 团队 ID
|
||||
|
||||
Returns:
|
||||
TeamLeader 或 None
|
||||
"""
|
||||
return self._team_leaders.get(team_id)
|
||||
|
||||
def assign_task(self, team_id: str, description: str, member: str) -> dict[str, Any]:
|
||||
"""创建并分配任务
|
||||
|
||||
Args:
|
||||
team_id: 团队 ID
|
||||
description: 任务描述
|
||||
member: 成员 ID
|
||||
|
||||
Returns:
|
||||
分配结果
|
||||
"""
|
||||
leader = self._team_leaders.get(team_id)
|
||||
if not leader:
|
||||
return {"status": "error", "message": f"Team '{team_id}' not found"}
|
||||
|
||||
task_id = leader.create_task(description)
|
||||
success = leader.assign_task(task_id, member)
|
||||
return {
|
||||
"status": "assigned" if success else "error",
|
||||
"task_id": task_id,
|
||||
"assignee": member,
|
||||
}
|
||||
|
||||
def broadcast_task(self, team_id: str, description: str) -> dict[str, Any]:
|
||||
"""广播任务给所有成员
|
||||
|
||||
Args:
|
||||
team_id: 团队 ID
|
||||
description: 任务描述
|
||||
|
||||
Returns:
|
||||
广播结果
|
||||
"""
|
||||
leader = self._team_leaders.get(team_id)
|
||||
if not leader:
|
||||
return {"status": "error", "message": f"Team '{team_id}' not found"}
|
||||
|
||||
task_ids = leader.broadcast_task(description)
|
||||
return {
|
||||
"status": "broadcast",
|
||||
"team_id": team_id,
|
||||
"task_ids": task_ids,
|
||||
"member_count": len(leader.members),
|
||||
}
|
||||
|
||||
def collect_team_results(self, team_id: str) -> dict[str, Any]:
|
||||
"""收集团队任务结果
|
||||
|
||||
Args:
|
||||
team_id: 团队 ID
|
||||
|
||||
Returns:
|
||||
收集结果
|
||||
"""
|
||||
leader = self._team_leaders.get(team_id)
|
||||
if not leader:
|
||||
return {"status": "error", "message": f"Team '{team_id}' not found"}
|
||||
|
||||
results = leader.collect_results()
|
||||
status = leader.get_team_status()
|
||||
return {
|
||||
"status": "collected",
|
||||
"team_id": team_id,
|
||||
"results": results,
|
||||
"completed": status["completed"],
|
||||
"failed": status["failed"],
|
||||
}
|
||||
|
||||
def get_team_status(self, team_id: str) -> dict[str, Any]:
|
||||
"""获取团队状态
|
||||
|
||||
Args:
|
||||
team_id: 团队 ID
|
||||
|
||||
Returns:
|
||||
团队状态
|
||||
"""
|
||||
leader = self._team_leaders.get(team_id)
|
||||
if not leader:
|
||||
return {"status": "error", "message": f"Team '{team_id}' not found"}
|
||||
|
||||
return leader.get_team_status()
|
||||
|
||||
# === 后台任务方法 ===
|
||||
|
||||
def submit_background_task(
|
||||
self,
|
||||
name: str,
|
||||
coro: Any,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
"""提交后台任务
|
||||
|
||||
Args:
|
||||
name: 任务名称
|
||||
coro: 协程函数
|
||||
*args: 位置参数
|
||||
**kwargs: 关键字参数
|
||||
|
||||
Returns:
|
||||
提交结果
|
||||
"""
|
||||
task_id = self._background_manager.submit_task(name, coro, *args, **kwargs)
|
||||
return {
|
||||
"status": "submitted",
|
||||
"task_id": task_id,
|
||||
"name": name,
|
||||
}
|
||||
|
||||
def cancel_background_task(self, task_id: str) -> dict[str, Any]:
|
||||
"""取消后台任务
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
|
||||
Returns:
|
||||
取消结果
|
||||
"""
|
||||
success = self._background_manager.cancel_task(task_id)
|
||||
return {
|
||||
"status": "cancelled" if success else "error",
|
||||
"task_id": task_id,
|
||||
}
|
||||
|
||||
def get_background_task_status(self, task_id: str) -> dict[str, Any]:
|
||||
"""获取后台任务状态
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
|
||||
Returns:
|
||||
任务状态
|
||||
"""
|
||||
task = self._background_manager.get_task_status(task_id)
|
||||
if not task:
|
||||
return {"status": "error", "message": f"Task '{task_id}' not found"}
|
||||
|
||||
return {
|
||||
"status": "found",
|
||||
"task_id": task.id,
|
||||
"name": task.name,
|
||||
"task_status": task.status.value,
|
||||
"result": task.result,
|
||||
"error": task.error,
|
||||
}
|
||||
|
||||
def list_background_tasks(self) -> dict[str, Any]:
|
||||
"""列出所有后台任务
|
||||
|
||||
Returns:
|
||||
任务列表
|
||||
"""
|
||||
tasks = self._background_manager.list_tasks()
|
||||
return {
|
||||
"status": "list",
|
||||
"count": len(tasks),
|
||||
"tasks": [
|
||||
{
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"status": t.status.value,
|
||||
}
|
||||
for t in tasks
|
||||
],
|
||||
}
|
||||
|
||||
# === 会话方法 ===
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
parent_session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""创建会话
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
parent_session_id: 父会话 ID
|
||||
|
||||
Returns:
|
||||
创建结果
|
||||
"""
|
||||
session = create_agent_session(
|
||||
user_id=user_id,
|
||||
parent_session_id=parent_session_id,
|
||||
)
|
||||
self._sessions[session.session_id] = session
|
||||
return {
|
||||
"status": "created",
|
||||
"session_id": session.session_id,
|
||||
"user_id": user_id,
|
||||
"parent_session_id": parent_session_id,
|
||||
}
|
||||
|
||||
def get_session(self, session_id: str) -> AgentSession | None:
|
||||
"""获取会话
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
|
||||
Returns:
|
||||
AgentSession 或 None
|
||||
"""
|
||||
return self._sessions.get(session_id) or get_agent_session(session_id)
|
||||
|
||||
async def process_session_message(
|
||||
self,
|
||||
session_id: str,
|
||||
message: str,
|
||||
response: str,
|
||||
) -> dict[str, Any]:
|
||||
"""处理会话消息
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
message: 用户消息
|
||||
response: 助手响应
|
||||
|
||||
Returns:
|
||||
处理结果
|
||||
"""
|
||||
session = self.get_session(session_id)
|
||||
if not session:
|
||||
return {"status": "error", "message": f"Session '{session_id}' not found"}
|
||||
|
||||
await session.process_message(message, response)
|
||||
return {
|
||||
"status": "processed",
|
||||
"session_id": session_id,
|
||||
"message_count": session.context.message_count,
|
||||
}
|
||||
|
||||
async def spawn_child_session(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""创建子会话
|
||||
|
||||
Args:
|
||||
session_id: 父会话 ID
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
创建结果
|
||||
"""
|
||||
session = self.get_session(session_id)
|
||||
if not session:
|
||||
return {"status": "error", "message": f"Session '{session_id}' not found"}
|
||||
|
||||
child = await session.spawn_child_session(user_id=user_id)
|
||||
self._sessions[child.session_id] = child
|
||||
return {
|
||||
"status": "spawned",
|
||||
"parent_session_id": session_id,
|
||||
"child_session_id": child.session_id,
|
||||
"depth": child.context.depth,
|
||||
}
|
||||
|
||||
def get_session_summary(self, session_id: str) -> dict[str, Any]:
|
||||
"""获取会话摘要
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
|
||||
Returns:
|
||||
会话摘要
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
session = self.get_session(session_id)
|
||||
if not session:
|
||||
return {"status": "error", "message": f"Session '{session_id}' not found"}
|
||||
|
||||
# get_session_summary is async, so we need to run it
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# Create a future
|
||||
future = asyncio.ensure_future(session.get_session_summary())
|
||||
return {"status": "found", "summary": future}
|
||||
else:
|
||||
return {
|
||||
"status": "found",
|
||||
"summary": loop.run_until_complete(session.get_session_summary()),
|
||||
}
|
||||
except RuntimeError:
|
||||
# No event loop, create one
|
||||
return {"status": "found", "summary": asyncio.run(session.get_session_summary())}
|
||||
|
||||
# === 远程传输方法 ===
|
||||
|
||||
def register_remote_handler(self, event_type: str, handler: Any) -> None:
|
||||
"""注册远程消息处理器
|
||||
|
||||
Args:
|
||||
event_type: 事件类型
|
||||
handler: 处理函数
|
||||
"""
|
||||
self._remote_transport.register_handler(event_type, handler)
|
||||
|
||||
async def send_remote_response(
|
||||
self,
|
||||
session_id: str,
|
||||
response: dict[str, Any],
|
||||
) -> bool:
|
||||
"""发送远程响应
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
response: 响应数据
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
return await self._remote_transport.send_response(session_id, response)
|
||||
|
||||
async def send_remote_event(
|
||||
self,
|
||||
session_id: str,
|
||||
event: dict[str, Any],
|
||||
) -> bool:
|
||||
"""发送远程事件
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
event: 事件数据
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
return await self._remote_transport.send_event(session_id, event)
|
||||
|
||||
async def send_remote_tool_call(
|
||||
self,
|
||||
session_id: str,
|
||||
tool_call: dict[str, Any],
|
||||
) -> bool:
|
||||
"""发送远程工具调用
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
tool_call: 工具调用数据
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
return await self._remote_transport.send_tool_call(session_id, tool_call)
|
||||
|
||||
# === 统一协调入口 ===
|
||||
|
||||
async def coordinate(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""统一协调入口
|
||||
|
||||
根据请求类型协调各类 Agent 操作。
|
||||
|
||||
Args:
|
||||
request: 请求数据,包含:
|
||||
- action: 操作类型 (team_create, team_assign, task_submit, session_create, etc.)
|
||||
- 其他参数根据 action 不同而不同
|
||||
|
||||
Returns:
|
||||
协调结果
|
||||
"""
|
||||
action = request.get("action")
|
||||
|
||||
if action == "team_create":
|
||||
return self.create_team(
|
||||
team_id=request["team_id"],
|
||||
members=request["members"],
|
||||
)
|
||||
|
||||
elif action == "team_assign":
|
||||
return self.assign_task(
|
||||
team_id=request["team_id"],
|
||||
description=request["description"],
|
||||
member=request["member"],
|
||||
)
|
||||
|
||||
elif action == "team_broadcast":
|
||||
return self.broadcast_task(
|
||||
team_id=request["team_id"],
|
||||
description=request["description"],
|
||||
)
|
||||
|
||||
elif action == "team_collect":
|
||||
return self.collect_team_results(team_id=request["team_id"])
|
||||
|
||||
elif action == "team_status":
|
||||
return self.get_team_status(team_id=request["team_id"])
|
||||
|
||||
elif action == "task_submit":
|
||||
return self.submit_background_task(
|
||||
name=request["name"],
|
||||
coro=request["coro"],
|
||||
*request.get("args", []),
|
||||
**request.get("kwargs", {}),
|
||||
)
|
||||
|
||||
elif action == "task_cancel":
|
||||
return self.cancel_background_task(task_id=request["task_id"])
|
||||
|
||||
elif action == "task_status":
|
||||
return self.get_background_task_status(task_id=request["task_id"])
|
||||
|
||||
elif action == "session_create":
|
||||
return self.create_session(
|
||||
user_id=request.get("user_id"),
|
||||
parent_session_id=request.get("parent_session_id"),
|
||||
)
|
||||
|
||||
elif action == "session_message":
|
||||
return await self.process_session_message(
|
||||
session_id=request["session_id"],
|
||||
message=request["message"],
|
||||
response=request["response"],
|
||||
)
|
||||
|
||||
elif action == "session_spawn":
|
||||
return await self.spawn_child_session(
|
||||
session_id=request["session_id"],
|
||||
user_id=request.get("user_id"),
|
||||
)
|
||||
|
||||
elif action == "session_summary":
|
||||
return self.get_session_summary(session_id=request["session_id"])
|
||||
|
||||
else:
|
||||
return {"status": "error", "message": f"Unknown action: {action}"}
|
||||
|
||||
|
||||
# 全局单例
|
||||
_coordinator: AgentCoordinator | None = None
|
||||
|
||||
|
||||
def get_agent_coordinator() -> AgentCoordinator:
|
||||
"""获取全局 Agent 协调整器"""
|
||||
global _coordinator
|
||||
if _coordinator is None:
|
||||
_coordinator = AgentCoordinator()
|
||||
return _coordinator
|
||||
File diff suppressed because it is too large
Load Diff
14
backend/app/agents/isolation/__init__.py
Normal file
14
backend/app/agents/isolation/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from app.agents.isolation.session_isolation import prepare_session_isolation
|
||||
from app.agents.isolation.strategy_selector import IsolationDecision, select_isolation_strategy
|
||||
from app.agents.isolation.worktree_isolation import (
|
||||
WorktreeIsolationError,
|
||||
prepare_worktree_isolation,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"IsolationDecision",
|
||||
"WorktreeIsolationError",
|
||||
"prepare_session_isolation",
|
||||
"prepare_worktree_isolation",
|
||||
"select_isolation_strategy",
|
||||
]
|
||||
31
backend/app/agents/isolation/session_isolation.py
Normal file
31
backend/app/agents/isolation/session_isolation.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from app.agents.isolation.strategy_selector import IsolationDecision
|
||||
|
||||
|
||||
def prepare_session_isolation(
|
||||
*,
|
||||
state: dict[str, Any],
|
||||
decision: IsolationDecision,
|
||||
role_value: str,
|
||||
sub_commander: str,
|
||||
) -> dict[str, Any]:
|
||||
isolation_id = f"session-{uuid4().hex[:8]}"
|
||||
return {
|
||||
"mode": "session",
|
||||
"isolation_id": isolation_id,
|
||||
"workspace_path": None,
|
||||
"parent_conversation_id": str(state.get("conversation_id") or "") or None,
|
||||
"metadata": {
|
||||
**dict(decision.metadata or {}),
|
||||
"reason": decision.reason,
|
||||
"role": role_value,
|
||||
"sub_commander": sub_commander,
|
||||
"tool_names": list(decision.tool_names),
|
||||
"capability_ids": list(decision.capability_ids),
|
||||
"status": "active",
|
||||
},
|
||||
}
|
||||
147
backend/app/agents/isolation/strategy_selector.py
Normal file
147
backend/app/agents/isolation/strategy_selector.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
from app.agents.registry import load_builtin_registry_indexes
|
||||
from app.agents.registry.models import CapabilityManifest, PermissionClass, SideEffectScope
|
||||
|
||||
|
||||
IsolationMode = Literal["none", "session", "worktree"]
|
||||
|
||||
_WORKTREE_QUERY_MARKERS = (
|
||||
"code",
|
||||
"repo",
|
||||
"repository",
|
||||
"git",
|
||||
"worktree",
|
||||
"branch",
|
||||
"patch",
|
||||
"diff",
|
||||
"refactor",
|
||||
"build",
|
||||
"test",
|
||||
"fix",
|
||||
"file",
|
||||
"files",
|
||||
"python",
|
||||
"typescript",
|
||||
"javascript",
|
||||
"代码",
|
||||
"仓库",
|
||||
"分支",
|
||||
"补丁",
|
||||
"重构",
|
||||
"构建",
|
||||
"测试",
|
||||
"修复",
|
||||
"文件",
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IsolationDecision:
|
||||
mode: IsolationMode
|
||||
reason: str
|
||||
tool_names: tuple[str, ...] = ()
|
||||
capability_ids: tuple[str, ...] = ()
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _capability_metadata(capability: CapabilityManifest | None) -> dict[str, Any]:
|
||||
if capability is None:
|
||||
return {}
|
||||
return {
|
||||
"capability_id": capability.capability_id,
|
||||
"tool_name": capability.tool_name,
|
||||
"permission_class": capability.permission_class.value,
|
||||
"side_effect_scope": capability.side_effect_scope.value,
|
||||
"supports_retry": capability.supports_retry,
|
||||
"idempotent": capability.idempotent,
|
||||
"safe_for_parallel_use": capability.safe_for_parallel_use,
|
||||
"requires_confirmation": capability.requires_confirmation,
|
||||
}
|
||||
|
||||
|
||||
def select_isolation_strategy(
|
||||
*,
|
||||
user_query: str,
|
||||
tool_names: list[str] | tuple[str, ...],
|
||||
role_value: str,
|
||||
execution_mode: str | None,
|
||||
) -> IsolationDecision:
|
||||
indexes = load_builtin_registry_indexes()
|
||||
capabilities: list[CapabilityManifest] = []
|
||||
capability_ids: list[str] = []
|
||||
|
||||
for tool_name in tool_names:
|
||||
capability_id = indexes.capability_id_by_tool_name.get(tool_name)
|
||||
capability = indexes.capability_by_id.get(capability_id) if capability_id else None
|
||||
if capability is not None:
|
||||
capabilities.append(capability)
|
||||
capability_ids.append(capability.capability_id)
|
||||
|
||||
normalized_query = (user_query or "").strip().lower()
|
||||
has_worktree_query_signal = any(marker in normalized_query for marker in _WORKTREE_QUERY_MARKERS)
|
||||
has_write_capability = any(cap.permission_class == PermissionClass.WRITE for cap in capabilities)
|
||||
has_external_capability = any(cap.permission_class == PermissionClass.EXTERNAL for cap in capabilities)
|
||||
has_non_parallel_capability = any(not cap.safe_for_parallel_use for cap in capabilities)
|
||||
has_stateful_side_effect = any(
|
||||
cap.side_effect_scope in {SideEffectScope.LOCAL_STATE, SideEffectScope.DB_WRITE}
|
||||
for cap in capabilities
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"role": role_value,
|
||||
"execution_mode": execution_mode,
|
||||
"capabilities": [_capability_metadata(capability) for capability in capabilities],
|
||||
"workspace_strategy": "inline",
|
||||
"risk_level": "low",
|
||||
}
|
||||
|
||||
if has_worktree_query_signal:
|
||||
return IsolationDecision(
|
||||
mode="worktree",
|
||||
reason="workspace_mutation_signals_detected",
|
||||
tool_names=tuple(tool_names),
|
||||
capability_ids=tuple(capability_ids),
|
||||
metadata={
|
||||
**metadata,
|
||||
"workspace_strategy": "ephemeral_worktree",
|
||||
"risk_level": "high",
|
||||
},
|
||||
)
|
||||
|
||||
if has_write_capability or has_stateful_side_effect or has_non_parallel_capability:
|
||||
return IsolationDecision(
|
||||
mode="session",
|
||||
reason="stateful_or_non_parallel_tooling",
|
||||
tool_names=tuple(tool_names),
|
||||
capability_ids=tuple(capability_ids),
|
||||
metadata={
|
||||
**metadata,
|
||||
"workspace_strategy": "isolated_session",
|
||||
"risk_level": "medium",
|
||||
},
|
||||
)
|
||||
|
||||
if execution_mode == "collaboration" or role_value in {"analyst", "librarian"} or has_external_capability:
|
||||
return IsolationDecision(
|
||||
mode="session",
|
||||
reason="context_heavy_or_external_retrieval",
|
||||
tool_names=tuple(tool_names),
|
||||
capability_ids=tuple(capability_ids),
|
||||
metadata={
|
||||
**metadata,
|
||||
"workspace_strategy": "isolated_session",
|
||||
"risk_level": "medium",
|
||||
},
|
||||
)
|
||||
|
||||
return IsolationDecision(
|
||||
mode="none",
|
||||
reason="inline_execution_is_sufficient",
|
||||
tool_names=tuple(tool_names),
|
||||
capability_ids=tuple(capability_ids),
|
||||
metadata=metadata,
|
||||
)
|
||||
83
backend/app/agents/isolation/worktree_isolation.py
Normal file
83
backend/app/agents/isolation/worktree_isolation.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from app.agents.isolation.strategy_selector import IsolationDecision
|
||||
|
||||
|
||||
class WorktreeIsolationError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _slugify(value: str, *, fallback: str) -> str:
|
||||
slug = re.sub(r"[^a-zA-Z0-9._-]+", "-", (value or "").strip()).strip("-").lower()
|
||||
return slug or fallback
|
||||
|
||||
|
||||
def _resolve_git_root() -> Path:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "--show-toplevel"],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
except subprocess.CalledProcessError as exc:
|
||||
raise WorktreeIsolationError(exc.stderr.strip() or exc.stdout.strip() or "git_root_unavailable") from exc
|
||||
git_root = Path(result.stdout.strip())
|
||||
if not git_root.exists():
|
||||
raise WorktreeIsolationError("git_root_not_found")
|
||||
return git_root
|
||||
|
||||
|
||||
def prepare_worktree_isolation(
|
||||
*,
|
||||
state: dict[str, Any],
|
||||
decision: IsolationDecision,
|
||||
role_value: str,
|
||||
sub_commander: str,
|
||||
create_workspace: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
isolation_id = f"worktree-{uuid4().hex[:8]}"
|
||||
conversation_slug = _slugify(str(state.get("conversation_id") or "conversation"), fallback="conversation")
|
||||
role_slug = _slugify(role_value, fallback="agent")
|
||||
git_root = _resolve_git_root()
|
||||
workspace_root = git_root / ".worktrees" / "jarvis" / conversation_slug
|
||||
workspace_path = workspace_root / f"{role_slug}-{isolation_id}"
|
||||
branch = f"jarvis/{conversation_slug}/{role_slug}-{isolation_id}"
|
||||
|
||||
if create_workspace and not workspace_path.exists():
|
||||
workspace_root.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "-C", str(git_root), "worktree", "add", "-b", branch, str(workspace_path), "HEAD"],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
except subprocess.CalledProcessError as exc:
|
||||
raise WorktreeIsolationError(exc.stderr.strip() or exc.stdout.strip() or "worktree_add_failed") from exc
|
||||
|
||||
return {
|
||||
"mode": "worktree",
|
||||
"isolation_id": isolation_id,
|
||||
"workspace_path": str(workspace_path),
|
||||
"parent_conversation_id": str(state.get("conversation_id") or "") or None,
|
||||
"metadata": {
|
||||
**dict(decision.metadata or {}),
|
||||
"reason": decision.reason,
|
||||
"role": role_value,
|
||||
"sub_commander": sub_commander,
|
||||
"tool_names": list(decision.tool_names),
|
||||
"capability_ids": list(decision.capability_ids),
|
||||
"repo_root": str(git_root),
|
||||
"branch": branch,
|
||||
"workspace_strategy": "ephemeral_worktree",
|
||||
"cleanup_status": "pending",
|
||||
"materialized": workspace_path.exists(),
|
||||
},
|
||||
}
|
||||
19
backend/app/agents/learning/__init__.py
Normal file
19
backend/app/agents/learning/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from app.agents.learning.jobs import persist_retrospective, schedule_retrospective_job
|
||||
from app.agents.learning.pattern_miner import LearningPatternMiner
|
||||
from app.agents.learning.retrospector import build_session_retrospective
|
||||
from app.agents.learning.session_search import SessionRetrospectiveSearch
|
||||
from app.agents.learning.signal_extractor import RetrospectiveSignalExtractor
|
||||
from app.agents.learning.skill_candidate_builder import SkillCandidateBuilder
|
||||
from app.agents.learning.store import LearningArtifactStore, SessionRetrospectiveStore
|
||||
|
||||
__all__ = [
|
||||
"build_session_retrospective",
|
||||
"LearningArtifactStore",
|
||||
"LearningPatternMiner",
|
||||
"persist_retrospective",
|
||||
"RetrospectiveSignalExtractor",
|
||||
"schedule_retrospective_job",
|
||||
"SessionRetrospectiveSearch",
|
||||
"SessionRetrospectiveStore",
|
||||
"SkillCandidateBuilder",
|
||||
]
|
||||
16
backend/app/agents/learning/audit.py
Normal file
16
backend/app/agents/learning/audit.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.schemas.learning import LearningDecision, SessionRetrospective
|
||||
|
||||
|
||||
def build_learning_audit_entry(retrospective: SessionRetrospective) -> dict[str, object]:
|
||||
decision = retrospective.learning_decision
|
||||
return {
|
||||
"retrospective_id": retrospective.retrospective_id,
|
||||
"decision": decision.decision if isinstance(decision, LearningDecision) else None,
|
||||
"explanation": decision.explanation if isinstance(decision, LearningDecision) else None,
|
||||
"signal_count": len(retrospective.learning_signals),
|
||||
"pattern_count": len(retrospective.pattern_candidates),
|
||||
"skill_candidate_count": len(retrospective.skill_candidates),
|
||||
"outcome": retrospective.outcome,
|
||||
}
|
||||
45
backend/app/agents/learning/bridge.py
Normal file
45
backend/app/agents/learning/bridge.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.schemas.learning import LearningDecision, LearningSignal
|
||||
|
||||
|
||||
def route_learning_signal(signal: LearningSignal) -> str:
|
||||
if signal.signal_type == "preference":
|
||||
return "memory"
|
||||
if signal.signal_type in {"workflow", "decomposition", "tool_success"}:
|
||||
return "skill"
|
||||
if signal.signal_type == "correction":
|
||||
return "audit"
|
||||
return "memory"
|
||||
|
||||
|
||||
def build_learning_bridge_summary(signals: list[LearningSignal]) -> dict[str, object]:
|
||||
memory_count = 0
|
||||
skill_count = 0
|
||||
audit_count = 0
|
||||
|
||||
for signal in signals:
|
||||
route = route_learning_signal(signal)
|
||||
if route == "memory":
|
||||
memory_count += 1
|
||||
elif route == "skill":
|
||||
skill_count += 1
|
||||
else:
|
||||
audit_count += 1
|
||||
|
||||
return {
|
||||
"memory_signal_count": memory_count,
|
||||
"skill_signal_count": skill_count,
|
||||
"audit_signal_count": audit_count,
|
||||
}
|
||||
|
||||
|
||||
def update_learning_decision_with_bridge(
|
||||
decision: LearningDecision,
|
||||
signals: list[LearningSignal],
|
||||
) -> LearningDecision:
|
||||
bridge_summary = build_learning_bridge_summary(signals)
|
||||
metadata = dict(decision.metadata or {})
|
||||
metadata["bridge"] = bridge_summary
|
||||
decision.metadata = metadata
|
||||
return decision
|
||||
222
backend/app/agents/learning/jobs.py
Normal file
222
backend/app/agents/learning/jobs.py
Normal file
@@ -0,0 +1,222 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.config import settings
|
||||
from app.database import async_session
|
||||
from app.agents.learning.bridge import update_learning_decision_with_bridge
|
||||
from app.agents.learning.pattern_miner import LearningPatternMiner
|
||||
from app.agents.learning.audit import build_learning_audit_entry
|
||||
from app.agents.learning.retrospector import build_session_retrospective
|
||||
from app.agents.learning.signal_extractor import RetrospectiveSignalExtractor
|
||||
from app.agents.learning.skill_candidate_builder import SkillCandidateBuilder
|
||||
from app.agents.learning.store import LearningArtifactStore, SessionRetrospectiveStore
|
||||
from app.agents.schemas.learning import LearningDecision, SessionRetrospective
|
||||
from app.agents.skills.evaluator import SkillPromotionEvaluator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _enrich_retrospective(retrospective: SessionRetrospective) -> SessionRetrospective:
|
||||
signals = RetrospectiveSignalExtractor().extract(retrospective)
|
||||
patterns = LearningPatternMiner().mine(signals)
|
||||
skill_candidates = SkillCandidateBuilder().build(patterns)
|
||||
|
||||
decision = LearningDecision(
|
||||
decision="create_candidate" if skill_candidates else ("reinforce_memory" if signals else "defer"),
|
||||
explanation=(
|
||||
"Retrospective produced reusable candidate skills."
|
||||
if skill_candidates
|
||||
else "Retrospective only reinforces memory-like evidence."
|
||||
if signals
|
||||
else "No stable signal was extracted from this retrospective."
|
||||
),
|
||||
evidence_refs=(skill_candidates[0].evidence_refs if skill_candidates else retrospective.evidence_refs[:3]),
|
||||
metadata={
|
||||
"signal_count": len(signals),
|
||||
"pattern_count": len(patterns),
|
||||
"skill_candidate_count": len(skill_candidates),
|
||||
},
|
||||
)
|
||||
|
||||
retrospective.learning_signals = signals
|
||||
retrospective.pattern_candidates = patterns
|
||||
retrospective.skill_candidates = skill_candidates
|
||||
retrospective.learning_decision = update_learning_decision_with_bridge(decision, signals)
|
||||
return retrospective
|
||||
|
||||
|
||||
def _build_learning_artifacts(retrospective: SessionRetrospective) -> list[dict[str, object]]:
|
||||
artifacts: list[dict[str, object]] = []
|
||||
for signal in retrospective.learning_signals:
|
||||
artifacts.append(
|
||||
{
|
||||
"artifact_type": "signal",
|
||||
"artifact_key": signal.signal_type,
|
||||
"summary_text": signal.explanation or signal.signal_type,
|
||||
"payload": signal.model_dump(mode="json"),
|
||||
}
|
||||
)
|
||||
for pattern in retrospective.pattern_candidates:
|
||||
artifacts.append(
|
||||
{
|
||||
"artifact_type": "pattern_candidate",
|
||||
"artifact_key": pattern.pattern_type,
|
||||
"summary_text": pattern.description,
|
||||
"payload": pattern.model_dump(mode="json"),
|
||||
}
|
||||
)
|
||||
for candidate in retrospective.skill_candidates:
|
||||
artifacts.append(
|
||||
{
|
||||
"artifact_type": "skill_candidate",
|
||||
"artifact_key": candidate.name,
|
||||
"summary_text": candidate.summary,
|
||||
"payload": candidate.model_dump(mode="json"),
|
||||
}
|
||||
)
|
||||
if retrospective.learning_decision is not None:
|
||||
artifacts.append(
|
||||
{
|
||||
"artifact_type": "learning_decision",
|
||||
"artifact_key": retrospective.learning_decision.decision,
|
||||
"summary_text": retrospective.learning_decision.explanation,
|
||||
"payload": retrospective.learning_decision.model_dump(mode="json"),
|
||||
}
|
||||
)
|
||||
artifacts.append(
|
||||
{
|
||||
"artifact_type": "learning_audit",
|
||||
"artifact_key": retrospective.retrospective_id or "retrospective",
|
||||
"summary_text": retrospective.learning_decision.explanation,
|
||||
"payload": build_learning_audit_entry(retrospective),
|
||||
}
|
||||
)
|
||||
return artifacts
|
||||
|
||||
|
||||
def _build_lifecycle_artifacts(decisions: list) -> list[dict[str, object]]:
|
||||
artifacts: list[dict[str, object]] = []
|
||||
for decision in decisions:
|
||||
artifacts.append(
|
||||
{
|
||||
"artifact_type": "skill_lifecycle_decision",
|
||||
"artifact_key": getattr(decision, "skill_name", None) or "skill",
|
||||
"summary_text": getattr(decision, "reason", ""),
|
||||
"payload": decision.model_dump(mode="json"),
|
||||
}
|
||||
)
|
||||
return artifacts
|
||||
|
||||
|
||||
async def persist_retrospective(
|
||||
*,
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
request_message_id: str | None,
|
||||
response_message_id: str | None,
|
||||
query_text: str,
|
||||
final_response: str | None,
|
||||
state: dict[str, Any] | None,
|
||||
) -> None:
|
||||
retrospective = build_session_retrospective(
|
||||
request_id=response_message_id or request_message_id or conversation_id,
|
||||
session_id=conversation_id,
|
||||
user_query=query_text,
|
||||
state=state,
|
||||
runtime_context={"user_id": user_id},
|
||||
)
|
||||
retrospective = _enrich_retrospective(retrospective)
|
||||
|
||||
async with async_session() as session:
|
||||
saved = await SessionRetrospectiveStore(session).save(retrospective)
|
||||
lifecycle_decisions = []
|
||||
if settings.ENABLE_SKILL_PROMOTION:
|
||||
lifecycle_decisions = await SkillPromotionEvaluator(session).sync_retrospective(
|
||||
user_id=user_id,
|
||||
retrospective=retrospective,
|
||||
)
|
||||
if settings.ENABLE_LEARNING_SIGNALS:
|
||||
await LearningArtifactStore(session).save_batch(
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
retrospective_id=saved.id,
|
||||
artifacts=[
|
||||
*_build_learning_artifacts(retrospective),
|
||||
*_build_lifecycle_artifacts(lifecycle_decisions),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def schedule_retrospective_job(**kwargs) -> asyncio.Task[None] | None:
|
||||
if not settings.ENABLE_RETROSPECTIVE:
|
||||
return None
|
||||
try:
|
||||
task = asyncio.create_task(persist_retrospective(**kwargs))
|
||||
except RuntimeError:
|
||||
return None
|
||||
|
||||
def _handle_completion(done_task: asyncio.Task[None]) -> None:
|
||||
try:
|
||||
done_task.result()
|
||||
except Exception:
|
||||
logger.exception("retrospective_job_failed")
|
||||
|
||||
task.add_done_callback(_handle_completion)
|
||||
return task
|
||||
|
||||
|
||||
def schedule_retrospective_learning_event(
|
||||
*,
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
retrospective: SessionRetrospective,
|
||||
session_factory=async_session,
|
||||
) -> asyncio.Task[None] | None:
|
||||
if not settings.ENABLE_RETROSPECTIVE:
|
||||
return None
|
||||
|
||||
async def _persist_existing() -> None:
|
||||
async with session_factory() as session:
|
||||
enriched = _enrich_retrospective(retrospective)
|
||||
saved = await SessionRetrospectiveStore(session).save(enriched)
|
||||
lifecycle_decisions = []
|
||||
if settings.ENABLE_SKILL_PROMOTION:
|
||||
lifecycle_decisions = await SkillPromotionEvaluator(session).sync_retrospective(
|
||||
user_id=user_id,
|
||||
retrospective=enriched,
|
||||
)
|
||||
if settings.ENABLE_LEARNING_SIGNALS:
|
||||
await LearningArtifactStore(session).save_batch(
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
retrospective_id=saved.id,
|
||||
artifacts=[
|
||||
*_build_learning_artifacts(enriched),
|
||||
*_build_lifecycle_artifacts(lifecycle_decisions),
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
task = asyncio.create_task(_persist_existing())
|
||||
except RuntimeError:
|
||||
return None
|
||||
|
||||
def _handle_completion(done_task: asyncio.Task[None]) -> None:
|
||||
try:
|
||||
done_task.result()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"retrospective_learning_event_failed",
|
||||
extra={
|
||||
"details": {
|
||||
"user_id": user_id,
|
||||
"conversation_id": conversation_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
task.add_done_callback(_handle_completion)
|
||||
return task
|
||||
42
backend/app/agents/learning/pattern_miner.py
Normal file
42
backend/app/agents/learning/pattern_miner.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from app.agents.schemas.learning import LearningSignal, PatternCandidate
|
||||
|
||||
|
||||
class LearningPatternMiner:
|
||||
def mine(self, signals: list[LearningSignal]) -> list[PatternCandidate]:
|
||||
patterns: list[PatternCandidate] = []
|
||||
|
||||
for signal in signals:
|
||||
if signal.signal_type not in {"workflow", "decomposition", "preference"}:
|
||||
continue
|
||||
|
||||
description = self._build_description(signal)
|
||||
patterns.append(
|
||||
PatternCandidate(
|
||||
pattern_id=f"pattern-{uuid4().hex[:10]}",
|
||||
pattern_type=signal.signal_type,
|
||||
description=description,
|
||||
confidence=signal.confidence,
|
||||
evidence_refs=signal.evidence_refs[:4],
|
||||
)
|
||||
)
|
||||
|
||||
return patterns
|
||||
|
||||
@staticmethod
|
||||
def _build_description(signal: LearningSignal) -> str:
|
||||
payload = signal.payload or {}
|
||||
if signal.signal_type == "workflow":
|
||||
task_type = payload.get("task_type") or "general"
|
||||
execution_mode = payload.get("execution_mode") or "direct"
|
||||
return f"Completed {task_type} requests worked under {execution_mode} execution."
|
||||
if signal.signal_type == "decomposition":
|
||||
task_count = payload.get("task_count") or 0
|
||||
return f"Requests with {task_count} concrete task refs benefit from structured decomposition."
|
||||
if signal.signal_type == "preference":
|
||||
preference = payload.get("preference") or "structured response"
|
||||
return f"User preference repeatedly points to {preference}."
|
||||
return signal.explanation or signal.signal_type
|
||||
115
backend/app/agents/learning/retrospector.py
Normal file
115
backend/app/agents/learning/retrospector.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.schemas.learning import SessionRetrospective
|
||||
|
||||
|
||||
def _classify_task_type(query_text: str) -> str:
|
||||
normalized = (query_text or "").lower()
|
||||
if any(token in normalized for token in ("总结", "分析", "对比", "report", "analyze")):
|
||||
return "analysis"
|
||||
if any(token in normalized for token in ("安排", "提醒", "日程", "todo", "task")):
|
||||
return "planning_or_execution"
|
||||
if any(token in normalized for token in ("文档", "资料", "年报", "search", "查")):
|
||||
return "retrieval"
|
||||
return "general"
|
||||
|
||||
|
||||
def build_session_retrospective(
|
||||
*,
|
||||
request_id: str,
|
||||
session_id: str,
|
||||
user_query: str,
|
||||
state: dict[str, Any] | None,
|
||||
runtime_context: dict[str, Any] | None = None,
|
||||
) -> SessionRetrospective:
|
||||
state = state or {}
|
||||
if hasattr(runtime_context, "model_dump"):
|
||||
runtime_context = runtime_context.model_dump(mode="json")
|
||||
runtime_context = runtime_context or {}
|
||||
skill_shortlist = state.get("skill_shortlist") or []
|
||||
used_skill_names = [
|
||||
item.get("skill_name")
|
||||
for item in skill_shortlist
|
||||
if isinstance(item, dict) and item.get("skill_name")
|
||||
]
|
||||
|
||||
task_refs = []
|
||||
for task in (state.get("completed_tasks") or [])[:4]:
|
||||
if isinstance(task, dict):
|
||||
task_refs.append(
|
||||
{
|
||||
"task_id": task.get("task_id"),
|
||||
"title": task.get("title"),
|
||||
"status": task.get("status"),
|
||||
}
|
||||
)
|
||||
|
||||
event_refs = []
|
||||
for event in (state.get("event_trace") or [])[:8]:
|
||||
if isinstance(event, dict):
|
||||
event_refs.append(
|
||||
{
|
||||
"event_type": event.get("event_type"),
|
||||
"task_id": event.get("task_id"),
|
||||
"agent_id": event.get("agent_id"),
|
||||
}
|
||||
)
|
||||
|
||||
verification_evidence = []
|
||||
for evidence in (state.get("verification_evidence") or [])[:6]:
|
||||
if isinstance(evidence, dict):
|
||||
verification_evidence.append(evidence)
|
||||
|
||||
verification_status = state.get("verification_status")
|
||||
execution_mode = state.get("execution_mode")
|
||||
primary_agent = state.get("current_agent") or "master"
|
||||
retrospective_shortlist = state.get("retrospective_shortlist") or []
|
||||
|
||||
summary_parts = [
|
||||
f"本轮请求按 {execution_mode or 'unknown'} 模式处理",
|
||||
f"主要负责 agent 为 {primary_agent}",
|
||||
]
|
||||
if verification_status:
|
||||
summary_parts.append(f"验证结果为 {verification_status}")
|
||||
if used_skill_names:
|
||||
summary_parts.append(f"命中技能候选 {', '.join(used_skill_names[:3])}")
|
||||
if retrospective_shortlist:
|
||||
summary_parts.append(f"参考了 {len(retrospective_shortlist)} 条历史复盘")
|
||||
|
||||
final_response = state.get("final_response")
|
||||
outcome = "completed" if final_response else "failed"
|
||||
if not final_response and verification_status == "passed":
|
||||
outcome = "completed"
|
||||
if final_response and verification_status == "skipped":
|
||||
outcome = "partial"
|
||||
|
||||
return SessionRetrospective(
|
||||
retrospective_id=request_id,
|
||||
user_id=str(runtime_context.get("user_id") or ""),
|
||||
conversation_id=session_id,
|
||||
response_message_id=request_id,
|
||||
query_text=user_query,
|
||||
final_response=final_response,
|
||||
summary=";".join(summary_parts) + "。",
|
||||
task_type=_classify_task_type(user_query),
|
||||
execution_mode=execution_mode,
|
||||
primary_agent=primary_agent,
|
||||
verification_status=verification_status,
|
||||
verification_summary=state.get("verification_summary"),
|
||||
used_skill_names=used_skill_names,
|
||||
evidence_refs=verification_evidence,
|
||||
task_refs=task_refs,
|
||||
event_refs=event_refs,
|
||||
context_snapshot={
|
||||
"runtime_request_context": runtime_context,
|
||||
"recommended_runtime_mode": runtime_context.get("recommended_runtime_mode"),
|
||||
"parallel_worthiness": state.get("parallel_worthiness"),
|
||||
"retrospective_shortlist_count": len(retrospective_shortlist),
|
||||
"scheduled_subtask_count": len(state.get("scheduled_subtasks") or []),
|
||||
"merge_report": dict(state.get("merge_report") or {}),
|
||||
"verification_report": dict(state.get("verification_report") or {}),
|
||||
},
|
||||
outcome=outcome,
|
||||
)
|
||||
95
backend/app/agents/learning/session_search.py
Normal file
95
backend/app/agents/learning/session_search.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.schemas.learning import SessionRetrospective
|
||||
from app.agents.skills.matcher import score_text_match
|
||||
from app.agents.learning.store import SessionRetrospectiveStore
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class SessionRetrospectiveSearch:
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
|
||||
async def shortlist(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
query_text: str,
|
||||
conversation_id: str | None = None,
|
||||
task_type: str | None = None,
|
||||
skill_name: str | None = None,
|
||||
limit: int = 3,
|
||||
) -> list[SessionRetrospective]:
|
||||
records = await SessionRetrospectiveStore(self.db).list_recent(user_id=user_id, limit=25)
|
||||
scored: list[tuple[float, SessionRetrospective]] = []
|
||||
|
||||
for record in records:
|
||||
if task_type and record.task_type != task_type:
|
||||
continue
|
||||
if skill_name and skill_name not in (record.skill_names or []):
|
||||
continue
|
||||
score, _matched_terms = score_text_match(
|
||||
query_text,
|
||||
record.query_text,
|
||||
record.summary_text,
|
||||
" ".join(record.skill_names or []),
|
||||
)
|
||||
if conversation_id and record.conversation_id == conversation_id:
|
||||
score = min(1.0, score + 0.1)
|
||||
if score <= 0:
|
||||
continue
|
||||
|
||||
payload = dict(record.payload or {})
|
||||
payload["retrospective_id"] = record.id
|
||||
retrospective = SessionRetrospective.model_validate(payload)
|
||||
scored.append((score, retrospective))
|
||||
|
||||
scored.sort(key=lambda item: item[0], reverse=True)
|
||||
return [item for _score, item in scored[:limit]]
|
||||
|
||||
|
||||
async def search_recent_retrospectives(
|
||||
db,
|
||||
*,
|
||||
user_id: str,
|
||||
query: str,
|
||||
conversation_id: str | None = None,
|
||||
task_type: str | None = None,
|
||||
skill_name: str | None = None,
|
||||
limit: int = 3,
|
||||
) -> list[SessionRetrospective]:
|
||||
if not settings.ENABLE_SESSION_RETROSPECTIVE_SEARCH:
|
||||
return []
|
||||
return await SessionRetrospectiveSearch(db).shortlist(
|
||||
user_id=user_id,
|
||||
query_text=query,
|
||||
conversation_id=conversation_id,
|
||||
task_type=task_type,
|
||||
skill_name=skill_name,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
def summarize_retrospective(retrospective: SessionRetrospective) -> dict[str, object]:
|
||||
verification_status = retrospective.verification_status or retrospective.outcome
|
||||
success_score = 1.0 if verification_status == "passed" else 0.6 if verification_status == "skipped" else 0.2
|
||||
reusable_patterns = []
|
||||
if retrospective.used_skill_names:
|
||||
reusable_patterns.append("skill_shortlist_hit")
|
||||
if retrospective.execution_mode:
|
||||
reusable_patterns.append(f"mode:{retrospective.execution_mode}")
|
||||
|
||||
avoid_patterns = []
|
||||
if retrospective.outcome == "failed":
|
||||
avoid_patterns.append("failed_outcome")
|
||||
|
||||
return {
|
||||
"retrospective_id": retrospective.retrospective_id,
|
||||
"task_type": retrospective.task_type,
|
||||
"request_summary": retrospective.query_text[:120],
|
||||
"summary": retrospective.summary,
|
||||
"execution_mode": retrospective.execution_mode,
|
||||
"success_score": round(success_score, 2),
|
||||
"reusable_patterns": reusable_patterns,
|
||||
"avoid_patterns": avoid_patterns,
|
||||
}
|
||||
72
backend/app/agents/learning/signal_extractor.py
Normal file
72
backend/app/agents/learning/signal_extractor.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.schemas.learning import LearningSignal, SessionRetrospective
|
||||
|
||||
|
||||
class RetrospectiveSignalExtractor:
|
||||
def extract(self, retrospective: SessionRetrospective) -> list[LearningSignal]:
|
||||
signals: list[LearningSignal] = []
|
||||
|
||||
if retrospective.outcome == "completed":
|
||||
signals.append(
|
||||
LearningSignal(
|
||||
signal_type="workflow",
|
||||
confidence=0.8,
|
||||
evidence_refs=retrospective.evidence_refs[:3],
|
||||
explanation="Completed runs can be mined as workflow hints later.",
|
||||
payload={
|
||||
"task_type": retrospective.task_type,
|
||||
"execution_mode": retrospective.execution_mode,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
if len(retrospective.task_refs) > 1:
|
||||
context_snapshot = retrospective.context_snapshot or {}
|
||||
merge_report = dict(context_snapshot.get("merge_report") or {})
|
||||
verification_report = dict(context_snapshot.get("verification_report") or {})
|
||||
effectiveness_score = 1.0
|
||||
if merge_report.get("status") == "conflicted":
|
||||
effectiveness_score = 0.45
|
||||
elif merge_report.get("status") == "fallback":
|
||||
effectiveness_score = 0.25
|
||||
elif verification_report.get("status") == "failed":
|
||||
effectiveness_score = 0.3
|
||||
signals.append(
|
||||
LearningSignal(
|
||||
signal_type="decomposition",
|
||||
confidence=0.7,
|
||||
evidence_refs=retrospective.task_refs[:3],
|
||||
explanation="Multiple completed task refs indicate a decomposition pattern.",
|
||||
payload={
|
||||
"task_count": len(retrospective.task_refs),
|
||||
"scheduled_subtask_count": context_snapshot.get("scheduled_subtask_count", 0),
|
||||
"effectiveness_score": effectiveness_score,
|
||||
"merge_status": merge_report.get("status"),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
if retrospective.used_skill_names:
|
||||
signals.append(
|
||||
LearningSignal(
|
||||
signal_type="tool_success",
|
||||
confidence=0.65 if retrospective.outcome == "completed" else 0.35,
|
||||
evidence_refs=retrospective.evidence_refs[:2],
|
||||
explanation="Task-scoped skill shortlist was available during this run.",
|
||||
payload={"skills": retrospective.used_skill_names[:3]},
|
||||
)
|
||||
)
|
||||
|
||||
if retrospective.outcome == "failed":
|
||||
signals.append(
|
||||
LearningSignal(
|
||||
signal_type="correction",
|
||||
confidence=0.75,
|
||||
evidence_refs=retrospective.evidence_refs[:2],
|
||||
explanation="Failed retrospectives should remain auditable before any promotion.",
|
||||
payload={"verification_status": retrospective.verification_status},
|
||||
)
|
||||
)
|
||||
|
||||
return signals
|
||||
54
backend/app/agents/learning/skill_candidate_builder.py
Normal file
54
backend/app/agents/learning/skill_candidate_builder.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
|
||||
from app.agents.schemas.learning import PatternCandidate, SkillCandidate
|
||||
|
||||
|
||||
class SkillCandidateBuilder:
|
||||
def build(self, patterns: list[PatternCandidate]) -> list[SkillCandidate]:
|
||||
candidates: list[SkillCandidate] = []
|
||||
|
||||
for pattern in patterns:
|
||||
if pattern.confidence < 0.55:
|
||||
continue
|
||||
|
||||
name = self._build_name(pattern)
|
||||
candidates.append(
|
||||
SkillCandidate(
|
||||
candidate_id=f"candidate-{self._stable_suffix(pattern)}",
|
||||
name=name,
|
||||
summary=pattern.description,
|
||||
candidate_type=self._map_candidate_type(pattern.pattern_type),
|
||||
source_pattern_ids=[pattern.pattern_id],
|
||||
confidence=pattern.confidence,
|
||||
evidence_refs=pattern.evidence_refs[:4],
|
||||
recommended_status="candidate",
|
||||
)
|
||||
)
|
||||
|
||||
return candidates
|
||||
|
||||
@staticmethod
|
||||
def _build_name(pattern: PatternCandidate) -> str:
|
||||
prefix = {
|
||||
"workflow": "workflow",
|
||||
"decomposition": "decomposition",
|
||||
"preference": "preference",
|
||||
}.get(pattern.pattern_type, "learned")
|
||||
stable_suffix = SkillCandidateBuilder._stable_suffix(pattern)
|
||||
return f"{prefix}-{stable_suffix}"
|
||||
|
||||
@staticmethod
|
||||
def _map_candidate_type(pattern_type: str) -> str:
|
||||
mapping = {
|
||||
"workflow": "workflow_skill",
|
||||
"decomposition": "decomposition_skill",
|
||||
"preference": "preference_skill",
|
||||
}
|
||||
return mapping.get(pattern_type, "workflow_skill")
|
||||
|
||||
@staticmethod
|
||||
def _stable_suffix(pattern: PatternCandidate) -> str:
|
||||
raw = f"{pattern.pattern_type}:{pattern.description}".encode("utf-8")
|
||||
return hashlib.sha1(raw).hexdigest()[:10]
|
||||
129
backend/app/agents/learning/store.py
Normal file
129
backend/app/agents/learning/store.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.schemas.learning import SessionRetrospective
|
||||
from app.models.learning import LearningArtifactRecord, SessionRetrospectiveRecord
|
||||
|
||||
|
||||
class SessionRetrospectiveStore:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def save(self, retrospective: SessionRetrospective) -> SessionRetrospectiveRecord:
|
||||
payload = retrospective.model_dump(mode="json")
|
||||
record = SessionRetrospectiveRecord(
|
||||
user_id=retrospective.user_id,
|
||||
conversation_id=retrospective.conversation_id,
|
||||
request_message_id=retrospective.request_message_id,
|
||||
response_message_id=retrospective.response_message_id,
|
||||
query_text=retrospective.query_text,
|
||||
final_response=retrospective.final_response,
|
||||
summary_text=retrospective.summary,
|
||||
task_type=retrospective.task_type,
|
||||
execution_mode=retrospective.execution_mode,
|
||||
primary_agent=retrospective.primary_agent,
|
||||
verification_status=retrospective.verification_status,
|
||||
verification_summary=retrospective.verification_summary,
|
||||
skill_names=retrospective.used_skill_names,
|
||||
evidence=retrospective.evidence_refs,
|
||||
task_refs=retrospective.task_refs,
|
||||
payload=payload,
|
||||
)
|
||||
self.db.add(record)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(record)
|
||||
return record
|
||||
|
||||
async def list_recent(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
limit: int = 20,
|
||||
) -> list[SessionRetrospectiveRecord]:
|
||||
result = await self.db.execute(
|
||||
select(SessionRetrospectiveRecord)
|
||||
.where(SessionRetrospectiveRecord.user_id == user_id)
|
||||
.order_by(desc(SessionRetrospectiveRecord.recorded_at), desc(SessionRetrospectiveRecord.created_at))
|
||||
.limit(limit)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
class LearningArtifactStore:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def save_batch(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
retrospective_id: str | None,
|
||||
artifacts: list[dict[str, object]],
|
||||
) -> list[LearningArtifactRecord]:
|
||||
records: list[LearningArtifactRecord] = []
|
||||
for artifact in artifacts:
|
||||
record = LearningArtifactRecord(
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
retrospective_id=retrospective_id,
|
||||
artifact_type=str(artifact.get("artifact_type") or "unknown"),
|
||||
artifact_key=str(artifact.get("artifact_key") or "") or None,
|
||||
summary_text=str(artifact.get("summary_text") or ""),
|
||||
payload=dict(artifact.get("payload") or {}),
|
||||
)
|
||||
self.db.add(record)
|
||||
records.append(record)
|
||||
|
||||
await self.db.commit()
|
||||
for record in records:
|
||||
await self.db.refresh(record)
|
||||
return records
|
||||
|
||||
async def list_recent(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
artifact_type: str | None = None,
|
||||
limit: int = 50,
|
||||
) -> list[LearningArtifactRecord]:
|
||||
query = select(LearningArtifactRecord).where(LearningArtifactRecord.user_id == user_id)
|
||||
if artifact_type:
|
||||
query = query.where(LearningArtifactRecord.artifact_type == artifact_type)
|
||||
result = await self.db.execute(
|
||||
query.order_by(
|
||||
desc(LearningArtifactRecord.recorded_at),
|
||||
desc(LearningArtifactRecord.created_at),
|
||||
).limit(limit)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def aggregate_counts_by_key(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
artifact_type: str,
|
||||
limit: int = 100,
|
||||
) -> dict[str, int]:
|
||||
records = await self.list_recent(user_id=user_id, artifact_type=artifact_type, limit=limit)
|
||||
counts: dict[str, int] = {}
|
||||
for record in records:
|
||||
key = record.artifact_key or "unknown"
|
||||
counts[key] = counts.get(key, 0) + 1
|
||||
return counts
|
||||
|
||||
|
||||
def append_retrospective_attachment(
|
||||
attachments: list[dict] | None,
|
||||
retrospective: SessionRetrospective,
|
||||
) -> list[dict]:
|
||||
next_attachments = list(attachments or [])
|
||||
next_attachments.append(
|
||||
{
|
||||
"kind": "session_retrospective",
|
||||
"payload": retrospective.model_dump(mode="json"),
|
||||
}
|
||||
)
|
||||
return next_attachments
|
||||
37
backend/app/agents/orchestration/__init__.py
Normal file
37
backend/app/agents/orchestration/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""高级编排系统 - Phase 10"""
|
||||
|
||||
from app.agents.orchestration.budget import build_subtask_budget
|
||||
from app.agents.orchestration.result_merge import merge_task_results
|
||||
from app.agents.orchestration.scheduler import (
|
||||
ParallelExecutionScheduler,
|
||||
build_subtask_specs,
|
||||
ensure_child_links,
|
||||
)
|
||||
from app.agents.orchestration.subagent_runtime import subtask_spec_to_agent_task
|
||||
from app.agents.team.leader import TeamLeader, TeamTask, TaskStatus
|
||||
from app.agents.transport.remote import RemoteTransport, StructuredMessage
|
||||
from app.agents.orchestration.task_graph import build_bounded_task_graph, render_task_graph_summary
|
||||
from app.agents.background.manager import (
|
||||
BackgroundTaskManager,
|
||||
BackgroundTask,
|
||||
get_background_task_manager,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TeamLeader",
|
||||
"TeamTask",
|
||||
"TaskStatus",
|
||||
"RemoteTransport",
|
||||
"StructuredMessage",
|
||||
"ParallelExecutionScheduler",
|
||||
"build_bounded_task_graph",
|
||||
"build_subtask_budget",
|
||||
"build_subtask_specs",
|
||||
"BackgroundTaskManager",
|
||||
"BackgroundTask",
|
||||
"ensure_child_links",
|
||||
"get_background_task_manager",
|
||||
"merge_task_results",
|
||||
"render_task_graph_summary",
|
||||
"subtask_spec_to_agent_task",
|
||||
]
|
||||
24
backend/app/agents/orchestration/budget.py
Normal file
24
backend/app/agents/orchestration/budget.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.schemas.task import CollaborationBudget
|
||||
|
||||
|
||||
def build_subtask_budget(
|
||||
*,
|
||||
execution_mode: str,
|
||||
max_parallel_tasks: int,
|
||||
max_tool_calls: int = 2,
|
||||
max_iterations: int = 2,
|
||||
metadata: dict | None = None,
|
||||
) -> CollaborationBudget:
|
||||
return CollaborationBudget(
|
||||
mode="collaboration" if execution_mode != "direct" else "direct",
|
||||
max_parallel_tasks=max_parallel_tasks,
|
||||
remaining_parallel_tasks=max_parallel_tasks,
|
||||
max_tool_calls=max_tool_calls,
|
||||
remaining_tool_calls=max_tool_calls,
|
||||
max_iterations=max_iterations,
|
||||
remaining_iterations=max_iterations,
|
||||
escalation_threshold=1,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
31
backend/app/agents/orchestration/monitor.py
Normal file
31
backend/app/agents/orchestration/monitor.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def build_parallel_runtime_metrics(
|
||||
*,
|
||||
task_graph: dict[str, Any] | None,
|
||||
scheduled_subtasks: list[dict[str, Any]] | None,
|
||||
task_results: list[dict[str, Any]] | None,
|
||||
merge_report: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
task_graph = task_graph or {}
|
||||
scheduled_subtasks = list(scheduled_subtasks or [])
|
||||
task_results = list(task_results or [])
|
||||
merge_report = merge_report or {}
|
||||
|
||||
completed = sum(1 for item in task_results if item.get("status") == "completed")
|
||||
failed = sum(1 for item in task_results if item.get("status") == "failed")
|
||||
blocked = sum(1 for item in task_results if item.get("status") == "blocked")
|
||||
|
||||
return {
|
||||
"task_graph_node_count": len(task_graph.get("nodes") or []),
|
||||
"scheduled_subtask_count": len(scheduled_subtasks),
|
||||
"completed_subtask_count": completed,
|
||||
"failed_subtask_count": failed,
|
||||
"blocked_subtask_count": blocked,
|
||||
"merge_status": merge_report.get("status"),
|
||||
"merge_conflict_count": len(merge_report.get("conflict_flags") or []),
|
||||
"fallback_used": bool(merge_report.get("fallback_used") or False),
|
||||
}
|
||||
69
backend/app/agents/orchestration/result_merge.py
Normal file
69
backend/app/agents/orchestration/result_merge.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.schemas.orchestration import MergeReport
|
||||
from app.agents.verifier import normalize_task_result
|
||||
|
||||
|
||||
def merge_task_results(task_results: list[dict] | list[object]) -> MergeReport:
|
||||
normalized = [normalize_task_result(item) for item in (task_results or [])]
|
||||
completed = [item for item in normalized if item.status == "completed"]
|
||||
failed_or_blocked = [item for item in normalized if item.status in {"failed", "blocked"}]
|
||||
|
||||
evidence_union: list[dict] = []
|
||||
summaries = []
|
||||
for item in normalized:
|
||||
evidence_union.extend(list(item.evidence or []))
|
||||
if item.summary:
|
||||
summaries.append(item.summary.strip())
|
||||
|
||||
unique_summaries = list(dict.fromkeys(summary for summary in summaries if summary))
|
||||
conflict_flags: list[str] = []
|
||||
status = "merged"
|
||||
fallback_used = False
|
||||
|
||||
if failed_or_blocked:
|
||||
status = "fallback"
|
||||
fallback_used = True
|
||||
conflict_flags.append(
|
||||
"failed_or_blocked_tasks:" + ",".join(item.task_id for item in failed_or_blocked)
|
||||
)
|
||||
resolution_strategy = "serial_recovery"
|
||||
resolved_summary = (
|
||||
completed[-1].summary
|
||||
if completed and completed[-1].summary
|
||||
else None
|
||||
)
|
||||
elif len(unique_summaries) > 1 and len(completed) > 1:
|
||||
status = "conflicted"
|
||||
conflict_flags.append("multiple_distinct_completed_summaries")
|
||||
resolution_strategy = "rank_by_evidence_count"
|
||||
ranked = sorted(
|
||||
completed,
|
||||
key=lambda item: (len(item.evidence or []), bool(item.summary)),
|
||||
reverse=True,
|
||||
)
|
||||
resolved_summary = ranked[0].summary if ranked and ranked[0].summary else None
|
||||
else:
|
||||
resolution_strategy = "evidence_union"
|
||||
resolved_summary = unique_summaries[-1] if unique_summaries else None
|
||||
|
||||
if status == "merged":
|
||||
summary = (
|
||||
unique_summaries[-1]
|
||||
if unique_summaries
|
||||
else f"已收敛 {len(normalized)} 个子任务结果。"
|
||||
)
|
||||
elif status == "conflicted":
|
||||
summary = "并行子任务摘要存在冲突,需要 verifier 或串行收敛。"
|
||||
else:
|
||||
summary = "存在失败或阻塞子任务,需要回退到更保守的收敛路径。"
|
||||
|
||||
return MergeReport(
|
||||
status=status,
|
||||
summary=summary,
|
||||
evidence_union=evidence_union,
|
||||
conflict_flags=conflict_flags,
|
||||
resolution_strategy=resolution_strategy,
|
||||
resolved_summary=resolved_summary,
|
||||
fallback_used=fallback_used,
|
||||
)
|
||||
93
backend/app/agents/orchestration/scheduler.py
Normal file
93
backend/app/agents/orchestration/scheduler.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from uuid import uuid4
|
||||
|
||||
from app.agents.orchestration.budget import build_subtask_budget
|
||||
from app.agents.schemas.orchestration import SubTaskSpec, TaskGraph, TaskNode
|
||||
|
||||
|
||||
class ParallelExecutionScheduler:
|
||||
def plan(self, task_graph: TaskGraph, *, query_text: str) -> list[SubTaskSpec]:
|
||||
ordered_nodes = _topological_nodes(task_graph)
|
||||
specs: list[SubTaskSpec] = []
|
||||
for node in ordered_nodes:
|
||||
budget = build_subtask_budget(
|
||||
execution_mode=node.execution_mode,
|
||||
max_parallel_tasks=max(1, task_graph.max_parallelism),
|
||||
metadata={
|
||||
"task_graph_id": task_graph.graph_id,
|
||||
"depends_on": node.depends_on,
|
||||
},
|
||||
)
|
||||
specs.append(
|
||||
SubTaskSpec(
|
||||
subtask_id=node.node_id,
|
||||
parent_run_id=task_graph.graph_id,
|
||||
title=node.title,
|
||||
role=node.role or "master",
|
||||
goal=node.goal or query_text,
|
||||
context_slice=_build_context_slice(node, query_text),
|
||||
allowed_tools=[],
|
||||
budget_tokens=1200,
|
||||
budget_tool_calls=budget.max_tool_calls or 2,
|
||||
expected_output_schema={
|
||||
"summary": "string",
|
||||
"evidence": "list",
|
||||
"status": "completed|failed|blocked",
|
||||
},
|
||||
expected_evidence=node.expected_evidence,
|
||||
dependencies=node.depends_on,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
|
||||
|
||||
def build_subtask_specs(task_graph: TaskGraph, *, query_text: str) -> list[SubTaskSpec]:
|
||||
return ParallelExecutionScheduler().plan(task_graph, query_text=query_text)
|
||||
|
||||
|
||||
def _build_context_slice(node: TaskNode, query_text: str) -> dict[str, object]:
|
||||
return {
|
||||
"query": query_text,
|
||||
"role": node.role,
|
||||
"title": node.title,
|
||||
"goal": node.goal,
|
||||
"depends_on": node.depends_on,
|
||||
}
|
||||
|
||||
|
||||
def _topological_nodes(task_graph: TaskGraph) -> list[TaskNode]:
|
||||
by_id = {node.node_id: node for node in task_graph.nodes}
|
||||
indegree = {node.node_id: 0 for node in task_graph.nodes}
|
||||
edges: dict[str, list[str]] = defaultdict(list)
|
||||
|
||||
for node in task_graph.nodes:
|
||||
for dep in node.depends_on:
|
||||
if dep not in by_id:
|
||||
continue
|
||||
edges[dep].append(node.node_id)
|
||||
indegree[node.node_id] += 1
|
||||
|
||||
ready = deque(node_id for node_id, count in indegree.items() if count == 0)
|
||||
ordered: list[TaskNode] = []
|
||||
|
||||
while ready:
|
||||
node_id = ready.popleft()
|
||||
ordered.append(by_id[node_id])
|
||||
for target in edges.get(node_id, []):
|
||||
indegree[target] -= 1
|
||||
if indegree[target] == 0:
|
||||
ready.append(target)
|
||||
|
||||
if len(ordered) != len(task_graph.nodes):
|
||||
return list(task_graph.nodes)
|
||||
return ordered
|
||||
|
||||
|
||||
def ensure_child_links(specs: list[SubTaskSpec]) -> dict[str, list[str]]:
|
||||
graph: dict[str, list[str]] = defaultdict(list)
|
||||
for spec in specs:
|
||||
for dep in spec.dependencies:
|
||||
graph[dep].append(spec.subtask_id)
|
||||
return dict(graph)
|
||||
17
backend/app/agents/orchestration/subagent_runtime.py
Normal file
17
backend/app/agents/orchestration/subagent_runtime.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.schemas.orchestration import SubTaskSpec
|
||||
from app.agents.schemas.task import AgentTask
|
||||
|
||||
|
||||
def subtask_spec_to_agent_task(spec: SubTaskSpec) -> AgentTask:
|
||||
return AgentTask(
|
||||
task_id=spec.subtask_id,
|
||||
title=spec.title,
|
||||
owner_agent_id=spec.role,
|
||||
role=spec.role,
|
||||
goal=spec.goal,
|
||||
parent_task_id=spec.parent_run_id,
|
||||
child_task_ids=[],
|
||||
expected_evidence=spec.expected_evidence,
|
||||
)
|
||||
128
backend/app/agents/orchestration/task_graph.py
Normal file
128
backend/app/agents/orchestration/task_graph.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from app.agents.schemas.orchestration import ParallelWorthiness, TaskGraph, TaskNode
|
||||
|
||||
|
||||
ROLE_KEYWORDS: list[tuple[str, tuple[str, ...]]] = [
|
||||
("librarian", ("查", "检索", "资料", "文档", "知识库", "年报", "forum", "search")),
|
||||
("analyst", ("分析", "判断", "风险", "总结", "对比", "洞察", "结论")),
|
||||
("schedule_planner", ("计划", "安排", "下周", "日程", "提醒", "优先级")),
|
||||
("executor", ("执行", "创建", "更新", "落库", "提交", "发帖")),
|
||||
]
|
||||
|
||||
|
||||
def build_bounded_task_graph(
|
||||
*,
|
||||
query_text: str,
|
||||
parallel_worthiness: ParallelWorthiness,
|
||||
max_nodes: int = 4,
|
||||
) -> TaskGraph | None:
|
||||
roles = _infer_roles(query_text)
|
||||
if not roles:
|
||||
return None
|
||||
|
||||
independent_roles = roles[: min(max_nodes - 1, max(1, parallel_worthiness.estimated_subtasks))]
|
||||
nodes: list[TaskNode] = []
|
||||
|
||||
for index, role in enumerate(independent_roles, start=1):
|
||||
node_id = f"task-{index}-{uuid4().hex[:6]}"
|
||||
nodes.append(
|
||||
TaskNode(
|
||||
node_id=node_id,
|
||||
title=_build_title(role),
|
||||
role=role,
|
||||
goal=_build_goal(role, query_text),
|
||||
depends_on=[],
|
||||
execution_mode=(
|
||||
"parallel"
|
||||
if parallel_worthiness.preferred_mode in {"collaboration", "parallel"}
|
||||
and len(independent_roles) > 1
|
||||
else "serial"
|
||||
),
|
||||
expected_evidence=_build_expected_evidence(role),
|
||||
)
|
||||
)
|
||||
|
||||
if len(nodes) > 1:
|
||||
merge_id = f"merge-{uuid4().hex[:6]}"
|
||||
nodes.append(
|
||||
TaskNode(
|
||||
node_id=merge_id,
|
||||
title="汇总并收敛最终结论",
|
||||
role="master",
|
||||
goal="汇总前置子任务结果,形成统一可验证的输出。",
|
||||
depends_on=[node.node_id for node in nodes],
|
||||
execution_mode="serial",
|
||||
expected_evidence=[{"type": "merge", "detail": "merged summary and conflict notes"}],
|
||||
)
|
||||
)
|
||||
|
||||
return TaskGraph(
|
||||
nodes=nodes,
|
||||
entry_node_ids=[node.node_id for node in nodes if not node.depends_on],
|
||||
max_parallelism=max(1, len(independent_roles)),
|
||||
rationale=_build_rationale(parallel_worthiness, independent_roles),
|
||||
)
|
||||
|
||||
|
||||
def render_task_graph_summary(task_graph: TaskGraph | None) -> str | None:
|
||||
if task_graph is None or not task_graph.nodes:
|
||||
return None
|
||||
|
||||
lines = ["- 任务图:"]
|
||||
for node in task_graph.nodes[:4]:
|
||||
deps = f" deps={','.join(node.depends_on)}" if node.depends_on else ""
|
||||
lines.append(f" - [{node.execution_mode}] {node.title} ({node.role}){deps}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _infer_roles(query_text: str) -> list[str]:
|
||||
selected: list[str] = []
|
||||
text = query_text or ""
|
||||
for role, keywords in ROLE_KEYWORDS:
|
||||
if any(keyword in text for keyword in keywords):
|
||||
selected.append(role)
|
||||
|
||||
if not selected:
|
||||
return ["analyst"]
|
||||
return selected
|
||||
|
||||
|
||||
def _build_title(role: str) -> str:
|
||||
mapping = {
|
||||
"librarian": "收集事实与外部/内部证据",
|
||||
"analyst": "形成判断与风险分析",
|
||||
"schedule_planner": "整理计划和优先级",
|
||||
"executor": "执行必要操作并回收结果",
|
||||
}
|
||||
return mapping.get(role, "处理子任务")
|
||||
|
||||
|
||||
def _build_goal(role: str, query_text: str) -> str:
|
||||
mapping = {
|
||||
"librarian": f"围绕请求收集支持结论的事实和资料:{query_text}",
|
||||
"analyst": f"基于当前请求输出结构化判断:{query_text}",
|
||||
"schedule_planner": f"把当前请求收束为计划、安排或优先级:{query_text}",
|
||||
"executor": f"基于请求执行必要动作并返回结果:{query_text}",
|
||||
}
|
||||
return mapping.get(role, query_text)
|
||||
|
||||
|
||||
def _build_expected_evidence(role: str) -> list[dict[str, str]]:
|
||||
mapping = {
|
||||
"librarian": [{"type": "evidence", "detail": "retrieval findings"}],
|
||||
"analyst": [{"type": "analysis", "detail": "structured judgment"}],
|
||||
"schedule_planner": [{"type": "plan", "detail": "explicit schedule or priorities"}],
|
||||
"executor": [{"type": "execution", "detail": "tool output or mutation result"}],
|
||||
}
|
||||
return mapping.get(role, [{"type": "summary", "detail": "task summary"}])
|
||||
|
||||
|
||||
def _build_rationale(parallel_worthiness: ParallelWorthiness, roles: list[str]) -> str:
|
||||
return (
|
||||
f"preferred_mode={parallel_worthiness.preferred_mode}; "
|
||||
f"score={parallel_worthiness.score:.2f}; "
|
||||
f"roles={','.join(roles)}"
|
||||
)
|
||||
12
backend/app/agents/plugins/__init__.py
Normal file
12
backend/app/agents/plugins/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""插件系统 - Phase 8"""
|
||||
|
||||
from app.agents.plugins.manager import PluginManager, get_plugin_manager
|
||||
from app.agents.plugins.manifest import PluginManifest
|
||||
from app.agents.plugins.sandbox import PluginSandbox
|
||||
|
||||
__all__ = [
|
||||
"PluginManager",
|
||||
"PluginManifest",
|
||||
"PluginSandbox",
|
||||
"get_plugin_manager",
|
||||
]
|
||||
19
backend/app/agents/plugins/builtins/code_helper/__init__.py
Normal file
19
backend/app/agents/plugins/builtins/code_helper/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Code Helper Plugin - Linting, formatting, and code explanation tools"""
|
||||
|
||||
|
||||
def lint_file(file_path: str) -> dict:
|
||||
"""Lint a source file and return issues found."""
|
||||
return {"status": "ok", "tool": "lint_file", "result": f"Linting {file_path}"}
|
||||
|
||||
|
||||
def format_file(file_path: str) -> dict:
|
||||
"""Format a source file and return the result."""
|
||||
return {"status": "ok", "tool": "format_file", "result": f"Formatting {file_path}"}
|
||||
|
||||
|
||||
def explain_code(code_snippet: str) -> dict:
|
||||
"""Explain a code snippet and return the explanation."""
|
||||
return {"status": "ok", "tool": "explain_code", "result": f"Explaining code snippet"}
|
||||
|
||||
|
||||
tools = [lint_file, format_file, explain_code]
|
||||
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"id": "code_helper",
|
||||
"name": "Code Helper",
|
||||
"version": "1.0.0",
|
||||
"description": "Code linting, formatting, and explanation tools",
|
||||
"author": "",
|
||||
"homepage": "",
|
||||
"license": "MIT",
|
||||
"plugin_type": "tool",
|
||||
"main": "__init__.py",
|
||||
"hooks": [],
|
||||
"tools": ["lint_file", "format_file", "explain_code"],
|
||||
"skills": [],
|
||||
"dependencies": {},
|
||||
"peer_dependencies": {},
|
||||
"permissions": [],
|
||||
"allowed_paths": [],
|
||||
"denied_paths": [],
|
||||
"network_allowed": false,
|
||||
"allowed_hosts": [],
|
||||
"config_schema": {}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
"""File Organizer Plugin - File organization and duplicate detection tools"""
|
||||
|
||||
|
||||
def organize_by_type(directory: str) -> dict:
|
||||
"""Organize files in a directory by file type."""
|
||||
return {"status": "ok", "tool": "organize_by_type", "result": f"Organizing {directory} by type"}
|
||||
|
||||
|
||||
def find_duplicates(directory: str) -> dict:
|
||||
"""Find duplicate files in a directory."""
|
||||
return {
|
||||
"status": "ok",
|
||||
"tool": "find_duplicates",
|
||||
"result": f"Finding duplicates in {directory}",
|
||||
}
|
||||
|
||||
|
||||
tools = [organize_by_type, find_duplicates]
|
||||
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"id": "file_organizer",
|
||||
"name": "File Organizer",
|
||||
"version": "1.0.0",
|
||||
"description": "File organization and duplicate detection tools",
|
||||
"author": "",
|
||||
"homepage": "",
|
||||
"license": "MIT",
|
||||
"plugin_type": "tool",
|
||||
"main": "__init__.py",
|
||||
"hooks": [],
|
||||
"tools": ["organize_by_type", "find_duplicates"],
|
||||
"skills": [],
|
||||
"dependencies": {},
|
||||
"peer_dependencies": {},
|
||||
"permissions": [],
|
||||
"allowed_paths": [],
|
||||
"denied_paths": [],
|
||||
"network_allowed": false,
|
||||
"allowed_hosts": [],
|
||||
"config_schema": {}
|
||||
}
|
||||
23
backend/app/agents/plugins/builtins/git_helper/__init__.py
Normal file
23
backend/app/agents/plugins/builtins/git_helper/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Git Helper Plugin - Git status, log, and diff summary tools"""
|
||||
|
||||
|
||||
def git_status_summary() -> dict:
|
||||
"""Get a summary of git status."""
|
||||
return {"status": "ok", "tool": "git_status_summary", "result": "Git status summary"}
|
||||
|
||||
|
||||
def git_log_summary(limit: int = 10) -> dict:
|
||||
"""Get a summary of recent git commits."""
|
||||
return {"status": "ok", "tool": "git_log_summary", "result": f"Git log summary (limit={limit})"}
|
||||
|
||||
|
||||
def git_diff_summary(ref1: str = "HEAD", ref2: str = "HEAD~1") -> dict:
|
||||
"""Get a summary of changes between two refs."""
|
||||
return {
|
||||
"status": "ok",
|
||||
"tool": "git_diff_summary",
|
||||
"result": f"Git diff summary ({ref1}..{ref2})",
|
||||
}
|
||||
|
||||
|
||||
tools = [git_status_summary, git_log_summary, git_diff_summary]
|
||||
22
backend/app/agents/plugins/builtins/git_helper/manifest.json
Normal file
22
backend/app/agents/plugins/builtins/git_helper/manifest.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"id": "git_helper",
|
||||
"name": "Git Helper",
|
||||
"version": "1.0.0",
|
||||
"description": "Git status, log, and diff summary tools",
|
||||
"author": "",
|
||||
"homepage": "",
|
||||
"license": "MIT",
|
||||
"plugin_type": "tool",
|
||||
"main": "__init__.py",
|
||||
"hooks": [],
|
||||
"tools": ["git_status_summary", "git_log_summary", "git_diff_summary"],
|
||||
"skills": [],
|
||||
"dependencies": {},
|
||||
"peer_dependencies": {},
|
||||
"permissions": [],
|
||||
"allowed_paths": [],
|
||||
"denied_paths": [],
|
||||
"network_allowed": false,
|
||||
"allowed_hosts": [],
|
||||
"config_schema": {}
|
||||
}
|
||||
14
backend/app/agents/plugins/builtins/web_helper/__init__.py
Normal file
14
backend/app/agents/plugins/builtins/web_helper/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Web Helper Plugin - Web fetching and HTML parsing tools"""
|
||||
|
||||
|
||||
def fetch_url_content(url: str) -> dict:
|
||||
"""Fetch content from a URL."""
|
||||
return {"status": "ok", "tool": "fetch_url_content", "result": f"Fetching {url}"}
|
||||
|
||||
|
||||
def parse_html_links(html_content: str) -> dict:
|
||||
"""Parse HTML content and extract links."""
|
||||
return {"status": "ok", "tool": "parse_html_links", "result": "Extracted links from HTML"}
|
||||
|
||||
|
||||
tools = [fetch_url_content, parse_html_links]
|
||||
22
backend/app/agents/plugins/builtins/web_helper/manifest.json
Normal file
22
backend/app/agents/plugins/builtins/web_helper/manifest.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"id": "web_helper",
|
||||
"name": "Web Helper",
|
||||
"version": "1.0.0",
|
||||
"description": "Web fetching and HTML parsing tools",
|
||||
"author": "",
|
||||
"homepage": "",
|
||||
"license": "MIT",
|
||||
"plugin_type": "tool",
|
||||
"main": "__init__.py",
|
||||
"hooks": [],
|
||||
"tools": ["fetch_url_content", "parse_html_links"],
|
||||
"skills": [],
|
||||
"dependencies": {},
|
||||
"peer_dependencies": {},
|
||||
"permissions": [],
|
||||
"allowed_paths": [],
|
||||
"denied_paths": [],
|
||||
"network_allowed": true,
|
||||
"allowed_hosts": [],
|
||||
"config_schema": {}
|
||||
}
|
||||
207
backend/app/agents/plugins/manager.py
Normal file
207
backend/app/agents/plugins/manager.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""插件管理器 - Phase 8.2"""
|
||||
|
||||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from app.agents.plugins.manifest import PluginManifest
|
||||
from app.agents.plugins.sandbox import PluginSandbox
|
||||
|
||||
|
||||
class PluginManager:
|
||||
"""插件管理器
|
||||
|
||||
负责插件的安装、卸载、启用、禁用和生命周期管理。
|
||||
"""
|
||||
|
||||
def __init__(self, plugins_dir: str | None = None):
|
||||
"""
|
||||
Args:
|
||||
plugins_dir: 插件目录,None 则使用默认目录
|
||||
"""
|
||||
if plugins_dir is None:
|
||||
plugins_dir = os.path.join(os.path.dirname(__file__), "..", "..", "..", "plugins")
|
||||
self.plugins_dir = plugins_dir
|
||||
self._plugins: dict[str, PluginManifest] = {}
|
||||
self._enabled: dict[str, bool] = {}
|
||||
self._modules: dict[str, Any] = {}
|
||||
self._sandbox = PluginSandbox()
|
||||
|
||||
def install(self, plugin_path: str) -> bool:
|
||||
"""安装插件
|
||||
|
||||
Args:
|
||||
plugin_path: 插件目录路径或 manifest.json 所在目录
|
||||
|
||||
Returns:
|
||||
是否安装成功
|
||||
"""
|
||||
try:
|
||||
manifest_path = os.path.join(plugin_path, "manifest.json")
|
||||
|
||||
if not os.path.exists(manifest_path):
|
||||
return False
|
||||
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
import json
|
||||
|
||||
data = json.load(f)
|
||||
|
||||
manifest = PluginManifest.from_dict(data)
|
||||
|
||||
# 验证 manifest
|
||||
if not self._validate_manifest(manifest, plugin_path):
|
||||
return False
|
||||
|
||||
# 复制插件到 plugins_dir
|
||||
target_dir = os.path.join(self.plugins_dir, manifest.id)
|
||||
os.makedirs(os.path.dirname(target_dir), exist_ok=True)
|
||||
|
||||
# 保存 manifest
|
||||
with open(os.path.join(target_dir, "manifest.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(manifest.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
|
||||
# 注册插件
|
||||
self._plugins[manifest.id] = manifest
|
||||
self._enabled[manifest.id] = True
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def uninstall(self, plugin_id: str) -> bool:
|
||||
"""卸载插件
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID
|
||||
|
||||
Returns:
|
||||
是否卸载成功
|
||||
"""
|
||||
if plugin_id not in self._plugins:
|
||||
return False
|
||||
|
||||
# 禁用插件
|
||||
self.disable(plugin_id)
|
||||
|
||||
# 移除模块
|
||||
if plugin_id in self._modules:
|
||||
del self._modules[plugin_id]
|
||||
|
||||
# 移除插件
|
||||
del self._plugins[plugin_id]
|
||||
del self._enabled[plugin_id]
|
||||
|
||||
# 删除目录
|
||||
plugin_dir = os.path.join(self.plugins_dir, plugin_id)
|
||||
if os.path.exists(plugin_dir):
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(plugin_dir)
|
||||
|
||||
return True
|
||||
|
||||
def enable(self, plugin_id: str) -> bool:
|
||||
"""启用插件
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID
|
||||
|
||||
Returns:
|
||||
是否启用成功
|
||||
"""
|
||||
if plugin_id not in self._plugins:
|
||||
return False
|
||||
|
||||
self._enabled[plugin_id] = True
|
||||
return True
|
||||
|
||||
def disable(self, plugin_id: str) -> bool:
|
||||
"""禁用插件
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID
|
||||
|
||||
Returns:
|
||||
是否禁用成功
|
||||
"""
|
||||
if plugin_id not in self._plugins:
|
||||
return False
|
||||
|
||||
self._enabled[plugin_id] = False
|
||||
return True
|
||||
|
||||
def reload(self, plugin_id: str) -> bool:
|
||||
"""重新加载插件
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID
|
||||
|
||||
Returns:
|
||||
是否重新加载成功
|
||||
"""
|
||||
if plugin_id not in self._plugins:
|
||||
return False
|
||||
|
||||
# 卸载模块
|
||||
if plugin_id in self._modules:
|
||||
del self._modules[plugin_id]
|
||||
|
||||
# 重新加载
|
||||
return self._load_plugin_module(plugin_id)
|
||||
|
||||
def list_plugins(self) -> list[PluginManifest]:
|
||||
"""列出所有插件"""
|
||||
return list(self._plugins.values())
|
||||
|
||||
def get_plugin(self, plugin_id: str) -> PluginManifest | None:
|
||||
"""获取插件清单"""
|
||||
return self._plugins.get(plugin_id)
|
||||
|
||||
def is_enabled(self, plugin_id: str) -> bool:
|
||||
"""检查插件是否启用"""
|
||||
return self._enabled.get(plugin_id, False)
|
||||
|
||||
def _validate_manifest(self, manifest: PluginManifest, plugin_path: str) -> bool:
|
||||
"""验证 manifest"""
|
||||
# 检查主入口文件是否存在
|
||||
main_path = os.path.join(plugin_path, manifest.main)
|
||||
if not os.path.exists(main_path):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _load_plugin_module(self, plugin_id: str) -> bool:
|
||||
"""加载插件模块"""
|
||||
plugin_dir = os.path.join(self.plugins_dir, plugin_id)
|
||||
manifest = self._plugins.get(plugin_id)
|
||||
if not manifest:
|
||||
return False
|
||||
|
||||
try:
|
||||
main_path = os.path.join(plugin_dir, manifest.main)
|
||||
spec = importlib.util.spec_from_file_location(plugin_id, main_path)
|
||||
if spec and spec.loader:
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[plugin_id] = module
|
||||
spec.loader.exec_module(module)
|
||||
self._modules[plugin_id] = module
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# 全局单例
|
||||
_manager: PluginManager | None = None
|
||||
|
||||
|
||||
def get_plugin_manager() -> PluginManager:
|
||||
"""获取全局插件管理器"""
|
||||
global _manager
|
||||
if _manager is None:
|
||||
_manager = PluginManager()
|
||||
return _manager
|
||||
73
backend/app/agents/plugins/manifest.py
Normal file
73
backend/app/agents/plugins/manifest.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""插件清单定义 - Phase 8.1"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginManifest:
|
||||
"""插件清单
|
||||
|
||||
定义插件的元数据和接口。
|
||||
"""
|
||||
|
||||
id: str # 唯一标识
|
||||
name: str # 显示名称
|
||||
version: str # 版本号
|
||||
description: str # 描述
|
||||
author: str = "" # 作者
|
||||
homepage: str = "" # 主页
|
||||
license: str = "MIT" # 许可证
|
||||
|
||||
# 插件类型
|
||||
plugin_type: str = "tool" # tool, hook, skill, all
|
||||
|
||||
# 入口点
|
||||
main: str = "index.py" # 主入口文件
|
||||
hooks: list[str] = field(default_factory=list) # 提供的 Hook 列表
|
||||
tools: list[str] = field(default_factory=list) # 提供的工具列表
|
||||
skills: list[str] = field(default_factory=list) # 提供的 Skills 列表
|
||||
|
||||
# 依赖
|
||||
dependencies: dict[str, str] = field(default_factory=dict) # pip 依赖
|
||||
peer_dependencies: dict[str, str] = field(default_factory=dict) # 对等依赖
|
||||
|
||||
# 权限要求
|
||||
permissions: list[str] = field(default_factory=list) # 需要的权限
|
||||
allowed_paths: list[str] = field(default_factory=list) # 允许访问的路径
|
||||
denied_paths: list[str] = field(default_factory=list) # 禁止访问的路径
|
||||
|
||||
# 网络权限
|
||||
network_allowed: bool = False # 是否允许网络访问
|
||||
allowed_hosts: list[str] = field(default_factory=list) # 允许访问的 host
|
||||
|
||||
# 配置
|
||||
config_schema: dict[str, Any] = field(default_factory=dict) # 配置 schema
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"version": self.version,
|
||||
"description": self.description,
|
||||
"author": self.author,
|
||||
"homepage": self.homepage,
|
||||
"license": self.license,
|
||||
"plugin_type": self.plugin_type,
|
||||
"main": self.main,
|
||||
"hooks": self.hooks,
|
||||
"tools": self.tools,
|
||||
"skills": self.skills,
|
||||
"dependencies": self.dependencies,
|
||||
"peer_dependencies": self.peer_dependencies,
|
||||
"permissions": self.permissions,
|
||||
"allowed_paths": self.allowed_paths,
|
||||
"denied_paths": self.denied_paths,
|
||||
"network_allowed": self.network_allowed,
|
||||
"allowed_hosts": self.allowed_hosts,
|
||||
"config_schema": self.config_schema,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "PluginManifest":
|
||||
return cls(**data)
|
||||
111
backend/app/agents/plugins/sandbox.py
Normal file
111
backend/app/agents/plugins/sandbox.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""插件沙箱隔离 - Phase 8.3"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
|
||||
class PluginSandbox:
|
||||
"""插件沙箱
|
||||
|
||||
提供插件执行隔离环境。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._allowed_paths: set[str] = set()
|
||||
self._denied_paths: set[str] = set()
|
||||
self._network_allowed: bool = False
|
||||
self._allowed_hosts: set[str] = set()
|
||||
|
||||
def set_file_permissions(
|
||||
self,
|
||||
allowed_paths: list[str] | None = None,
|
||||
denied_paths: list[str] | None = None,
|
||||
) -> None:
|
||||
"""设置文件访问权限
|
||||
|
||||
Args:
|
||||
allowed_paths: 允许访问的路径列表
|
||||
denied_paths: 禁止访问的路径列表
|
||||
"""
|
||||
self._allowed_paths = set(allowed_paths or [])
|
||||
self._denied_paths = set(denied_paths or [])
|
||||
|
||||
def set_network_permissions(
|
||||
self, allowed: bool, allowed_hosts: list[str] | None = None
|
||||
) -> None:
|
||||
"""设置网络访问权限
|
||||
|
||||
Args:
|
||||
allowed: 是否允许网络访问
|
||||
allowed_hosts: 允许访问的 host 列表
|
||||
"""
|
||||
self._network_allowed = allowed
|
||||
self._allowed_hosts = set(allowed_hosts or [])
|
||||
|
||||
def check_file_access(self, path: str) -> bool:
|
||||
"""检查文件访问权限
|
||||
|
||||
Args:
|
||||
path: 文件路径
|
||||
|
||||
Returns:
|
||||
是否允许访问
|
||||
"""
|
||||
# 如果有允许列表,只允许访问列表中的路径
|
||||
if self._allowed_paths:
|
||||
return path in self._allowed_paths or any(
|
||||
path.startswith(allowed) for allowed in self._allowed_paths
|
||||
)
|
||||
|
||||
# 如果有禁止列表,禁止访问列表中的路径
|
||||
if self._denied_paths:
|
||||
return not any(path.startswith(denied) for denied in self._denied_paths)
|
||||
|
||||
# 没有限制
|
||||
return True
|
||||
|
||||
def check_network_access(self, host: str) -> bool:
|
||||
"""检查网络访问权限
|
||||
|
||||
Args:
|
||||
host: 主机地址
|
||||
|
||||
Returns:
|
||||
是否允许访问
|
||||
"""
|
||||
if not self._network_allowed:
|
||||
return False
|
||||
|
||||
if self._allowed_hosts:
|
||||
return host in self._allowed_hosts or any(
|
||||
host.endswith(allowed) for allowed in self._allowed_hosts
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def execute_in_sandbox(self, func: Any, *args, **kwargs) -> Any:
|
||||
"""在沙箱中执行函数
|
||||
|
||||
Args:
|
||||
func: 要执行的函数
|
||||
*args: 位置参数
|
||||
**kwargs: 关键字参数
|
||||
|
||||
Returns:
|
||||
函数返回值
|
||||
"""
|
||||
# 保存当前状态
|
||||
old_allowed_paths = self._allowed_paths.copy()
|
||||
old_denied_paths = self._denied_paths.copy()
|
||||
old_network_allowed = self._network_allowed
|
||||
old_allowed_hosts = self._allowed_hosts.copy()
|
||||
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
finally:
|
||||
# 恢复状态
|
||||
self._allowed_paths = old_allowed_paths
|
||||
self._denied_paths = old_denied_paths
|
||||
self._network_allowed = old_network_allowed
|
||||
self._allowed_hosts = old_allowed_hosts
|
||||
@@ -309,14 +309,14 @@ ANALYST_INSIGHTS_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
|
||||
你是 analyst 体系下的洞察建议官,负责从任务、论坛和知识线索里提炼趋势、风险与建议。
|
||||
|
||||
## 允许使用的工具:
|
||||
## 你的允许使用的工具:
|
||||
- get_tasks
|
||||
- get_forum_posts
|
||||
- search_knowledge
|
||||
- hybrid_search
|
||||
- web_search
|
||||
|
||||
## 要求:
|
||||
## 你的要求:
|
||||
- 先给结论与判断
|
||||
- 再说明依据与建议
|
||||
- 当需要外部/最新信息时,可使用 `web_search`
|
||||
@@ -324,6 +324,70 @@ ANALYST_INSIGHTS_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
"""
|
||||
|
||||
|
||||
CODE_COMMANDER_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
|
||||
你是代码指挥官,负责协调 AI 写代码助手。
|
||||
|
||||
## 你的职责:
|
||||
1. 接收用户选择的 AI 提供商(Claude/Gemini/Codex/OpenCode)
|
||||
2. 接收用户的写代码需求
|
||||
3. 进行安全分级判定
|
||||
4. 路由到合适的执行器
|
||||
|
||||
## 安全分级规则:
|
||||
- 低风险:demo、示例、贪食蛇游戏等独立项目
|
||||
- 高风险:修改现有项目、涉及 Jarvis 项目、路径操作等
|
||||
|
||||
## 执行模式:
|
||||
- 直接执行:低风险任务,直接运行
|
||||
- 沙盒执行:高风险任务,在临时目录隔离执行
|
||||
|
||||
## 你的输出:
|
||||
- 简洁汇报执行结果
|
||||
- 如果需要用户交互(如确认 "y"),明确提示
|
||||
"""
|
||||
|
||||
|
||||
SANDBOX_EXECUTION_PROMPT = """将在隔离的临时目录中执行任务。
|
||||
任务完成后,工作目录会被保留供下载。"""
|
||||
|
||||
|
||||
DIRECT_EXECUTION_PROMPT = """将直接执行任务。
|
||||
如果需要交互,请等待用户输入。"""
|
||||
|
||||
|
||||
COORDINATOR_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
|
||||
你是 Jarvis 的协作协调官,负责把复杂请求收束成最小受控协作,而不是放任系统进入自由 swarm。
|
||||
|
||||
## 你的职责:
|
||||
- 先判断当前请求是否真的需要拆解;不需要时应明确建议继续走 direct
|
||||
- 只有在明显多步骤、跨领域、需要多角色配合时,才拆成 2~4 个子任务
|
||||
- 每个子任务必须清晰写出 `title`、`role`、`goal`、`expected_evidence`
|
||||
- 角色建议只能来自现有 top-level agent:`schedule_planner`、`librarian`、`analyst`、`executor`
|
||||
- 汇总时基于子任务结果回收,不依赖单点硬编码拼接
|
||||
|
||||
## 边界:
|
||||
- 禁止无限递归拆分
|
||||
- 禁止创建新的 runtime agent / worker
|
||||
- 禁止把一个简单请求硬拆成多个空泛步骤
|
||||
- 如果证据不足、子任务未闭环,必须把风险明确暴露出来
|
||||
"""
|
||||
|
||||
|
||||
VERIFIER_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
|
||||
你是 Jarvis 的验证官,负责对执行结果做最小但明确的核验。
|
||||
|
||||
## 你的职责:
|
||||
- 只输出 passed、failed、skipped 三种验证结论之一
|
||||
- 用一句话总结验证判断
|
||||
- 如有证据,保留关键证据点
|
||||
- 当信息不足以证明成功或失败时,优先判定为 skipped
|
||||
- 不重写执行方案,不扩展无关建议
|
||||
"""
|
||||
|
||||
|
||||
JSON_ACTION_FALLBACK_PROMPT = """你当前运行在 JSON action fallback 模式。
|
||||
|
||||
你的输出必须满足以下规则:
|
||||
@@ -350,6 +414,7 @@ TOP_LEVEL_SYSTEM_PROMPTS_BY_KEY = {
|
||||
"executor": EXECUTOR_SYSTEM_PROMPT,
|
||||
"librarian": LIBRARIAN_SYSTEM_PROMPT,
|
||||
"analyst": ANALYST_SYSTEM_PROMPT,
|
||||
"code_commander": CODE_COMMANDER_SYSTEM_PROMPT,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,19 @@
|
||||
"""Registry manifest models and validation helpers."""
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
from app.agents.registry.indexes import RegistryIndexes, build_registry_indexes
|
||||
from app.agents.registry.loader import RegistryBundle, load_builtin_registry_bundle
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def load_builtin_registry_indexes() -> RegistryIndexes:
|
||||
return build_registry_indexes(load_builtin_registry_bundle())
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RegistryBundle",
|
||||
"RegistryIndexes",
|
||||
"build_registry_indexes",
|
||||
"load_builtin_registry_bundle",
|
||||
"load_builtin_registry_indexes",
|
||||
]
|
||||
|
||||
@@ -2,6 +2,8 @@ from app.agents.prompts import SUB_COMMANDER_PROMPTS_BY_KEY
|
||||
from app.agents.registry.models import (
|
||||
AgentManifest,
|
||||
CapabilityManifest,
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
SpecialistTemplateManifest,
|
||||
SubCommanderManifest,
|
||||
)
|
||||
@@ -27,6 +29,7 @@ TOP_LEVEL_AGENT_DEFAULT_SUB_COMMANDERS: dict[str, tuple[str, ...]] = {
|
||||
"analyst_progress",
|
||||
"analyst_insights",
|
||||
),
|
||||
AgentRole.CODE_COMMANDER.value: (),
|
||||
}
|
||||
|
||||
TOP_LEVEL_AGENT_DISPLAY_NAMES: dict[str, str] = {
|
||||
@@ -35,6 +38,7 @@ TOP_LEVEL_AGENT_DISPLAY_NAMES: dict[str, str] = {
|
||||
AgentRole.EXECUTOR.value: "Executor",
|
||||
AgentRole.LIBRARIAN.value: "Librarian",
|
||||
AgentRole.ANALYST.value: "Analyst",
|
||||
AgentRole.CODE_COMMANDER.value: "Code Commander",
|
||||
}
|
||||
|
||||
TOP_LEVEL_AGENT_ROUTING_HINTS: dict[str, tuple[str, ...]] = {
|
||||
@@ -53,6 +57,24 @@ TOP_LEVEL_AGENT_ROUTING_HINTS: dict[str, tuple[str, ...]] = {
|
||||
AgentRole.ANALYST.value: (
|
||||
"Handle reporting and insight requests using analyst sub-commanders.",
|
||||
),
|
||||
AgentRole.CODE_COMMANDER.value: (
|
||||
"Handle code writing and execution tasks using AI CLI adapters.",
|
||||
),
|
||||
}
|
||||
|
||||
TOP_LEVEL_AGENT_ALLOWED_SPAWN_ROLES: dict[str, tuple[str, ...]] = {
|
||||
AgentRole.MASTER.value: (
|
||||
AgentRole.SCHEDULE_PLANNER.value,
|
||||
AgentRole.EXECUTOR.value,
|
||||
AgentRole.LIBRARIAN.value,
|
||||
AgentRole.ANALYST.value,
|
||||
AgentRole.CODE_COMMANDER.value,
|
||||
),
|
||||
AgentRole.SCHEDULE_PLANNER.value: (AgentRole.SCHEDULE_PLANNER.value,),
|
||||
AgentRole.EXECUTOR.value: (AgentRole.EXECUTOR.value,),
|
||||
AgentRole.LIBRARIAN.value: (AgentRole.LIBRARIAN.value,),
|
||||
AgentRole.ANALYST.value: (AgentRole.ANALYST.value,),
|
||||
AgentRole.CODE_COMMANDER.value: (),
|
||||
}
|
||||
|
||||
SUB_COMMANDER_PARENT_AGENT_IDS: dict[str, str] = {
|
||||
@@ -75,6 +97,8 @@ BUILTIN_AGENT_MANIFESTS: tuple[AgentManifest, ...] = tuple(
|
||||
system_prompt_key=role.value,
|
||||
routing_hints=list(TOP_LEVEL_AGENT_ROUTING_HINTS[role.value]),
|
||||
default_sub_commanders=list(TOP_LEVEL_AGENT_DEFAULT_SUB_COMMANDERS[role.value]),
|
||||
can_spawn_children=bool(TOP_LEVEL_AGENT_ALLOWED_SPAWN_ROLES[role.value]),
|
||||
allowed_spawn_role_values=list(TOP_LEVEL_AGENT_ALLOWED_SPAWN_ROLES[role.value]),
|
||||
skill_context_key=role.value.replace("agent_", ""),
|
||||
)
|
||||
for role in AgentRole
|
||||
@@ -82,17 +106,153 @@ BUILTIN_AGENT_MANIFESTS: tuple[AgentManifest, ...] = tuple(
|
||||
|
||||
|
||||
_capability_tool_names = tuple(
|
||||
dict.fromkeys(
|
||||
tool.name
|
||||
for tools in SUB_COMMANDER_TOOLSETS.values()
|
||||
for tool in tools
|
||||
)
|
||||
dict.fromkeys(tool.name for tools in SUB_COMMANDER_TOOLSETS.values() for tool in tools)
|
||||
)
|
||||
|
||||
_CAPABILITY_METADATA_BY_TOOL_NAME: dict[str, dict[str, object]] = {
|
||||
"get_tasks": {
|
||||
"permission_class": PermissionClass.READ,
|
||||
"side_effect_scope": SideEffectScope.NONE,
|
||||
"supports_retry": True,
|
||||
"idempotent": True,
|
||||
"safe_for_parallel_use": True,
|
||||
"requires_confirmation": False,
|
||||
},
|
||||
"get_schedule_day": {
|
||||
"permission_class": PermissionClass.READ,
|
||||
"side_effect_scope": SideEffectScope.NONE,
|
||||
"supports_retry": True,
|
||||
"idempotent": True,
|
||||
"safe_for_parallel_use": True,
|
||||
"requires_confirmation": False,
|
||||
},
|
||||
"resolve_time_expression": {
|
||||
"permission_class": PermissionClass.READ,
|
||||
"side_effect_scope": SideEffectScope.NONE,
|
||||
"supports_retry": True,
|
||||
"idempotent": True,
|
||||
"safe_for_parallel_use": True,
|
||||
"requires_confirmation": False,
|
||||
},
|
||||
"search_knowledge": {
|
||||
"permission_class": PermissionClass.READ,
|
||||
"side_effect_scope": SideEffectScope.NONE,
|
||||
"supports_retry": True,
|
||||
"idempotent": True,
|
||||
"safe_for_parallel_use": True,
|
||||
"requires_confirmation": False,
|
||||
},
|
||||
"hybrid_search": {
|
||||
"permission_class": PermissionClass.READ,
|
||||
"side_effect_scope": SideEffectScope.NONE,
|
||||
"supports_retry": True,
|
||||
"idempotent": True,
|
||||
"safe_for_parallel_use": True,
|
||||
"requires_confirmation": False,
|
||||
},
|
||||
"get_knowledge_graph_context": {
|
||||
"permission_class": PermissionClass.READ,
|
||||
"side_effect_scope": SideEffectScope.NONE,
|
||||
"supports_retry": True,
|
||||
"idempotent": True,
|
||||
"safe_for_parallel_use": True,
|
||||
"requires_confirmation": False,
|
||||
},
|
||||
"get_forum_posts": {
|
||||
"permission_class": PermissionClass.READ,
|
||||
"side_effect_scope": SideEffectScope.NONE,
|
||||
"supports_retry": True,
|
||||
"idempotent": True,
|
||||
"safe_for_parallel_use": True,
|
||||
"requires_confirmation": False,
|
||||
},
|
||||
"scan_forum_for_instructions": {
|
||||
"permission_class": PermissionClass.READ,
|
||||
"side_effect_scope": SideEffectScope.NONE,
|
||||
"supports_retry": True,
|
||||
"idempotent": True,
|
||||
"safe_for_parallel_use": True,
|
||||
"requires_confirmation": False,
|
||||
},
|
||||
"web_search": {
|
||||
"permission_class": PermissionClass.EXTERNAL,
|
||||
"side_effect_scope": SideEffectScope.NETWORK,
|
||||
"supports_retry": True,
|
||||
"idempotent": True,
|
||||
"safe_for_parallel_use": True,
|
||||
"requires_confirmation": False,
|
||||
},
|
||||
"create_task": {
|
||||
"permission_class": PermissionClass.WRITE,
|
||||
"side_effect_scope": SideEffectScope.LOCAL_STATE,
|
||||
"supports_retry": False,
|
||||
"idempotent": False,
|
||||
"safe_for_parallel_use": False,
|
||||
"requires_confirmation": True,
|
||||
},
|
||||
"update_task_status": {
|
||||
"permission_class": PermissionClass.WRITE,
|
||||
"side_effect_scope": SideEffectScope.LOCAL_STATE,
|
||||
"supports_retry": False,
|
||||
"idempotent": False,
|
||||
"safe_for_parallel_use": False,
|
||||
"requires_confirmation": True,
|
||||
},
|
||||
"create_todo": {
|
||||
"permission_class": PermissionClass.WRITE,
|
||||
"side_effect_scope": SideEffectScope.LOCAL_STATE,
|
||||
"supports_retry": False,
|
||||
"idempotent": False,
|
||||
"safe_for_parallel_use": False,
|
||||
"requires_confirmation": True,
|
||||
},
|
||||
"create_schedule_task": {
|
||||
"permission_class": PermissionClass.WRITE,
|
||||
"side_effect_scope": SideEffectScope.LOCAL_STATE,
|
||||
"supports_retry": False,
|
||||
"idempotent": False,
|
||||
"safe_for_parallel_use": False,
|
||||
"requires_confirmation": True,
|
||||
},
|
||||
"create_reminder": {
|
||||
"permission_class": PermissionClass.WRITE,
|
||||
"side_effect_scope": SideEffectScope.LOCAL_STATE,
|
||||
"supports_retry": False,
|
||||
"idempotent": False,
|
||||
"safe_for_parallel_use": False,
|
||||
"requires_confirmation": True,
|
||||
},
|
||||
"create_goal": {
|
||||
"permission_class": PermissionClass.WRITE,
|
||||
"side_effect_scope": SideEffectScope.LOCAL_STATE,
|
||||
"supports_retry": False,
|
||||
"idempotent": False,
|
||||
"safe_for_parallel_use": False,
|
||||
"requires_confirmation": True,
|
||||
},
|
||||
"create_forum_post": {
|
||||
"permission_class": PermissionClass.WRITE,
|
||||
"side_effect_scope": SideEffectScope.LOCAL_STATE,
|
||||
"supports_retry": False,
|
||||
"idempotent": False,
|
||||
"safe_for_parallel_use": False,
|
||||
"requires_confirmation": True,
|
||||
},
|
||||
"build_knowledge_graph": {
|
||||
"permission_class": PermissionClass.WRITE,
|
||||
"side_effect_scope": SideEffectScope.LOCAL_STATE,
|
||||
"supports_retry": False,
|
||||
"idempotent": False,
|
||||
"safe_for_parallel_use": False,
|
||||
"requires_confirmation": True,
|
||||
},
|
||||
}
|
||||
|
||||
BUILTIN_CAPABILITY_MANIFESTS: tuple[CapabilityManifest, ...] = tuple(
|
||||
CapabilityManifest(
|
||||
capability_id=tool_name,
|
||||
tool_name=tool_name,
|
||||
**dict(_CAPABILITY_METADATA_BY_TOOL_NAME.get(tool_name, {})),
|
||||
)
|
||||
for tool_name in _capability_tool_names
|
||||
)
|
||||
@@ -103,9 +263,7 @@ BUILTIN_SUB_COMMANDER_MANIFESTS: tuple[SubCommanderManifest, ...] = tuple(
|
||||
sub_commander_id=sub_commander_id,
|
||||
parent_agent_id=SUB_COMMANDER_PARENT_AGENT_IDS[sub_commander_id],
|
||||
prompt_text=SUB_COMMANDER_PROMPTS_BY_KEY[sub_commander_id],
|
||||
capability_ids=list(
|
||||
dict.fromkeys(tool.name for tool in tools)
|
||||
),
|
||||
capability_ids=list(dict.fromkeys(tool.name for tool in tools)),
|
||||
)
|
||||
for sub_commander_id, tools in SUB_COMMANDER_TOOLSETS.items()
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@ from app.agents.registry.models import (
|
||||
@dataclass(frozen=True)
|
||||
class RegistryIndexes:
|
||||
agent_by_id: Mapping[str, AgentManifest]
|
||||
agent_by_role_value: Mapping[str, AgentManifest]
|
||||
sub_commander_by_id: Mapping[str, SubCommanderManifest]
|
||||
capability_by_id: Mapping[str, CapabilityManifest]
|
||||
specialist_template_by_id: Mapping[str, SpecialistTemplateManifest]
|
||||
@@ -24,6 +25,7 @@ class RegistryIndexes:
|
||||
skill_context_key_by_agent_id: Mapping[str, str]
|
||||
capability_id_by_tool_name: Mapping[str, str]
|
||||
capability_ids_by_sub_commander_id: Mapping[str, tuple[str, ...]]
|
||||
spawnable_role_values_by_agent_id: Mapping[str, tuple[str, ...]]
|
||||
|
||||
|
||||
def summarize_registry_indexes(indexes: RegistryIndexes) -> dict[str, int]:
|
||||
@@ -50,6 +52,9 @@ def build_registry_indexes(bundle: RegistryBundle) -> RegistryIndexes:
|
||||
|
||||
return RegistryIndexes(
|
||||
agent_by_id=MappingProxyType(agent_by_id),
|
||||
agent_by_role_value=MappingProxyType({
|
||||
agent.role_value: agent for agent in bundle.agents
|
||||
}),
|
||||
sub_commander_by_id=MappingProxyType(sub_commander_by_id),
|
||||
capability_by_id=MappingProxyType(capability_by_id),
|
||||
specialist_template_by_id=MappingProxyType(specialist_template_by_id),
|
||||
@@ -73,4 +78,9 @@ def build_registry_indexes(bundle: RegistryBundle) -> RegistryIndexes:
|
||||
sub_commander.sub_commander_id: tuple(sub_commander.capability_ids)
|
||||
for sub_commander in bundle.sub_commanders
|
||||
}),
|
||||
spawnable_role_values_by_agent_id=MappingProxyType({
|
||||
agent.agent_id: tuple(agent.allowed_spawn_role_values)
|
||||
for agent in bundle.agents
|
||||
if agent.can_spawn_children and agent.allowed_spawn_role_values
|
||||
}),
|
||||
)
|
||||
|
||||
@@ -1,4 +1,19 @@
|
||||
from pydantic import BaseModel
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PermissionClass(str, Enum):
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
EXTERNAL = "external"
|
||||
|
||||
|
||||
class SideEffectScope(str, Enum):
|
||||
NONE = "none"
|
||||
LOCAL_STATE = "local_state"
|
||||
DB_WRITE = "db_write"
|
||||
NETWORK = "network"
|
||||
|
||||
|
||||
class AgentManifest(BaseModel):
|
||||
@@ -8,6 +23,8 @@ class AgentManifest(BaseModel):
|
||||
system_prompt_key: str
|
||||
routing_hints: list[str]
|
||||
default_sub_commanders: list[str]
|
||||
can_spawn_children: bool = False
|
||||
allowed_spawn_role_values: list[str] = Field(default_factory=list)
|
||||
skill_context_key: str | None = None
|
||||
continuity_policy: str | None = None
|
||||
clarification_policy: str | None = None
|
||||
@@ -23,6 +40,12 @@ class SubCommanderManifest(BaseModel):
|
||||
class CapabilityManifest(BaseModel):
|
||||
capability_id: str
|
||||
tool_name: str
|
||||
permission_class: PermissionClass = PermissionClass.READ
|
||||
side_effect_scope: SideEffectScope = SideEffectScope.NONE
|
||||
supports_retry: bool = False
|
||||
idempotent: bool = False
|
||||
safe_for_parallel_use: bool = False
|
||||
requires_confirmation: bool = False
|
||||
|
||||
|
||||
class SpecialistTemplateManifest(BaseModel):
|
||||
|
||||
86
backend/app/agents/runtime_metrics.py
Normal file
86
backend/app/agents/runtime_metrics.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
INPUT_TOKEN_USD_RATE = 0.000003
|
||||
OUTPUT_TOKEN_USD_RATE = 0.000015
|
||||
DEFAULT_COST_THRESHOLDS = {
|
||||
"total_tokens": 4000,
|
||||
"estimated_cost": 0.02,
|
||||
}
|
||||
|
||||
|
||||
def estimate_token_cost(input_tokens: int, output_tokens: int) -> float | None:
|
||||
total_tokens = max(input_tokens, 0) + max(output_tokens, 0)
|
||||
if total_tokens <= 0:
|
||||
return None
|
||||
return round(
|
||||
(max(input_tokens, 0) * INPUT_TOKEN_USD_RATE)
|
||||
+ (max(output_tokens, 0) * OUTPUT_TOKEN_USD_RATE),
|
||||
6,
|
||||
)
|
||||
|
||||
|
||||
def extract_token_usage(response: Any) -> tuple[int, int]:
|
||||
usage_metadata = getattr(response, "usage_metadata", None) or {}
|
||||
if isinstance(usage_metadata, dict):
|
||||
input_tokens = int(
|
||||
usage_metadata.get("input_tokens")
|
||||
or usage_metadata.get("prompt_tokens")
|
||||
or 0
|
||||
)
|
||||
output_tokens = int(
|
||||
usage_metadata.get("output_tokens")
|
||||
or usage_metadata.get("completion_tokens")
|
||||
or 0
|
||||
)
|
||||
if input_tokens or output_tokens:
|
||||
return input_tokens, output_tokens
|
||||
|
||||
response_metadata = getattr(response, "response_metadata", None) or {}
|
||||
token_usage = {}
|
||||
if isinstance(response_metadata, dict):
|
||||
token_usage = response_metadata.get("token_usage") or response_metadata.get("usage") or {}
|
||||
if isinstance(token_usage, dict):
|
||||
input_tokens = int(
|
||||
token_usage.get("prompt_tokens")
|
||||
or token_usage.get("input_tokens")
|
||||
or 0
|
||||
)
|
||||
output_tokens = int(
|
||||
token_usage.get("completion_tokens")
|
||||
or token_usage.get("output_tokens")
|
||||
or 0
|
||||
)
|
||||
if input_tokens or output_tokens:
|
||||
return input_tokens, output_tokens
|
||||
|
||||
return 0, 0
|
||||
|
||||
|
||||
def coerce_cost_thresholds(raw_thresholds: Any) -> dict[str, float]:
|
||||
thresholds: dict[str, float] = dict(DEFAULT_COST_THRESHOLDS)
|
||||
if not isinstance(raw_thresholds, dict):
|
||||
return thresholds
|
||||
for key in DEFAULT_COST_THRESHOLDS:
|
||||
value = raw_thresholds.get(key)
|
||||
if isinstance(value, (int, float)) and value > 0:
|
||||
thresholds[key] = float(value)
|
||||
return thresholds
|
||||
|
||||
|
||||
def is_cost_budget_warning(
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
estimated_cost: float | None,
|
||||
thresholds: dict[str, float] | None = None,
|
||||
) -> bool:
|
||||
effective_thresholds = thresholds or DEFAULT_COST_THRESHOLDS
|
||||
total_tokens = max(input_tokens, 0) + max(output_tokens, 0)
|
||||
token_threshold = float(effective_thresholds.get("total_tokens") or 0)
|
||||
cost_threshold = float(effective_thresholds.get("estimated_cost") or 0)
|
||||
return (
|
||||
(token_threshold > 0 and total_tokens >= token_threshold)
|
||||
or (cost_threshold > 0 and estimated_cost is not None and estimated_cost >= cost_threshold)
|
||||
)
|
||||
60
backend/app/agents/schemas/__init__.py
Normal file
60
backend/app/agents/schemas/__init__.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from app.agents.schemas.event import AgentEvent
|
||||
from app.agents.schemas.learning import (
|
||||
LearningDecision,
|
||||
LearningSignal,
|
||||
PatternCandidate,
|
||||
SessionRetrospective,
|
||||
SkillCandidate,
|
||||
)
|
||||
from app.agents.schemas.message import AgentMessage
|
||||
from app.agents.schemas.orchestration import (
|
||||
ExecutionDecision,
|
||||
MergeReport,
|
||||
ParallelWorthiness,
|
||||
RuntimeRequestContext,
|
||||
SubTaskResult,
|
||||
SubTaskSpec,
|
||||
TaskGraph,
|
||||
TaskNode,
|
||||
VerificationReport,
|
||||
)
|
||||
from app.agents.schemas.skills import SkillActivationRecord, SkillShortlistEntry
|
||||
from app.agents.schemas.task import (
|
||||
AgentTask,
|
||||
CollaborationBudget,
|
||||
InterruptRecord,
|
||||
RecoveryRecord,
|
||||
TaskLifecycleStatus,
|
||||
TaskResult,
|
||||
TaskResultStatus,
|
||||
VerificationStatus,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentEvent",
|
||||
"AgentMessage",
|
||||
"ExecutionDecision",
|
||||
"AgentTask",
|
||||
"CollaborationBudget",
|
||||
"InterruptRecord",
|
||||
"LearningDecision",
|
||||
"LearningSignal",
|
||||
"MergeReport",
|
||||
"ParallelWorthiness",
|
||||
"PatternCandidate",
|
||||
"RecoveryRecord",
|
||||
"RuntimeRequestContext",
|
||||
"SessionRetrospective",
|
||||
"SkillActivationRecord",
|
||||
"SkillCandidate",
|
||||
"SkillShortlistEntry",
|
||||
"SubTaskResult",
|
||||
"SubTaskSpec",
|
||||
"TaskGraph",
|
||||
"TaskNode",
|
||||
"TaskLifecycleStatus",
|
||||
"TaskResult",
|
||||
"TaskResultStatus",
|
||||
"VerificationReport",
|
||||
"VerificationStatus",
|
||||
]
|
||||
63
backend/app/agents/schemas/event.py
Normal file
63
backend/app/agents/schemas/event.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
AgentEventType = Literal[
|
||||
"agent.execution.decided",
|
||||
"agent.parallel.assessed",
|
||||
"agent.skill.shortlisted",
|
||||
"agent.task_graph.built",
|
||||
"agent.subtask.started",
|
||||
"agent.subtask.completed",
|
||||
"agent.merge.completed",
|
||||
"agent.tool.start",
|
||||
"agent.tool.result",
|
||||
"agent.verify.started",
|
||||
"agent.verify.completed",
|
||||
"agent.retrospective.created",
|
||||
"agent.learning.decision",
|
||||
"agent.skill.lifecycle.changed",
|
||||
"agent.rollback.triggered",
|
||||
"agent.created",
|
||||
"agent.spawn.blocked",
|
||||
"agent.message.sent",
|
||||
"agent.message.received",
|
||||
"agent.interrupt.requested",
|
||||
"agent.interrupt.completed",
|
||||
"agent.recovery.started",
|
||||
"agent.recovery.completed",
|
||||
"agent.task.interrupted",
|
||||
"agent.task.recovered",
|
||||
"agent.task.reassigned",
|
||||
"agent.collaboration.budget.updated",
|
||||
"agent.isolation.selected",
|
||||
"agent.isolation.fallback",
|
||||
"agent.cost.updated",
|
||||
"agent.cost.warning",
|
||||
"agent.phase.changed",
|
||||
"agent.checkpoint.recorded",
|
||||
"agent.error",
|
||||
]
|
||||
AgentEventSeverity = Literal["info", "warning", "error"]
|
||||
|
||||
|
||||
class AgentEvent(BaseModel):
|
||||
event_id: str
|
||||
event_type: AgentEventType
|
||||
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
conversation_id: str | None = None
|
||||
agent_id: str | None = None
|
||||
sub_commander_id: str | None = None
|
||||
task_id: str | None = None
|
||||
parent_task_id: str | None = None
|
||||
child_task_id: str | None = None
|
||||
thread_id: str | None = None
|
||||
message_id: str | None = None
|
||||
interrupt_id: str | None = None
|
||||
recovery_id: str | None = None
|
||||
payload: dict[str, Any] = Field(default_factory=dict)
|
||||
severity: AgentEventSeverity = "info"
|
||||
76
backend/app/agents/schemas/learning.py
Normal file
76
backend/app/agents/schemas/learning.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
LearningSignalType = Literal[
|
||||
"preference",
|
||||
"workflow",
|
||||
"decomposition",
|
||||
"tool_success",
|
||||
"correction",
|
||||
]
|
||||
|
||||
|
||||
class SessionRetrospective(BaseModel):
|
||||
retrospective_id: str | None = None
|
||||
user_id: str
|
||||
conversation_id: str
|
||||
request_message_id: str | None = None
|
||||
response_message_id: str | None = None
|
||||
query_text: str
|
||||
final_response: str | None = None
|
||||
summary: str
|
||||
task_type: str | None = None
|
||||
execution_mode: str | None = None
|
||||
primary_agent: str | None = None
|
||||
verification_status: str | None = None
|
||||
verification_summary: str | None = None
|
||||
used_skill_names: list[str] = Field(default_factory=list)
|
||||
evidence_refs: list[dict[str, Any]] = Field(default_factory=list)
|
||||
task_refs: list[dict[str, Any]] = Field(default_factory=list)
|
||||
event_refs: list[dict[str, Any]] = Field(default_factory=list)
|
||||
context_snapshot: dict[str, Any] = Field(default_factory=dict)
|
||||
learning_signals: list["LearningSignal"] = Field(default_factory=list)
|
||||
pattern_candidates: list["PatternCandidate"] = Field(default_factory=list)
|
||||
skill_candidates: list["SkillCandidate"] = Field(default_factory=list)
|
||||
learning_decision: "LearningDecision | None" = None
|
||||
outcome: Literal["completed", "partial", "failed"] = "completed"
|
||||
captured_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
class LearningSignal(BaseModel):
|
||||
signal_type: LearningSignalType
|
||||
confidence: float = 0.0
|
||||
evidence_refs: list[dict[str, Any]] = Field(default_factory=list)
|
||||
explanation: str | None = None
|
||||
payload: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class PatternCandidate(BaseModel):
|
||||
pattern_id: str
|
||||
pattern_type: str
|
||||
description: str
|
||||
confidence: float = 0.0
|
||||
evidence_refs: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class SkillCandidate(BaseModel):
|
||||
candidate_id: str
|
||||
name: str
|
||||
summary: str
|
||||
candidate_type: Literal["workflow_skill", "preference_skill", "decomposition_skill"] = "workflow_skill"
|
||||
source_pattern_ids: list[str] = Field(default_factory=list)
|
||||
confidence: float = 0.0
|
||||
evidence_refs: list[dict[str, Any]] = Field(default_factory=list)
|
||||
recommended_status: Literal["candidate", "shadow"] = "candidate"
|
||||
|
||||
|
||||
class LearningDecision(BaseModel):
|
||||
decision: Literal["reinforce_memory", "create_candidate", "promote_skill", "defer", "reject"]
|
||||
explanation: str
|
||||
evidence_refs: list[dict[str, Any]] = Field(default_factory=list)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
29
backend/app/agents/schemas/message.py
Normal file
29
backend/app/agents/schemas/message.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
AgentMessageType = Literal[
|
||||
"task_request",
|
||||
"task_update",
|
||||
"handoff",
|
||||
"verification_request",
|
||||
"verification_feedback",
|
||||
"interrupt_notice",
|
||||
]
|
||||
|
||||
|
||||
class AgentMessage(BaseModel):
|
||||
message_id: str
|
||||
thread_id: str
|
||||
from_agent_id: str
|
||||
to_agent_id: str
|
||||
task_id: str | None = None
|
||||
reply_to_message_id: str | None = None
|
||||
message_type: AgentMessageType = "task_update"
|
||||
content_summary: str
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
payload: dict[str, Any] = Field(default_factory=dict)
|
||||
211
backend/app/agents/schemas/orchestration.py
Normal file
211
backend/app/agents/schemas/orchestration.py
Normal file
@@ -0,0 +1,211 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agents.schemas.skills import SkillShortlistEntry
|
||||
|
||||
|
||||
ExecutionMode = Literal["direct", "collaboration", "parallel", "delegated"]
|
||||
ParallelPreference = Literal["direct", "collaboration", "parallel"]
|
||||
|
||||
|
||||
class ParallelWorthiness(BaseModel):
|
||||
should_parallelize: bool = False
|
||||
score: float = 0.0
|
||||
estimated_subtasks: int = 1
|
||||
preferred_mode: ParallelPreference = "direct"
|
||||
reasons: list[str] = Field(default_factory=list)
|
||||
risk_flags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TaskNode(BaseModel):
|
||||
node_id: str
|
||||
title: str
|
||||
role: str | None = None
|
||||
goal: str | None = None
|
||||
depends_on: list[str] = Field(default_factory=list)
|
||||
execution_mode: Literal["serial", "parallel"] = "serial"
|
||||
expected_evidence: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TaskGraph(BaseModel):
|
||||
graph_id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
nodes: list[TaskNode] = Field(default_factory=list)
|
||||
entry_node_ids: list[str] = Field(default_factory=list)
|
||||
max_parallelism: int = 1
|
||||
rationale: str | None = None
|
||||
|
||||
|
||||
class SubTaskSpec(BaseModel):
|
||||
subtask_id: str
|
||||
parent_run_id: str
|
||||
title: str
|
||||
role: str
|
||||
goal: str
|
||||
context_slice: dict[str, Any] = Field(default_factory=dict)
|
||||
allowed_tools: list[str] = Field(default_factory=list)
|
||||
budget_tokens: int = 1200
|
||||
budget_tool_calls: int = 2
|
||||
expected_output_schema: dict[str, Any] = Field(default_factory=dict)
|
||||
expected_evidence: list[dict[str, Any]] = Field(default_factory=list)
|
||||
dependencies: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class SubTaskResult(BaseModel):
|
||||
subtask_id: str
|
||||
status: Literal["completed", "failed", "blocked"]
|
||||
summary: str | None = None
|
||||
evidence: list[dict[str, Any]] = Field(default_factory=list)
|
||||
output: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class MergeReport(BaseModel):
|
||||
merge_id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
status: Literal["merged", "conflicted", "fallback"]
|
||||
summary: str | None = None
|
||||
evidence_union: list[dict[str, Any]] = Field(default_factory=list)
|
||||
conflict_flags: list[str] = Field(default_factory=list)
|
||||
resolution_strategy: str | None = None
|
||||
resolved_summary: str | None = None
|
||||
fallback_used: bool = False
|
||||
|
||||
|
||||
class VerificationReport(BaseModel):
|
||||
status: Literal["passed", "failed", "skipped"]
|
||||
summary: str | None = None
|
||||
evidence: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ExecutionDecision(BaseModel):
|
||||
request_id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
mode: ExecutionMode = "direct"
|
||||
reason: str
|
||||
complexity_score: float = 0.0
|
||||
parallel_worthiness_score: float | None = None
|
||||
selected_roles: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class RuntimeRequestContext(BaseModel):
|
||||
request_id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
session_id: str | None = None
|
||||
user_id: str
|
||||
conversation_id: str | None = None
|
||||
query_text: str | None = None
|
||||
raw_user_query: str | None = None
|
||||
recalled_memories: list[str] = Field(default_factory=list)
|
||||
retrospective_shortlist: list[dict[str, Any]] = Field(default_factory=list)
|
||||
recalled_retrospectives: list[dict[str, Any]] = Field(default_factory=list)
|
||||
skill_shortlist: list[SkillShortlistEntry] = Field(default_factory=list)
|
||||
shortlisted_skills: list[str] = Field(default_factory=list)
|
||||
parallel_worthiness: ParallelWorthiness = Field(default_factory=ParallelWorthiness)
|
||||
task_graph: TaskGraph | None = None
|
||||
recommended_runtime_mode: Literal["direct", "collaboration"] = "direct"
|
||||
execution_mode: Literal["direct", "collaboration"] | None = None
|
||||
current_agent_role: str | None = None
|
||||
conversation_state_ref: str | None = None
|
||||
assembly_metrics: dict[str, float] = Field(default_factory=dict)
|
||||
assembled_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
def assess_parallel_worthiness(
|
||||
query_text: str,
|
||||
*,
|
||||
retrospective_count: int = 0,
|
||||
skill_count: int = 0,
|
||||
) -> ParallelWorthiness:
|
||||
normalized = (query_text or "").strip().lower()
|
||||
reasons: list[str] = []
|
||||
score = 0.0
|
||||
|
||||
multi_step_markers = ("然后", "接着", "同时", "并且", "最后", "汇总", "对比", "分析", "research")
|
||||
artifact_markers = ("文档", "代码", "文件", "数据库", "论坛", "知识库", "计划")
|
||||
|
||||
if any(marker in normalized for marker in multi_step_markers):
|
||||
score += 0.35
|
||||
reasons.append("multi_step_request")
|
||||
|
||||
if sum(1 for marker in artifact_markers if marker in normalized) >= 2:
|
||||
score += 0.25
|
||||
reasons.append("multi_source_context")
|
||||
|
||||
if len(re.findall(r"[,,、;;]", query_text or "")) >= 2:
|
||||
score += 0.15
|
||||
reasons.append("compound_instruction")
|
||||
|
||||
if retrospective_count > 0:
|
||||
score += 0.1
|
||||
reasons.append("historical_support")
|
||||
|
||||
if skill_count > 0:
|
||||
score += 0.1
|
||||
reasons.append("skill_candidates_available")
|
||||
|
||||
score = min(score, 1.0)
|
||||
should_parallelize = score >= 0.55
|
||||
preferred_mode: ParallelPreference = "parallel" if should_parallelize else "direct"
|
||||
if not should_parallelize and score >= 0.3:
|
||||
preferred_mode = "collaboration"
|
||||
|
||||
estimated_subtasks = 1
|
||||
if preferred_mode == "parallel":
|
||||
estimated_subtasks = 3 if score >= 0.8 else 2
|
||||
elif preferred_mode == "collaboration":
|
||||
estimated_subtasks = 2
|
||||
|
||||
return ParallelWorthiness(
|
||||
should_parallelize=should_parallelize,
|
||||
score=round(score, 3),
|
||||
estimated_subtasks=estimated_subtasks,
|
||||
preferred_mode=preferred_mode,
|
||||
reasons=reasons,
|
||||
)
|
||||
|
||||
|
||||
def render_runtime_request_context_summary(context: RuntimeRequestContext) -> str:
|
||||
lines = ["【Runtime Request Context】"]
|
||||
lines.append(f"- 推荐运行模式: {context.recommended_runtime_mode}")
|
||||
lines.append(
|
||||
f"- 并行潜力: score={context.parallel_worthiness.score:.2f}, "
|
||||
f"preferred={context.parallel_worthiness.preferred_mode}, "
|
||||
f"estimated_subtasks={context.parallel_worthiness.estimated_subtasks}"
|
||||
)
|
||||
|
||||
if context.parallel_worthiness.reasons:
|
||||
lines.append(f"- 并行判断依据: {', '.join(context.parallel_worthiness.reasons)}")
|
||||
if context.assembly_metrics:
|
||||
total_ms = context.assembly_metrics.get("total_ms")
|
||||
if total_ms is not None:
|
||||
lines.append(f"- 上下文装配耗时: {total_ms:.1f} ms")
|
||||
|
||||
if context.task_graph and context.task_graph.nodes:
|
||||
lines.append(
|
||||
f"- 任务图: nodes={len(context.task_graph.nodes)}, max_parallelism={context.task_graph.max_parallelism}"
|
||||
)
|
||||
for node in context.task_graph.nodes[:4]:
|
||||
deps = f", deps={len(node.depends_on)}" if node.depends_on else ""
|
||||
lines.append(f" - [{node.execution_mode}] {node.title} ({node.role}{deps})")
|
||||
|
||||
if context.retrospective_shortlist:
|
||||
lines.append("- 历史复盘命中:")
|
||||
for item in context.retrospective_shortlist[:3]:
|
||||
summary = (item.get("summary") or item.get("summary_text") or "").strip()
|
||||
task_type = item.get("task_type") or "unknown"
|
||||
lines.append(f" - [{task_type}] {summary[:160]}")
|
||||
|
||||
if context.skill_shortlist:
|
||||
lines.append("- 技能候选:")
|
||||
for item in context.skill_shortlist[:3]:
|
||||
lines.append(
|
||||
f" - {item.skill_name} ({item.injection_mode}, score={item.score:.2f})"
|
||||
+ (f": {item.rationale}" if item.rationale else "")
|
||||
)
|
||||
|
||||
if context.recalled_memories:
|
||||
lines.append("- 记忆上下文已装配,可在回答中按需引用。")
|
||||
|
||||
return "\n".join(lines)
|
||||
38
backend/app/agents/schemas/skills.py
Normal file
38
backend/app/agents/schemas/skills.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
SkillStatus = Literal["candidate", "shadow", "active", "deprecated", "retired"]
|
||||
SkillInjectionMode = Literal["metadata_only", "summary", "full"]
|
||||
|
||||
|
||||
class SkillShortlistEntry(BaseModel):
|
||||
skill_name: str
|
||||
source: str = "runtime"
|
||||
source_id: str | None = None
|
||||
status: SkillStatus = "active"
|
||||
scope: list[str] = Field(default_factory=list)
|
||||
effectiveness: float | None = None
|
||||
score: float = 0.0
|
||||
rationale: str | None = None
|
||||
summary: str | None = None
|
||||
matched_terms: list[str] = Field(default_factory=list)
|
||||
injection_mode: SkillInjectionMode = "metadata_only"
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class SkillActivationRecord(BaseModel):
|
||||
skill_name: str
|
||||
source: str = "runtime"
|
||||
source_id: str | None = None
|
||||
status: SkillStatus = "active"
|
||||
injection_mode: SkillInjectionMode = "metadata_only"
|
||||
matched_terms: list[str] = Field(default_factory=list)
|
||||
rationale: str | None = None
|
||||
activated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
outcome: str | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
133
backend/app/agents/schemas/task.py
Normal file
133
backend/app/agents/schemas/task.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
TaskLifecycleStatus = Literal["pending", "in_progress", "completed", "failed", "blocked"]
|
||||
VerificationStatus = Literal["passed", "failed", "skipped"]
|
||||
TaskResultStatus = Literal["completed", "failed", "blocked", "passed", "skipped"]
|
||||
InterruptStatus = Literal["requested", "acknowledged", "resolved"]
|
||||
BudgetMode = Literal["direct", "collaboration"]
|
||||
|
||||
|
||||
class CodeProviderType(str, Enum):
|
||||
CLAUDE = "claude"
|
||||
GEMINI = "gemini"
|
||||
CODEX = "codex"
|
||||
OPENCODE = "opencode"
|
||||
|
||||
|
||||
class RiskLevelType(str, Enum):
|
||||
LOW = "low"
|
||||
HIGH = "high"
|
||||
|
||||
|
||||
class InterruptRecord(BaseModel):
|
||||
interrupt_id: str
|
||||
reason: str
|
||||
status: InterruptStatus = "requested"
|
||||
requested_by: str | None = None
|
||||
source_event_id: str | None = None
|
||||
requested_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
payload: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class RecoveryRecord(BaseModel):
|
||||
recovery_id: str
|
||||
source_interrupt_id: str | None = None
|
||||
strategy: str | None = None
|
||||
resumed_from_task_id: str | None = None
|
||||
resumed_from_thread_id: str | None = None
|
||||
recovered_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
payload: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class CollaborationBudget(BaseModel):
|
||||
mode: BudgetMode = "direct"
|
||||
max_parallel_tasks: int | None = None
|
||||
remaining_parallel_tasks: int | None = None
|
||||
max_tool_calls: int | None = None
|
||||
remaining_tool_calls: int | None = None
|
||||
max_iterations: int | None = None
|
||||
remaining_iterations: int | None = None
|
||||
escalation_threshold: int | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class AgentTask(BaseModel):
|
||||
task_id: str
|
||||
title: str
|
||||
status: TaskLifecycleStatus = "pending"
|
||||
owner_agent_id: str | None = None
|
||||
role: str | None = None
|
||||
goal: str | None = None
|
||||
parent_task_id: str | None = None
|
||||
child_task_ids: list[str] = Field(default_factory=list)
|
||||
thread_id: str | None = None
|
||||
message_id: str | None = None
|
||||
message_index: int | None = None
|
||||
expected_evidence: list[dict[str, Any]] = Field(default_factory=list)
|
||||
evidence: list[dict[str, Any]] = Field(default_factory=list)
|
||||
interrupt_records: list[InterruptRecord | dict[str, Any]] = Field(default_factory=list)
|
||||
recovery_records: list[RecoveryRecord | dict[str, Any]] = Field(default_factory=list)
|
||||
collaboration_budget: CollaborationBudget | dict[str, Any] | None = None
|
||||
result_summary: str | None = None
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
class TaskResult(BaseModel):
|
||||
task_id: str
|
||||
status: TaskResultStatus
|
||||
summary: str | None = None
|
||||
evidence: list[dict[str, Any]] = Field(default_factory=list)
|
||||
owner_agent_id: str | None = None
|
||||
parent_task_id: str | None = None
|
||||
child_task_ids: list[str] = Field(default_factory=list)
|
||||
thread_id: str | None = None
|
||||
message_id: str | None = None
|
||||
message_index: int | None = None
|
||||
interrupt_records: list[InterruptRecord | dict[str, Any]] = Field(default_factory=list)
|
||||
recovery_records: list[RecoveryRecord | dict[str, Any]] = Field(default_factory=list)
|
||||
budget_snapshot: CollaborationBudget | dict[str, Any] | None = None
|
||||
next_action: str | None = None
|
||||
output_data: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class CodeTaskType(str, Enum):
|
||||
DEMO = "demo"
|
||||
PROJECT = "project"
|
||||
MODIFICATION = "modification"
|
||||
|
||||
|
||||
class CodeTask(BaseModel):
|
||||
"""代码任务请求模型"""
|
||||
|
||||
task_id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
task_type: CodeTaskType
|
||||
ai_provider: CodeProviderType
|
||||
sandbox_mode: bool = False
|
||||
workspace_path: str | None = None
|
||||
user_prompt: str
|
||||
parent_task_id: str | None = None
|
||||
thread_id: str | None = None
|
||||
message_id: str | None = None
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
class CodeExecutionResultSchema(BaseModel):
|
||||
"""代码执行结果模型 (API 响应用)"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
files_created: list[str] = Field(default_factory=list)
|
||||
output: str = ""
|
||||
error: str | None = None
|
||||
exit_code: int = 0
|
||||
execution_time: float | None = None
|
||||
sandbox_session_id: str | None = None
|
||||
17
backend/app/agents/session/__init__.py
Normal file
17
backend/app/agents/session/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Agent Session Management - Phase 10.3"""
|
||||
|
||||
from app.agents.session.manager import (
|
||||
AgentSession,
|
||||
SessionContext,
|
||||
SessionPersistence,
|
||||
create_agent_session,
|
||||
get_agent_session,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentSession",
|
||||
"SessionContext",
|
||||
"SessionPersistence",
|
||||
"create_agent_session",
|
||||
"get_agent_session",
|
||||
]
|
||||
238
backend/app/agents/session/manager.py
Normal file
238
backend/app/agents/session/manager.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""Agent Session 管理 - Phase 10.3
|
||||
|
||||
支持会话层级管理和子会话创建。
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionContext:
|
||||
"""会话上下文"""
|
||||
|
||||
session_id: str
|
||||
parent_session_id: str | None = None
|
||||
root_session_id: str | None = None
|
||||
depth: int = 0
|
||||
user_id: str | None = None
|
||||
created_at: str | None = None
|
||||
last_active: str | None = None
|
||||
message_count: int = 0
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.now().isoformat()
|
||||
if self.last_active is None:
|
||||
self.last_active = self.created_at
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionPersistence:
|
||||
"""会话持久化"""
|
||||
|
||||
def __init__(self, persistence_dir: str | None = None):
|
||||
if persistence_dir is None:
|
||||
persistence_dir = os.path.join(
|
||||
os.path.dirname(__file__), "..", "..", "..", "data", "sessions"
|
||||
)
|
||||
self.persistence_dir = persistence_dir
|
||||
|
||||
def _get_session_path(self, session_id: str) -> str:
|
||||
return os.path.join(self.persistence_dir, f"{session_id}.json")
|
||||
|
||||
def save(self, session: "AgentSession") -> bool:
|
||||
"""保存会话"""
|
||||
try:
|
||||
os.makedirs(self.persistence_dir, exist_ok=True)
|
||||
path = self._get_session_path(session.session_id)
|
||||
data = {
|
||||
"session_id": session.session_id,
|
||||
"parent_session_id": session.context.parent_session_id,
|
||||
"root_session_id": session.context.root_session_id,
|
||||
"depth": session.context.depth,
|
||||
"user_id": session.context.user_id,
|
||||
"created_at": session.context.created_at,
|
||||
"last_active": session.context.last_active,
|
||||
"message_count": session.context.message_count,
|
||||
"metadata": session.context.metadata,
|
||||
"history": session._history,
|
||||
}
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def load(self, session_id: str) -> dict[str, Any] | None:
|
||||
"""加载会话"""
|
||||
try:
|
||||
path = self._get_session_path(session_id)
|
||||
if not os.path.exists(path):
|
||||
return None
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def delete(self, session_id: str) -> bool:
|
||||
"""删除会话"""
|
||||
try:
|
||||
path = self._get_session_path(session_id)
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]:
|
||||
"""列出所有会话"""
|
||||
sessions = []
|
||||
try:
|
||||
os.makedirs(self.persistence_dir, exist_ok=True)
|
||||
for filename in os.listdir(self.persistence_dir):
|
||||
if filename.endswith(".json"):
|
||||
path = os.path.join(self.persistence_dir, filename)
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if user_id is None or data.get("user_id") == user_id:
|
||||
sessions.append(data)
|
||||
except Exception:
|
||||
pass
|
||||
return sessions
|
||||
|
||||
|
||||
class AgentSession:
|
||||
"""Agent 会话管理器
|
||||
|
||||
支持:
|
||||
- 会话层级(parent/root/depth)
|
||||
- 子会话创建
|
||||
- 会话摘要
|
||||
- 持久化
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
parent_session_id: str | None = None,
|
||||
):
|
||||
self.session_id = session_id or str(uuid.uuid4())[:8]
|
||||
self.context = SessionContext(
|
||||
session_id=self.session_id,
|
||||
user_id=user_id,
|
||||
parent_session_id=parent_session_id,
|
||||
depth=0 if parent_session_id is None else 1,
|
||||
)
|
||||
self._history: list[dict[str, Any]] = []
|
||||
self._persistence = SessionPersistence()
|
||||
|
||||
# 如果有父会话,设置 root_session_id
|
||||
if parent_session_id:
|
||||
parent_data = self._persistence.load(parent_session_id)
|
||||
if parent_data:
|
||||
self.context.root_session_id = (
|
||||
parent_data.get("root_session_id") or parent_session_id
|
||||
)
|
||||
self.context.depth = parent_data.get("depth", 0) + 1
|
||||
|
||||
async def initialize(self) -> dict[str, Any]:
|
||||
"""初始化会话"""
|
||||
self.context.last_active = datetime.now().isoformat()
|
||||
self._persistence.save(self)
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"depth": self.context.depth,
|
||||
"parent_session_id": self.context.parent_session_id,
|
||||
"root_session_id": self.context.root_session_id,
|
||||
}
|
||||
|
||||
async def process_message(self, message: str, response: str) -> None:
|
||||
"""处理消息并记录到历史"""
|
||||
self.context.message_count += 1
|
||||
self.context.last_active = datetime.now().isoformat()
|
||||
self._history.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": message,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
self._history.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": response,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
self._persistence.save(self)
|
||||
|
||||
async def spawn_child_session(self, user_id: str | None = None) -> "AgentSession":
|
||||
"""创建子会话"""
|
||||
child = AgentSession(
|
||||
user_id=user_id or self.context.user_id,
|
||||
parent_session_id=self.session_id,
|
||||
)
|
||||
child.context.root_session_id = self.context.root_session_id or self.session_id
|
||||
await child.initialize()
|
||||
return child
|
||||
|
||||
async def get_session_summary(self) -> dict[str, Any]:
|
||||
"""获取会话摘要"""
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"parent_session_id": self.context.parent_session_id,
|
||||
"root_session_id": self.context.root_session_id,
|
||||
"depth": self.context.depth,
|
||||
"user_id": self.context.user_id,
|
||||
"created_at": self.context.created_at,
|
||||
"last_active": self.context.last_active,
|
||||
"message_count": self.context.message_count,
|
||||
"history_length": len(self._history),
|
||||
}
|
||||
|
||||
async def persist(self) -> bool:
|
||||
"""持久化会话"""
|
||||
return self._persistence.save(self)
|
||||
|
||||
def get_history(self) -> list[dict[str, Any]]:
|
||||
"""获取会话历史"""
|
||||
return self._history.copy()
|
||||
|
||||
def add_metadata(self, key: str, value: Any) -> None:
|
||||
"""添加会话元数据"""
|
||||
self.context.metadata[key] = value
|
||||
|
||||
def get_metadata(self, key: str) -> Any:
|
||||
"""获取会话元数据"""
|
||||
return self.context.metadata.get(key)
|
||||
|
||||
|
||||
# 全局会话存储(内存中)
|
||||
_sessions: dict[str, AgentSession] = {}
|
||||
|
||||
|
||||
def get_agent_session(session_id: str) -> AgentSession | None:
|
||||
"""获取会话"""
|
||||
return _sessions.get(session_id)
|
||||
|
||||
|
||||
def create_agent_session(
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
parent_session_id: str | None = None,
|
||||
) -> AgentSession:
|
||||
"""创建新会话"""
|
||||
session = AgentSession(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
parent_session_id=parent_session_id,
|
||||
)
|
||||
_sessions[session.session_id] = session
|
||||
return session
|
||||
1
backend/app/agents/skills/__init__.py
Normal file
1
backend/app/agents/skills/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Skill package."""
|
||||
72
backend/app/agents/skills/bundled.py
Normal file
72
backend/app/agents/skills/bundled.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Built-in Skills - Phase 9.4
|
||||
|
||||
This module contains bundled skills that are always available
|
||||
without requiring external skill loaders.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
# SkillMetadata-compatible structure for bundled skills
|
||||
BUNDLED_SKILLS: list[dict[str, Any]] = [
|
||||
{
|
||||
"id": "code-analysis",
|
||||
"name": "Code Analysis",
|
||||
"description": "Analyze code structure, patterns, and quality. Helps understand codebase architecture, find issues, and suggest improvements.",
|
||||
"version": "1.0.0",
|
||||
"prompts": [
|
||||
"Analyze the code structure and identify key components, their relationships, and responsibilities.",
|
||||
"Review the code for potential issues like bugs, security vulnerabilities, or performance problems.",
|
||||
"Explain how the code works and what it does in simple terms.",
|
||||
],
|
||||
"tools": ["grep", "read", "glob", "lsp_symbols", "lsp_find_references"],
|
||||
},
|
||||
{
|
||||
"id": "git-helper",
|
||||
"name": "Git Helper",
|
||||
"description": "Assists with Git operations including commit, branch management, merge conflicts, and repository exploration.",
|
||||
"version": "1.0.0",
|
||||
"prompts": [
|
||||
"Show me the current git status and any uncommitted changes.",
|
||||
"Help me create a meaningful commit message for these changes.",
|
||||
"Explain the git history and branch structure of this repository.",
|
||||
],
|
||||
"tools": ["bash"],
|
||||
},
|
||||
{
|
||||
"id": "web-research",
|
||||
"name": "Web Research",
|
||||
"description": "Search the web for information, documentation, and resources. Helps find answers and learn about technologies.",
|
||||
"version": "1.0.0",
|
||||
"prompts": [
|
||||
"Search the web for information about {topic} and summarize the key findings.",
|
||||
"Find official documentation or reliable resources about {topic}.",
|
||||
"Look up the latest news or developments in {topic}.",
|
||||
],
|
||||
"tools": ["search_brave_web_search", "websearch_web_search_exa", "webfetch"],
|
||||
},
|
||||
{
|
||||
"id": "file-management",
|
||||
"name": "File Management",
|
||||
"description": "Helps with file operations like creating, editing, organizing, and managing project files and directories.",
|
||||
"version": "1.0.0",
|
||||
"prompts": [
|
||||
"Create a new file at {path} with the following content: {content}",
|
||||
"Organize the files in the project structure and suggest improvements.",
|
||||
"Find all files related to {topic} or matching {pattern}.",
|
||||
],
|
||||
"tools": ["read", "write", "glob", "bash"],
|
||||
},
|
||||
{
|
||||
"id": "task-planning",
|
||||
"name": "Task Planning",
|
||||
"description": "Helps break down complex tasks into smaller steps, create implementation plans, and track progress.",
|
||||
"version": "1.0.0",
|
||||
"prompts": [
|
||||
"Break down this task into smaller, manageable steps: {task}",
|
||||
"Create an implementation plan for building {feature} with clear phases.",
|
||||
"Review the current progress and suggest next steps for completing {goal}.",
|
||||
],
|
||||
"tools": ["todowrite", "read", "write"],
|
||||
},
|
||||
]
|
||||
14
backend/app/agents/skills/effectiveness.py
Normal file
14
backend/app/agents/skills/effectiveness.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.models.skill import Skill
|
||||
|
||||
|
||||
def summarize_skill_effectiveness(skill: Skill) -> dict[str, object]:
|
||||
return {
|
||||
"name": skill.name,
|
||||
"status": skill.status,
|
||||
"effectiveness": skill.effectiveness,
|
||||
"activation_count": skill.activation_count,
|
||||
"candidate_count": getattr(skill, "candidate_count", 0),
|
||||
"last_activated_at": skill.last_activated_at.isoformat() if skill.last_activated_at else None,
|
||||
}
|
||||
58
backend/app/agents/skills/evaluator.py
Normal file
58
backend/app/agents/skills/evaluator.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from app.agents.schemas.learning import SessionRetrospective, SkillCandidate
|
||||
from app.agents.skills.models import SkillLifecycleDecision
|
||||
from app.services.skill_service import SkillService
|
||||
|
||||
|
||||
class SkillPromotionEvaluator:
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
self.skill_service = SkillService(db)
|
||||
|
||||
async def sync_retrospective(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
retrospective: SessionRetrospective,
|
||||
) -> list[SkillLifecycleDecision]:
|
||||
decisions: list[SkillLifecycleDecision] = []
|
||||
|
||||
for candidate in retrospective.skill_candidates:
|
||||
decisions.append(
|
||||
await self.skill_service.upsert_learned_candidate(
|
||||
user_id=user_id,
|
||||
candidate=candidate,
|
||||
primary_agent=retrospective.primary_agent,
|
||||
evidence_refs=candidate.evidence_refs,
|
||||
)
|
||||
)
|
||||
|
||||
outcome_score = self._derive_outcome_score(retrospective)
|
||||
for skill_name in retrospective.used_skill_names:
|
||||
decision = await self.skill_service.record_activation_feedback(
|
||||
user_id=user_id,
|
||||
skill_name=skill_name,
|
||||
outcome_score=outcome_score,
|
||||
evidence_refs=retrospective.evidence_refs,
|
||||
)
|
||||
if decision is not None:
|
||||
decisions.append(decision)
|
||||
|
||||
return decisions
|
||||
|
||||
@staticmethod
|
||||
def _derive_outcome_score(retrospective: SessionRetrospective) -> float:
|
||||
if retrospective.verification_status == "passed":
|
||||
return 0.9
|
||||
if retrospective.verification_status == "skipped":
|
||||
return 0.55
|
||||
if retrospective.verification_status == "failed":
|
||||
return 0.15
|
||||
return 0.7 if retrospective.outcome == "completed" else 0.2
|
||||
|
||||
|
||||
def next_review_after(days: int = 7) -> datetime:
|
||||
return datetime.now(UTC) + timedelta(days=days)
|
||||
12
backend/app/agents/skills/loaders/__init__.py
Normal file
12
backend/app/agents/skills/loaders/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Skills 加载器包"""
|
||||
|
||||
from app.agents.skills.loaders.local_loader import LocalSkillLoader
|
||||
from app.agents.skills.loaders.plugin_loader import PluginSkillLoader
|
||||
from app.agents.skills.loaders.mcp_loader import MCPSkillLoader, get_mcp_skill_loader
|
||||
|
||||
__all__ = [
|
||||
"LocalSkillLoader",
|
||||
"PluginSkillLoader",
|
||||
"MCPSkillLoader",
|
||||
"get_mcp_skill_loader",
|
||||
]
|
||||
100
backend/app/agents/skills/loaders/local_loader.py
Normal file
100
backend/app/agents/skills/loaders/local_loader.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""本地 Skills 加载器 - Phase 9.2"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.agents.skills.metadata import SkillMetadata
|
||||
|
||||
|
||||
class LocalSkillLoader:
|
||||
"""本地 Skills 加载器
|
||||
|
||||
从 skills_dir 目录加载 SKILL.md 文件。
|
||||
"""
|
||||
|
||||
def __init__(self, skills_dir: str):
|
||||
self.skills_dir = skills_dir
|
||||
|
||||
def load_all(self) -> list[SkillMetadata]:
|
||||
"""加载所有本地 Skills
|
||||
|
||||
Returns:
|
||||
Skill 元数据列表
|
||||
"""
|
||||
skills = []
|
||||
|
||||
if not os.path.exists(self.skills_dir):
|
||||
return skills
|
||||
|
||||
for root, dirs, files in os.walk(self.skills_dir):
|
||||
# 跳过隐藏目录
|
||||
dirs[:] = [d for d in dirs if not d.startswith(".")]
|
||||
|
||||
if "SKILL.md" in files:
|
||||
skill = self._load_skill_from_dir(root)
|
||||
if skill:
|
||||
skills.append(skill)
|
||||
|
||||
return skills
|
||||
|
||||
def _load_skill_from_dir(self, skill_dir: str) -> SkillMetadata | None:
|
||||
"""从目录加载 Skill
|
||||
|
||||
Args:
|
||||
skill_dir: Skill 目录
|
||||
|
||||
Returns:
|
||||
Skill 元数据
|
||||
"""
|
||||
skill_path = os.path.join(skill_dir, "SKILL.md")
|
||||
|
||||
try:
|
||||
with open(skill_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# 解析 frontmatter
|
||||
metadata = self._parse_frontmatter(content)
|
||||
|
||||
# 获取 Skill 名称(目录名)
|
||||
name = os.path.basename(skill_dir)
|
||||
|
||||
return SkillMetadata(
|
||||
name=metadata.get("name", name),
|
||||
description=metadata.get("description", ""),
|
||||
version=metadata.get("version", "1.0.0"),
|
||||
author=metadata.get("author", ""),
|
||||
tags=metadata.get("tags", []),
|
||||
triggers=metadata.get("triggers", []),
|
||||
content=content,
|
||||
source="local",
|
||||
source_id=skill_dir,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _parse_frontmatter(self, content: str) -> dict[str, Any]:
|
||||
"""解析 frontmatter"""
|
||||
metadata = {}
|
||||
|
||||
# 匹配 --- 包裹的 frontmatter
|
||||
match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL)
|
||||
if match:
|
||||
frontmatter = match.group(1)
|
||||
|
||||
for line in frontmatter.split("\n"):
|
||||
if ":" in line:
|
||||
key, value = line.split(":", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
# 处理列表
|
||||
if value.startswith("[") and value.endswith("]"):
|
||||
value = [v.strip().strip('"').strip("'") for v in value[1:-1].split(",")]
|
||||
elif value.lower() in ("true", "false"):
|
||||
value = value.lower() == "true"
|
||||
|
||||
metadata[key] = value
|
||||
|
||||
return metadata
|
||||
169
backend/app/agents/skills/loaders/mcp_loader.py
Normal file
169
backend/app/agents/skills/loaders/mcp_loader.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""MCP Skill 加载器 - Phase 9.2
|
||||
|
||||
从 MCP (Model Context Protocol) 服务器发现和加载 Skills。
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from app.agents.skills.metadata import SkillMetadata
|
||||
|
||||
|
||||
class MCPSkillLoader:
|
||||
"""MCP Skill 加载器
|
||||
|
||||
从 MCP 服务器发现可用的 Skills。
|
||||
"""
|
||||
|
||||
def __init__(self, mcp_servers: list[dict[str, Any]] | None = None):
|
||||
"""
|
||||
Args:
|
||||
mcp_servers: MCP 服务器列表,每项包含 name, command, env 等
|
||||
"""
|
||||
self.mcp_servers = mcp_servers or []
|
||||
self._discovered_skills: dict[str, SkillMetadata] = {}
|
||||
|
||||
def discover_skills(self) -> list[SkillMetadata]:
|
||||
"""从所有配置的 MCP 服务器发现 Skills
|
||||
|
||||
Returns:
|
||||
发现的 Skill 列表
|
||||
"""
|
||||
skills = []
|
||||
|
||||
for server in self.mcp_servers:
|
||||
server_skills = self._discover_from_server(server)
|
||||
skills.extend(server_skills)
|
||||
|
||||
return skills
|
||||
|
||||
def _discover_from_server(self, server: dict[str, Any]) -> list[SkillMetadata]:
|
||||
"""从单个 MCP 服务器发现 Skills
|
||||
|
||||
Args:
|
||||
server: 服务器配置
|
||||
|
||||
Returns:
|
||||
Skill 列表
|
||||
"""
|
||||
skills = []
|
||||
server_name = server.get("name", "unknown")
|
||||
|
||||
# 模拟从 MCP 服务器获取工具列表
|
||||
# 实际实现时,这里会调用 MCP 服务器的 list_tools 接口
|
||||
try:
|
||||
tools = self._call_mcp_list_tools(server)
|
||||
for tool in tools:
|
||||
skill = self._tool_to_skill(tool, server_name)
|
||||
if skill:
|
||||
skills.append(skill)
|
||||
self._discovered_skills[skill.name] = skill
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return skills
|
||||
|
||||
def _call_mcp_list_tools(self, server: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""调用 MCP 服务器的 list_tools 接口
|
||||
|
||||
Args:
|
||||
server: 服务器配置
|
||||
|
||||
Returns:
|
||||
工具列表
|
||||
"""
|
||||
# TODO: 实现实际的 MCP 协议调用
|
||||
# 目前返回空列表,实际使用时需要实现 MCP 客户端
|
||||
return []
|
||||
|
||||
def _tool_to_skill(self, tool: dict[str, Any], server: str) -> SkillMetadata | None:
|
||||
"""将 MCP 工具转换为 Skill
|
||||
|
||||
Args:
|
||||
tool: MCP 工具定义
|
||||
server: 服务器名
|
||||
|
||||
Returns:
|
||||
Skill 元数据或 None
|
||||
"""
|
||||
tool_name = tool.get("name")
|
||||
if not tool_name:
|
||||
return None
|
||||
|
||||
return SkillMetadata(
|
||||
id=f"mcp_{server}_{tool_name}",
|
||||
name=f"{server}:{tool_name}",
|
||||
description=tool.get("description", f"MCP tool: {tool_name}"),
|
||||
version="1.0.0",
|
||||
content=self._generate_skill_content(tool),
|
||||
triggers=[f"@{server}", f"/{tool_name}"],
|
||||
tools=[tool_name],
|
||||
tags=["mcp", server],
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
def _generate_skill_content(self, tool: dict[str, Any]) -> str:
|
||||
"""生成 Skill 内容
|
||||
|
||||
Args:
|
||||
tool: MCP 工具定义
|
||||
|
||||
Returns:
|
||||
Skill 内容字符串
|
||||
"""
|
||||
name = tool.get("name", "unknown")
|
||||
description = tool.get("description", "No description")
|
||||
input_schema = tool.get("inputSchema", {})
|
||||
|
||||
content = f"""# MCP Tool: {name}
|
||||
|
||||
**Description**: {description}
|
||||
|
||||
**Server**: {tool.get("server", "unknown")}
|
||||
|
||||
**Input Schema**:
|
||||
```json
|
||||
{input_schema}
|
||||
```
|
||||
|
||||
**Usage**:
|
||||
Use the `/{name}` command or `@{tool.get("server", "server")}` to invoke this tool.
|
||||
|
||||
**Examples**:
|
||||
```
|
||||
/{name} arg1=value1 arg2=value2
|
||||
@{tool.get("server", "server")} {name} --arg1 value1
|
||||
```
|
||||
"""
|
||||
return content
|
||||
|
||||
def get_skill(self, name: str) -> SkillMetadata | None:
|
||||
"""获取已发现的 Skill
|
||||
|
||||
Args:
|
||||
name: Skill 名称
|
||||
|
||||
Returns:
|
||||
Skill 元数据或 None
|
||||
"""
|
||||
return self._discovered_skills.get(name)
|
||||
|
||||
def list_skills(self) -> list[SkillMetadata]:
|
||||
"""列出所有已发现的 Skills
|
||||
|
||||
Returns:
|
||||
Skill 列表
|
||||
"""
|
||||
return list(self._discovered_skills.values())
|
||||
|
||||
|
||||
# 全局加载器
|
||||
_loader: MCPSkillLoader | None = None
|
||||
|
||||
|
||||
def get_mcp_skill_loader() -> MCPSkillLoader:
|
||||
"""获取全局 MCP Skill 加载器"""
|
||||
global _loader
|
||||
if _loader is None:
|
||||
_loader = MCPSkillLoader()
|
||||
return _loader
|
||||
53
backend/app/agents/skills/loaders/plugin_loader.py
Normal file
53
backend/app/agents/skills/loaders/plugin_loader.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""插件 Skills 加载器 - Phase 9.2"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.skills.metadata import SkillMetadata
|
||||
from app.agents.plugins.manager import get_plugin_manager
|
||||
|
||||
|
||||
class PluginSkillLoader:
|
||||
"""插件 Skills 加载器
|
||||
|
||||
从已安装的插件中加载 Skills。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.plugin_manager = get_plugin_manager()
|
||||
|
||||
def load_all(self) -> list[SkillMetadata]:
|
||||
"""从所有已启用的插件加载 Skills
|
||||
|
||||
Returns:
|
||||
Skill 元数据列表
|
||||
"""
|
||||
skills = []
|
||||
|
||||
for plugin in self.plugin_manager.list_plugins():
|
||||
if not self.plugin_manager.is_enabled(plugin.id):
|
||||
continue
|
||||
|
||||
# 从插件加载 Skills
|
||||
plugin_skills = self._load_from_plugin(plugin)
|
||||
skills.extend(plugin_skills)
|
||||
|
||||
return skills
|
||||
|
||||
def _load_from_plugin(self, plugin: Any) -> list[SkillMetadata]:
|
||||
"""从单个插件加载 Skills"""
|
||||
skills = []
|
||||
|
||||
for skill_name in plugin.skills:
|
||||
skill = SkillMetadata(
|
||||
name=f"{plugin.id}/{skill_name}",
|
||||
description=f"Skill from plugin: {plugin.name}",
|
||||
version=plugin.version,
|
||||
author=plugin.author,
|
||||
tags=["plugin", plugin.id],
|
||||
content=f"# {skill_name}\n\nFrom plugin: {plugin.name}",
|
||||
source="plugin",
|
||||
source_id=plugin.id,
|
||||
)
|
||||
skills.append(skill)
|
||||
|
||||
return skills
|
||||
32
backend/app/agents/skills/matcher.py
Normal file
32
backend/app/agents/skills/matcher.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def extract_match_terms(text: str | None) -> list[str]:
|
||||
source = (text or "").lower()
|
||||
terms = [token for token in re.findall(r"[a-z0-9_]+", source) if len(token) >= 3]
|
||||
|
||||
for chunk in re.findall(r"[\u4e00-\u9fff]+", text or ""):
|
||||
if len(chunk) >= 2:
|
||||
terms.append(chunk)
|
||||
if len(chunk) > 4:
|
||||
for index in range(len(chunk) - 1):
|
||||
terms.append(chunk[index : index + 2])
|
||||
|
||||
return list(dict.fromkeys(terms))
|
||||
|
||||
|
||||
def score_text_match(query_text: str, *corpus_parts: str | None) -> tuple[float, list[str]]:
|
||||
query_terms = extract_match_terms(query_text)
|
||||
if not query_terms:
|
||||
return 0.0, []
|
||||
|
||||
corpus = " ".join(part for part in corpus_parts if part).lower()
|
||||
matched_terms = [term for term in query_terms if term and term in corpus]
|
||||
if not matched_terms:
|
||||
return 0.0, []
|
||||
|
||||
coverage = len(matched_terms) / max(len(query_terms), 1)
|
||||
density = min(len(matched_terms), 4) / 4
|
||||
return round(min(1.0, coverage * 0.7 + density * 0.3), 3), matched_terms
|
||||
100
backend/app/agents/skills/mcp_builder.py
Normal file
100
backend/app/agents/skills/mcp_builder.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""MCP Skill Builder - Phase 9.3"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.skills.metadata import SkillMetadata
|
||||
|
||||
|
||||
class MCPSkillBuilder:
|
||||
"""MCP Skill Builder
|
||||
|
||||
从 MCP 服务器发现和构建 Skills。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._skills: dict[str, SkillMetadata] = {}
|
||||
|
||||
def discover_skills_from_mcp(self, mcp_servers: list[dict[str, Any]]) -> list[SkillMetadata]:
|
||||
"""从 MCP 服务器发现 Skills
|
||||
|
||||
Args:
|
||||
mcp_servers: MCP 服务器配置列表
|
||||
|
||||
Returns:
|
||||
发现的 Skill 元数据列表
|
||||
"""
|
||||
skills = []
|
||||
|
||||
for server in mcp_servers:
|
||||
server_skills = self._discover_from_server(server)
|
||||
skills.extend(server_skills)
|
||||
|
||||
return skills
|
||||
|
||||
def _discover_from_server(self, server: dict[str, Any]) -> list[SkillMetadata]:
|
||||
"""从单个 MCP 服务器发现 Skills"""
|
||||
skills = []
|
||||
server_name = server.get("name", "unknown")
|
||||
tools = server.get("tools", [])
|
||||
|
||||
# 按工具分组
|
||||
tool_groups: dict[str, list[str]] = {}
|
||||
for tool in tools:
|
||||
group = tool.get("group", "default")
|
||||
if group not in tool_groups:
|
||||
tool_groups[group] = []
|
||||
tool_groups[group].append(tool)
|
||||
|
||||
# 为每个组创建一个 Skill
|
||||
for group_name, group_tools in tool_groups.items():
|
||||
skill = self._tool_to_skill(group_name, group_tools, server_name)
|
||||
skills.append(skill)
|
||||
|
||||
return skills
|
||||
|
||||
def _tool_to_skill(self, group: str, tools: list[dict[str, Any]], server: str) -> SkillMetadata:
|
||||
"""将 MCP 工具转换为 Skill"""
|
||||
tool_summaries = []
|
||||
for tool in tools:
|
||||
name = tool.get("name", "unknown")
|
||||
description = tool.get("description", "")
|
||||
input_schema = tool.get("inputSchema", {})
|
||||
|
||||
tool_summaries.append(f"### {name}\n{description}\n\nInput: {input_schema}")
|
||||
|
||||
content = f"""# MCP Skill: {group}
|
||||
|
||||
来自 MCP 服务器: {server}
|
||||
|
||||
## 工具列表
|
||||
|
||||
{chr(10).join(tool_summaries)}
|
||||
|
||||
## 使用说明
|
||||
|
||||
使用这些工具前请确保理解每个工具的输入输出格式。
|
||||
"""
|
||||
|
||||
return SkillMetadata(
|
||||
name=f"mcp-{server}-{group}",
|
||||
description=f"MCP skill from {server}: {group}",
|
||||
version="1.0.0",
|
||||
tags=["mcp", server, group],
|
||||
triggers=[group, server],
|
||||
content=content,
|
||||
source="mcp",
|
||||
source_id=f"{server}:{group}",
|
||||
)
|
||||
|
||||
def _group_to_skill(self, group: str, tools: list[str], server: str) -> SkillMetadata:
|
||||
"""将 MCP 工具组转换为 Skill"""
|
||||
return SkillMetadata(
|
||||
name=f"mcp-{server}-{group}",
|
||||
description=f"MCP skill from {server}: {group}",
|
||||
version="1.0.0",
|
||||
tags=["mcp", server, group],
|
||||
triggers=[group, server],
|
||||
content=f"# {group}\n\nTools: {', '.join(tools)}",
|
||||
source="mcp",
|
||||
source_id=f"{server}:{group}",
|
||||
)
|
||||
50
backend/app/agents/skills/metadata.py
Normal file
50
backend/app/agents/skills/metadata.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Skill 元数据定义 - Phase 9.1"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillMetadata:
|
||||
"""Skill 元数据"""
|
||||
|
||||
id: str = "" # Skill ID
|
||||
name: str = "" # Skill 名称
|
||||
description: str = "" # 描述
|
||||
version: str = "1.0.0" # 版本
|
||||
author: str = "" # 作者
|
||||
tags: list[str] = field(default_factory=list) # 标签
|
||||
triggers: list[str] = field(default_factory=list) # 触发关键词
|
||||
content: str = "" # Skill 内容(markdown)
|
||||
source: str = "local" # 来源:local, plugin, mcp, bundled
|
||||
source_id: str = "" # 来源 ID
|
||||
enabled: bool = True # 是否启用
|
||||
tools: list[str] = field(default_factory=list) # 关联的工具
|
||||
status: str = "active" # candidate/shadow/active/deprecated/retired
|
||||
scope: list[str] = field(default_factory=list)
|
||||
effectiveness: float | None = None
|
||||
review_after: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"version": self.version,
|
||||
"author": self.author,
|
||||
"tags": self.tags,
|
||||
"triggers": self.triggers,
|
||||
"content": self.content,
|
||||
"source": self.source,
|
||||
"source_id": self.source_id,
|
||||
"enabled": self.enabled,
|
||||
"tools": self.tools,
|
||||
"status": self.status,
|
||||
"scope": self.scope,
|
||||
"effectiveness": self.effectiveness,
|
||||
"review_after": self.review_after,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SkillMetadata":
|
||||
return cls(**data)
|
||||
29
backend/app/agents/skills/models.py
Normal file
29
backend/app/agents/skills/models.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
SkillLifecycleAction = Literal[
|
||||
"created_candidate",
|
||||
"promoted_to_shadow",
|
||||
"promoted_to_active",
|
||||
"degraded_to_deprecated",
|
||||
"retired",
|
||||
"reactivated",
|
||||
"feedback_recorded",
|
||||
"no_change",
|
||||
]
|
||||
|
||||
|
||||
class SkillLifecycleDecision(BaseModel):
|
||||
skill_name: str
|
||||
action: SkillLifecycleAction
|
||||
previous_status: str | None = None
|
||||
new_status: str
|
||||
reason: str
|
||||
evidence_refs: list[dict[str, object]] = Field(default_factory=list)
|
||||
confidence: float | None = None
|
||||
review_after: datetime | None = None
|
||||
27
backend/app/agents/skills/policy.py
Normal file
27
backend/app/agents/skills/policy.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.schemas.skills import SkillInjectionMode, SkillShortlistEntry
|
||||
|
||||
MAX_SUMMARY_CHARS = 120
|
||||
|
||||
|
||||
def choose_injection_mode(score: float, summary_available: bool) -> SkillInjectionMode:
|
||||
if score >= 0.75 and summary_available:
|
||||
return "summary"
|
||||
return "metadata_only"
|
||||
|
||||
|
||||
def render_skill_shortlist_context(entries: list[SkillShortlistEntry]) -> str:
|
||||
if not entries:
|
||||
return ""
|
||||
|
||||
lines = ["[Task-Scoped Skills]"]
|
||||
for entry in entries[:3]:
|
||||
detail = entry.summary or "Relevant to the current request."
|
||||
detail = detail[:MAX_SUMMARY_CHARS]
|
||||
lines.append(f"- {entry.skill_name} | mode={entry.injection_mode} | score={entry.score:.2f}")
|
||||
lines.append(f" {detail}")
|
||||
if entry.matched_terms:
|
||||
lines.append(f" matched_terms={', '.join(entry.matched_terms[:6])}")
|
||||
|
||||
return "\n".join(lines)
|
||||
133
backend/app/agents/skills/registry.py
Normal file
133
backend/app/agents/skills/registry.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Skills 注册表 - Phase 9.1"""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from app.agents.skills.metadata import SkillMetadata
|
||||
from app.agents.skills.loaders.local_loader import LocalSkillLoader
|
||||
|
||||
|
||||
class SkillRegistry:
|
||||
"""Skills 注册表
|
||||
|
||||
管理所有 Skills 的注册、发现和加载。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._skills: dict[str, SkillMetadata] = {}
|
||||
self._loaders: list[Any] = []
|
||||
|
||||
def load_all(self, skills_dir: str | None = None) -> int:
|
||||
"""加载所有 Skills
|
||||
|
||||
Args:
|
||||
skills_dir: Skills 目录,None 则使用默认目录
|
||||
|
||||
Returns:
|
||||
加载的 Skill 数量
|
||||
"""
|
||||
if skills_dir is None:
|
||||
skills_dir = os.path.join(
|
||||
os.path.dirname(__file__), "..", "..", "..", ".claude", "skills"
|
||||
)
|
||||
|
||||
count = 0
|
||||
|
||||
# 本地加载器
|
||||
local_loader = LocalSkillLoader(skills_dir)
|
||||
local_skills = local_loader.load_all()
|
||||
for skill in local_skills:
|
||||
self.register(skill)
|
||||
count += 1
|
||||
|
||||
# 插件加载器
|
||||
for loader in self._loaders:
|
||||
try:
|
||||
external_skills = loader.load_all()
|
||||
for skill in external_skills:
|
||||
self.register(skill)
|
||||
count += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return count
|
||||
|
||||
def register(self, skill: SkillMetadata) -> None:
|
||||
"""注册 Skill"""
|
||||
self._skills[skill.name] = skill
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""注销 Skill"""
|
||||
if name in self._skills:
|
||||
del self._skills[name]
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_skill(self, name: str) -> SkillMetadata | None:
|
||||
"""获取 Skill"""
|
||||
return self._skills.get(name)
|
||||
|
||||
def search(self, query: str) -> list[SkillMetadata]:
|
||||
"""搜索 Skills
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
|
||||
Returns:
|
||||
匹配的 Skills 列表
|
||||
"""
|
||||
query_lower = query.lower()
|
||||
results = []
|
||||
|
||||
for skill in self._skills.values():
|
||||
if not skill.enabled:
|
||||
continue
|
||||
|
||||
# 匹配名称、描述、标签
|
||||
if (
|
||||
query_lower in skill.name.lower()
|
||||
or query_lower in skill.description.lower()
|
||||
or any(query_lower in tag.lower() for tag in skill.tags)
|
||||
or any(query_lower in trigger.lower() for trigger in skill.triggers)
|
||||
):
|
||||
results.append(skill)
|
||||
|
||||
return results
|
||||
|
||||
def get_skill_context(self, names: list[str]) -> str:
|
||||
"""获取 Skill 上下文
|
||||
|
||||
Args:
|
||||
names: Skill 名称列表
|
||||
|
||||
Returns:
|
||||
拼接的 Skill 内容
|
||||
"""
|
||||
contexts = []
|
||||
|
||||
for name in names:
|
||||
skill = self._skills.get(name)
|
||||
if skill and skill.enabled:
|
||||
contexts.append(f"# {skill.name}\n\n{skill.content}")
|
||||
|
||||
return "\n\n---\n\n".join(contexts)
|
||||
|
||||
def add_loader(self, loader: Any) -> None:
|
||||
"""添加加载器"""
|
||||
self._loaders.append(loader)
|
||||
|
||||
def list_all(self) -> list[SkillMetadata]:
|
||||
"""列出所有 Skills"""
|
||||
return list(self._skills.values())
|
||||
|
||||
|
||||
# 全局单例
|
||||
_registry: SkillRegistry | None = None
|
||||
|
||||
|
||||
def get_skill_registry() -> SkillRegistry:
|
||||
"""获取全局 Skills 注册表"""
|
||||
global _registry
|
||||
if _registry is None:
|
||||
_registry = SkillRegistry()
|
||||
return _registry
|
||||
153
backend/app/agents/skills/retriever.py
Normal file
153
backend/app/agents/skills/retriever.py
Normal file
@@ -0,0 +1,153 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
from app.agents.schemas.skills import SkillShortlistEntry
|
||||
from app.agents.skills.matcher import score_text_match
|
||||
from app.agents.skills.policy import choose_injection_mode, render_skill_shortlist_context
|
||||
from app.agents.skills.registry import get_skill_registry
|
||||
from app.services.skill_service import SkillService
|
||||
|
||||
|
||||
class RuntimeSkillRetriever:
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
|
||||
async def shortlist(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
query_text: str,
|
||||
memory_context: str | None = None,
|
||||
retrospectives: list[dict] | None = None,
|
||||
include_learned: bool = True,
|
||||
limit: int = 3,
|
||||
) -> list[SkillShortlistEntry]:
|
||||
deduped: "OrderedDict[str, SkillShortlistEntry]" = OrderedDict()
|
||||
retrospective_text = "\n".join(
|
||||
(item.get("summary") or item.get("summary_text") or "")
|
||||
for item in (retrospectives or [])
|
||||
if isinstance(item, dict)
|
||||
)
|
||||
|
||||
service = SkillService(self.db)
|
||||
for skill in await service.list_runtime_candidates(user_id, include_learned=include_learned):
|
||||
score, matched_terms = score_text_match(
|
||||
query_text,
|
||||
skill.name,
|
||||
skill.description,
|
||||
skill.instructions,
|
||||
retrospective_text,
|
||||
memory_context,
|
||||
)
|
||||
if score <= 0:
|
||||
continue
|
||||
entry = SkillShortlistEntry(
|
||||
skill_name=skill.name,
|
||||
source="database",
|
||||
source_id=skill.id,
|
||||
scope=[skill.agent_type, skill.visibility],
|
||||
status=skill.status,
|
||||
effectiveness=skill.effectiveness,
|
||||
score=score,
|
||||
matched_terms=matched_terms,
|
||||
rationale=(
|
||||
"Shadow skill matched current request; keep metadata-only injection."
|
||||
if skill.status == "shadow"
|
||||
else "Matched against DB skill metadata and instructions."
|
||||
),
|
||||
summary=skill.description or (skill.instructions[:160] if skill.instructions else None),
|
||||
injection_mode=(
|
||||
"metadata_only"
|
||||
if skill.status == "shadow"
|
||||
else choose_injection_mode(score, bool(skill.description or skill.instructions))
|
||||
),
|
||||
)
|
||||
self._upsert(deduped, entry)
|
||||
|
||||
registry = get_skill_registry()
|
||||
if not registry.list_all():
|
||||
try:
|
||||
registry.load_all()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for skill in registry.list_all():
|
||||
score, matched_terms = score_text_match(
|
||||
query_text,
|
||||
skill.name,
|
||||
skill.description,
|
||||
" ".join(skill.tags),
|
||||
" ".join(skill.triggers),
|
||||
skill.content[:400],
|
||||
retrospective_text,
|
||||
memory_context,
|
||||
)
|
||||
if score <= 0:
|
||||
continue
|
||||
entry = SkillShortlistEntry(
|
||||
skill_name=skill.name,
|
||||
source=skill.source,
|
||||
source_id=skill.source_id or skill.id,
|
||||
scope=skill.scope or list(skill.tags),
|
||||
status=skill.status,
|
||||
effectiveness=skill.effectiveness,
|
||||
score=score,
|
||||
matched_terms=matched_terms,
|
||||
rationale="Matched against local or external skill metadata.",
|
||||
summary=skill.description or skill.content[:160],
|
||||
injection_mode=choose_injection_mode(
|
||||
score,
|
||||
bool(skill.description or skill.content),
|
||||
),
|
||||
)
|
||||
self._upsert(deduped, entry)
|
||||
|
||||
return sorted(deduped.values(), key=lambda item: item.score, reverse=True)[:limit]
|
||||
|
||||
@staticmethod
|
||||
def _upsert(
|
||||
deduped: "OrderedDict[str, SkillShortlistEntry]",
|
||||
entry: SkillShortlistEntry,
|
||||
) -> None:
|
||||
existing = deduped.get(entry.skill_name)
|
||||
if existing is None or existing.score < entry.score:
|
||||
deduped[entry.skill_name] = entry
|
||||
|
||||
|
||||
def build_shortlisted_skill_context(
|
||||
shortlist: list[dict] | list[SkillShortlistEntry] | None,
|
||||
*,
|
||||
agent_type: str | None = None,
|
||||
) -> str:
|
||||
if not shortlist:
|
||||
return ""
|
||||
|
||||
entries: list[SkillShortlistEntry] = []
|
||||
for item in shortlist:
|
||||
entry = item if isinstance(item, SkillShortlistEntry) else SkillShortlistEntry.model_validate(item)
|
||||
if agent_type and entry.scope and agent_type not in entry.scope:
|
||||
continue
|
||||
entries.append(entry)
|
||||
|
||||
return render_skill_shortlist_context(entries)
|
||||
|
||||
|
||||
async def shortlist_skills_for_request(
|
||||
db,
|
||||
*,
|
||||
user_id: str,
|
||||
user_query: str,
|
||||
memory_context: str | None = None,
|
||||
retrospectives: list[dict] | None = None,
|
||||
include_learned: bool = True,
|
||||
limit: int = 3,
|
||||
) -> list[SkillShortlistEntry]:
|
||||
return await RuntimeSkillRetriever(db).shortlist(
|
||||
user_id=user_id,
|
||||
query_text=user_query,
|
||||
memory_context=memory_context,
|
||||
retrospectives=retrospectives,
|
||||
include_learned=include_learned,
|
||||
limit=limit,
|
||||
)
|
||||
140
backend/app/agents/skills/trigger.py
Normal file
140
backend/app/agents/skills/trigger.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""Skill 触发检测器 - Phase 9.5
|
||||
|
||||
检测消息中的 Skill 触发条件。
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.agents.skills.metadata import SkillMetadata
|
||||
|
||||
|
||||
class SkillTriggerDetector:
|
||||
"""Skill 触发检测器
|
||||
|
||||
检测用户消息中是否触发了某个 Skill。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._skills: dict[str, SkillMetadata] = {}
|
||||
|
||||
def register_skill(self, skill: SkillMetadata) -> None:
|
||||
"""注册 Skill
|
||||
|
||||
Args:
|
||||
skill: Skill 元数据
|
||||
"""
|
||||
self._skills[skill.name] = skill
|
||||
|
||||
def unregister_skill(self, name: str) -> bool:
|
||||
"""注销 Skill
|
||||
|
||||
Args:
|
||||
name: Skill 名称
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if name in self._skills:
|
||||
del self._skills[name]
|
||||
return True
|
||||
return False
|
||||
|
||||
def detect_triggered_skills(self, message: str) -> list[str]:
|
||||
"""检测触发的 Skills
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
|
||||
Returns:
|
||||
触发的 Skill 名称列表
|
||||
"""
|
||||
triggered = []
|
||||
message_lower = message.lower()
|
||||
|
||||
for skill in self._skills.values():
|
||||
if not skill.enabled:
|
||||
continue
|
||||
|
||||
if self._matches_triggers(message, message_lower, skill):
|
||||
triggered.append(skill.name)
|
||||
|
||||
return triggered
|
||||
|
||||
def _matches_triggers(self, message: str, message_lower: str, skill: SkillMetadata) -> bool:
|
||||
"""检查消息是否匹配 Skill 触发条件
|
||||
|
||||
Args:
|
||||
message: 原始消息
|
||||
message_lower: 小写消息
|
||||
skill: Skill 元数据
|
||||
|
||||
Returns:
|
||||
是否匹配
|
||||
"""
|
||||
for trigger in skill.triggers:
|
||||
trigger_lower = trigger.lower()
|
||||
|
||||
# 前缀匹配,如 "/code" 或 "@git"
|
||||
if trigger_lower.startswith("/") or trigger_lower.startswith("@"):
|
||||
if message_lower.startswith(trigger_lower):
|
||||
return True
|
||||
|
||||
# 命令格式,如 "//analyze"
|
||||
if trigger_lower.startswith("//"):
|
||||
pattern = trigger_lower[2:]
|
||||
if re.search(rf"\b{re.escape(pattern)}\b", message_lower):
|
||||
return True
|
||||
|
||||
# 关键词匹配
|
||||
if trigger_lower in message_lower:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_skill_prompt(self, skill_name: str) -> str | None:
|
||||
"""获取 Skill 的提示词
|
||||
|
||||
Args:
|
||||
skill_name: Skill 名称
|
||||
|
||||
Returns:
|
||||
Skill 内容或 None
|
||||
"""
|
||||
skill = self._skills.get(skill_name)
|
||||
if skill:
|
||||
return skill.content
|
||||
return None
|
||||
|
||||
def get_triggered_skill_context(self, message: str) -> str:
|
||||
"""获取触发的 Skills 上下文
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
|
||||
Returns:
|
||||
拼接的 Skill 上下文
|
||||
"""
|
||||
triggered = self.detect_triggered_skills(message)
|
||||
if not triggered:
|
||||
return ""
|
||||
|
||||
contexts = []
|
||||
for skill_name in triggered:
|
||||
skill = self._skills.get(skill_name)
|
||||
if skill:
|
||||
contexts.append(f"# {skill.name}\n\n{skill.content}")
|
||||
|
||||
return "\n\n---\n\n".join(contexts)
|
||||
|
||||
|
||||
# 全局检测器
|
||||
_detector: SkillTriggerDetector | None = None
|
||||
|
||||
|
||||
def get_skill_trigger_detector() -> SkillTriggerDetector:
|
||||
"""获取全局 Skill 触发检测器"""
|
||||
global _detector
|
||||
if _detector is None:
|
||||
_detector = SkillTriggerDetector()
|
||||
return _detector
|
||||
@@ -1,10 +1,28 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypedDict, Annotated, Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Literal, TypedDict
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from app.agents.schemas.event import AgentEvent
|
||||
from app.agents.schemas.message import AgentMessage
|
||||
from app.agents.schemas.task import (
|
||||
AgentTask,
|
||||
CollaborationBudget,
|
||||
InterruptRecord,
|
||||
RecoveryRecord,
|
||||
TaskResult,
|
||||
VerificationStatus,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
AgentPhase = Literal[
|
||||
"phase_0_bootstrap",
|
||||
"phase_1_routing",
|
||||
"phase_2_controlled_collaboration",
|
||||
"phase_3_dynamic_collaboration",
|
||||
"phase_4_visibility_and_verification",
|
||||
]
|
||||
|
||||
|
||||
class AgentRole(str, Enum):
|
||||
MASTER = "master"
|
||||
@@ -12,6 +30,7 @@ class AgentRole(str, Enum):
|
||||
EXECUTOR = "executor"
|
||||
LIBRARIAN = "librarian"
|
||||
ANALYST = "analyst"
|
||||
CODE_COMMANDER = "code_commander"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -22,41 +41,133 @@ class ConversationTurn:
|
||||
model: str | None = None
|
||||
|
||||
|
||||
class AgentState(TypedDict):
|
||||
# Core message history with add_messages reducer
|
||||
messages: Annotated[list[BaseMessage], add_messages]
|
||||
def turn_to_message(turn: ConversationTurn) -> BaseMessage:
|
||||
if turn.role == "user":
|
||||
return HumanMessage(content=turn.content)
|
||||
return AIMessage(content=turn.content)
|
||||
|
||||
# Session identifiers
|
||||
|
||||
class AgentState(TypedDict):
|
||||
messages: Annotated[list[BaseMessage], add_messages]
|
||||
user_id: str
|
||||
conversation_id: str
|
||||
parent_conversation_id: str | None
|
||||
thread_id: str | None
|
||||
last_message_id: str | None
|
||||
message_sequence: int
|
||||
agent_id: str | None
|
||||
parent_agent_id: str | None
|
||||
root_agent_id: str | None
|
||||
collaboration_depth: int
|
||||
spawned_agent_ids: list[str]
|
||||
|
||||
# Agent routing state
|
||||
execution_mode: Literal["direct", "collaboration", "delegated", "verified"]
|
||||
current_agent: str | None
|
||||
next_step: str | None # For explicit graph routing
|
||||
|
||||
# Traceability
|
||||
next_step: str | None
|
||||
active_agents: list[AgentRole]
|
||||
current_sub_commander: str | None
|
||||
active_sub_commanders: list[str]
|
||||
sub_commander_trace: list[dict[str, Any]]
|
||||
agent_trace: list[str]
|
||||
event_trace: list[AgentEvent | dict[str, Any]]
|
||||
message_trace: list[AgentMessage | dict[str, Any]]
|
||||
|
||||
# Task & Entity Tracking (Business Logic)
|
||||
pending_tasks: list[dict]
|
||||
completed_tasks: list[dict]
|
||||
created_entities: list[dict]
|
||||
pending_tasks: list[dict[str, Any]]
|
||||
completed_tasks: list[dict[str, Any]]
|
||||
active_tasks: list[AgentTask | dict[str, Any]]
|
||||
task_results: list[TaskResult | dict[str, Any]]
|
||||
task_hierarchy: dict[str, list[str]]
|
||||
interrupted_tasks: list[InterruptRecord | dict[str, Any]]
|
||||
recovery_trace: list[RecoveryRecord | dict[str, Any]]
|
||||
recovery_points: list[dict[str, Any]]
|
||||
tool_calls: list[dict[str, Any]]
|
||||
last_tool_result: str | None
|
||||
action_results: list[dict[str, Any]]
|
||||
created_entities: list[dict[str, Any]]
|
||||
tool_outcomes: list[dict[str, Any]]
|
||||
task_result_summary: dict[str, Any] | None
|
||||
verifier_hints: dict[str, Any] | None
|
||||
|
||||
verification_status: VerificationStatus | None
|
||||
verification_summary: str | None
|
||||
verification_evidence: list[dict[str, Any]]
|
||||
isolation_mode: str
|
||||
isolation_id: str | None
|
||||
isolation_workspace_path: str | None
|
||||
isolation_parent_conversation_id: str | None
|
||||
isolation_metadata: dict[str, Any]
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
estimated_cost: float | None
|
||||
budget_warning: bool
|
||||
cost_by_agent: dict[str, dict[str, Any]]
|
||||
cost_thresholds: dict[str, Any]
|
||||
budget_state: CollaborationBudget | dict[str, Any] | None
|
||||
collaboration_budget_history: list[CollaborationBudget | dict[str, Any]]
|
||||
current_phase: AgentPhase
|
||||
phase_history: list[dict[str, Any]]
|
||||
current_checkpoint: str | None
|
||||
checkpoint_history: list[dict[str, Any]]
|
||||
|
||||
tool_strategy_used: str | None
|
||||
tool_round_count: int
|
||||
max_tool_rounds: int
|
||||
retry_count: int
|
||||
max_retries: int
|
||||
iteration_count: int
|
||||
max_iterations: int
|
||||
routing_hops: int
|
||||
max_routing_hops: int
|
||||
terminated_due_to_loop_guard: bool
|
||||
retrieval_trace: list[dict[str, Any]]
|
||||
stop_reason: str | None
|
||||
|
||||
clarification_needed: bool
|
||||
clarification_question: str | None
|
||||
fallback_parse_error: str | None
|
||||
should_respond: bool
|
||||
|
||||
# Context summaries (for long-term or cross-agent context)
|
||||
knowledge_context: str | None
|
||||
graph_context: str | None
|
||||
schedule_context_summary: str | None
|
||||
plan: str | None
|
||||
plan_steps: list[dict[str, Any]]
|
||||
analysis_report: str | None
|
||||
|
||||
# Output control
|
||||
final_response: str | None
|
||||
|
||||
# Memory & Environment
|
||||
memory_context: str | None
|
||||
current_datetime_context: str | None
|
||||
current_datetime_reference: dict[str, str] | None
|
||||
runtime_request_context: dict[str, Any] | None
|
||||
task_graph: dict[str, Any] | None
|
||||
scheduled_subtasks: list[dict[str, Any]]
|
||||
recalled_retrospectives: list[dict[str, Any]]
|
||||
retrospective_shortlist: list[dict[str, Any]]
|
||||
skill_shortlist: list[dict[str, Any]]
|
||||
skill_activation_records: list[dict[str, Any]]
|
||||
execution_decision: dict[str, Any] | None
|
||||
merge_report: dict[str, Any] | None
|
||||
verification_report: dict[str, Any] | None
|
||||
feature_flags: dict[str, bool]
|
||||
observability_report: dict[str, Any] | None
|
||||
|
||||
# Configuration
|
||||
user_llm_config: dict | None
|
||||
provider_capabilities: dict | None
|
||||
turn_context: dict[str, Any] | None
|
||||
routing_decision: dict[str, Any] | None
|
||||
continuity_state: dict[str, Any] | None
|
||||
pending_action: dict[str, Any] | None
|
||||
last_completed_action: dict[str, Any] | None
|
||||
clarification_context: dict[str, Any] | None
|
||||
|
||||
user_llm_config: dict[str, Any] | None
|
||||
provider_capabilities: dict[str, Any] | None
|
||||
|
||||
# Code Commander state
|
||||
code_task_type: Literal["demo", "project", "modification"] | None
|
||||
code_ai_provider: Literal["claude", "gemini", "codex", "opencode"] | None
|
||||
code_sandbox_mode: bool | None
|
||||
code_workspace_path: str | None
|
||||
code_execution_session_id: str | None
|
||||
code_execution_result: dict[str, Any] | None
|
||||
|
||||
|
||||
def initial_state(user_id: str, conversation_id: str) -> AgentState:
|
||||
@@ -64,18 +175,115 @@ def initial_state(user_id: str, conversation_id: str) -> AgentState:
|
||||
messages=[],
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
parent_conversation_id=None,
|
||||
thread_id=None,
|
||||
last_message_id=None,
|
||||
message_sequence=0,
|
||||
agent_id=AgentRole.MASTER.value,
|
||||
parent_agent_id=None,
|
||||
root_agent_id=AgentRole.MASTER.value,
|
||||
collaboration_depth=0,
|
||||
spawned_agent_ids=[],
|
||||
execution_mode="direct",
|
||||
current_agent=AgentRole.MASTER.value,
|
||||
next_step=None,
|
||||
active_agents=[AgentRole.MASTER],
|
||||
current_sub_commander=None,
|
||||
active_sub_commanders=[],
|
||||
sub_commander_trace=[],
|
||||
agent_trace=[AgentRole.MASTER.value],
|
||||
event_trace=[],
|
||||
message_trace=[],
|
||||
pending_tasks=[],
|
||||
completed_tasks=[],
|
||||
active_tasks=[],
|
||||
task_results=[],
|
||||
task_hierarchy={},
|
||||
interrupted_tasks=[],
|
||||
recovery_trace=[],
|
||||
recovery_points=[],
|
||||
tool_calls=[],
|
||||
last_tool_result=None,
|
||||
action_results=[],
|
||||
created_entities=[],
|
||||
tool_outcomes=[],
|
||||
task_result_summary=None,
|
||||
verifier_hints=None,
|
||||
verification_status=None,
|
||||
verification_summary=None,
|
||||
verification_evidence=[],
|
||||
isolation_mode="none",
|
||||
isolation_id=None,
|
||||
isolation_workspace_path=None,
|
||||
isolation_parent_conversation_id=None,
|
||||
isolation_metadata={},
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
estimated_cost=None,
|
||||
budget_warning=False,
|
||||
cost_by_agent={},
|
||||
cost_thresholds={},
|
||||
budget_state=None,
|
||||
collaboration_budget_history=[],
|
||||
current_phase="phase_0_bootstrap",
|
||||
phase_history=[
|
||||
{
|
||||
"phase": "phase_0_bootstrap",
|
||||
"reason": "initial_state_created",
|
||||
}
|
||||
],
|
||||
current_checkpoint="bootstrap.initialized",
|
||||
checkpoint_history=[
|
||||
{
|
||||
"checkpoint": "bootstrap.initialized",
|
||||
"phase": "phase_0_bootstrap",
|
||||
"reason": "initial_state_created",
|
||||
}
|
||||
],
|
||||
tool_strategy_used=None,
|
||||
tool_round_count=0,
|
||||
max_tool_rounds=2,
|
||||
retry_count=0,
|
||||
max_retries=1,
|
||||
iteration_count=0,
|
||||
max_iterations=3,
|
||||
routing_hops=0,
|
||||
max_routing_hops=2,
|
||||
terminated_due_to_loop_guard=False,
|
||||
retrieval_trace=[],
|
||||
stop_reason=None,
|
||||
clarification_needed=False,
|
||||
clarification_question=None,
|
||||
fallback_parse_error=None,
|
||||
should_respond=True,
|
||||
knowledge_context=None,
|
||||
graph_context=None,
|
||||
schedule_context_summary=None,
|
||||
plan=None,
|
||||
plan_steps=[],
|
||||
analysis_report=None,
|
||||
final_response=None,
|
||||
memory_context=None,
|
||||
current_datetime_context=None,
|
||||
current_datetime_reference=None,
|
||||
runtime_request_context=None,
|
||||
task_graph=None,
|
||||
scheduled_subtasks=[],
|
||||
recalled_retrospectives=[],
|
||||
retrospective_shortlist=[],
|
||||
skill_shortlist=[],
|
||||
skill_activation_records=[],
|
||||
execution_decision=None,
|
||||
merge_report=None,
|
||||
verification_report=None,
|
||||
feature_flags={},
|
||||
observability_report=None,
|
||||
turn_context=None,
|
||||
routing_decision=None,
|
||||
continuity_state=None,
|
||||
pending_action=None,
|
||||
last_completed_action=None,
|
||||
clarification_context=None,
|
||||
user_llm_config=None,
|
||||
provider_capabilities=None,
|
||||
)
|
||||
|
||||
13
backend/app/agents/team/__init__.py
Normal file
13
backend/app/agents/team/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Team 多 Agent 协作"""
|
||||
|
||||
from app.agents.team.leader import TeamLeader, TeamTask, TaskStatus
|
||||
from app.agents.team.member import TeamMember, MemberStatus, MemberTask
|
||||
|
||||
__all__ = [
|
||||
"TeamLeader",
|
||||
"TeamTask",
|
||||
"TaskStatus",
|
||||
"TeamMember",
|
||||
"MemberStatus",
|
||||
"MemberTask",
|
||||
]
|
||||
121
backend/app/agents/team/leader.py
Normal file
121
backend/app/agents/team/leader.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Team 多 Agent 协作 - Phase 10.1"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TeamTask:
|
||||
"""团队任务"""
|
||||
|
||||
id: str
|
||||
description: str
|
||||
assignee: str | None = None
|
||||
status: TaskStatus = TaskStatus.PENDING
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class TeamLeader:
|
||||
"""团队领导者
|
||||
|
||||
协调多个 Agent 成员执行任务。
|
||||
"""
|
||||
|
||||
def __init__(self, team_id: str, members: list[str]):
|
||||
"""
|
||||
Args:
|
||||
team_id: 团队 ID
|
||||
members: 成员 ID 列表
|
||||
"""
|
||||
self.team_id = team_id
|
||||
self.members = members
|
||||
self._tasks: dict[str, TeamTask] = {}
|
||||
|
||||
def create_task(self, description: str) -> str:
|
||||
"""创建任务
|
||||
|
||||
Args:
|
||||
description: 任务描述
|
||||
|
||||
Returns:
|
||||
任务 ID
|
||||
"""
|
||||
import uuid
|
||||
|
||||
task_id = str(uuid.uuid4())[:8]
|
||||
self._tasks[task_id] = TeamTask(
|
||||
id=task_id,
|
||||
description=description,
|
||||
)
|
||||
return task_id
|
||||
|
||||
def assign_task(self, task_id: str, member: str) -> bool:
|
||||
"""分配任务
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
member: 成员 ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if task_id not in self._tasks:
|
||||
return False
|
||||
|
||||
if member not in self.members:
|
||||
return False
|
||||
|
||||
self._tasks[task_id].assignee = member
|
||||
self._tasks[task_id].status = TaskStatus.IN_PROGRESS
|
||||
return True
|
||||
|
||||
def broadcast_task(self, description: str) -> list[str]:
|
||||
"""广播任务给所有成员
|
||||
|
||||
Args:
|
||||
description: 任务描述
|
||||
|
||||
Returns:
|
||||
创建的任务 ID 列表
|
||||
"""
|
||||
task_ids = []
|
||||
for member in self.members:
|
||||
task_id = self.create_task(description)
|
||||
self.assign_task(task_id, member)
|
||||
task_ids.append(task_id)
|
||||
return task_ids
|
||||
|
||||
def collect_results(self) -> dict[str, Any]:
|
||||
"""收集所有任务结果
|
||||
|
||||
Returns:
|
||||
任务 ID -> 结果的映射
|
||||
"""
|
||||
return {
|
||||
task_id: task.result
|
||||
for task_id, task in self._tasks.items()
|
||||
if task.status == TaskStatus.COMPLETED
|
||||
}
|
||||
|
||||
def get_team_status(self) -> dict[str, Any]:
|
||||
"""获取团队状态
|
||||
|
||||
Returns:
|
||||
团队状态摘要
|
||||
"""
|
||||
return {
|
||||
"team_id": self.team_id,
|
||||
"members": self.members,
|
||||
"task_count": len(self._tasks),
|
||||
"completed": sum(1 for t in self._tasks.values() if t.status == TaskStatus.COMPLETED),
|
||||
"failed": sum(1 for t in self._tasks.values() if t.status == TaskStatus.FAILED),
|
||||
}
|
||||
166
backend/app/agents/team/member.py
Normal file
166
backend/app/agents/team/member.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""TeamMember 实现 - Phase 10.1
|
||||
|
||||
团队成员实现,负责执行分配的任务。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class MemberStatus(Enum):
|
||||
"""成员状态"""
|
||||
|
||||
IDLE = "idle"
|
||||
BUSY = "busy"
|
||||
OFFLINE = "offline"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemberTask:
|
||||
"""成员任务"""
|
||||
|
||||
task_id: str
|
||||
description: str
|
||||
status: str = "pending" # pending, in_progress, completed, failed
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class TeamMember:
|
||||
"""团队成员
|
||||
|
||||
代表团队中的一个 Agent 成员,负责执行分配的任务。
|
||||
"""
|
||||
|
||||
def __init__(self, member_id: str, name: str, capabilities: list[str] | None = None):
|
||||
"""
|
||||
Args:
|
||||
member_id: 成员 ID
|
||||
name: 成员名称
|
||||
capabilities: 成员能力列表
|
||||
"""
|
||||
self.member_id = member_id
|
||||
self.name = name
|
||||
self.capabilities = capabilities or []
|
||||
self.status = MemberStatus.IDLE
|
||||
self._tasks: dict[str, MemberTask] = {}
|
||||
self._metadata: dict[str, Any] = {}
|
||||
|
||||
def assign_task(self, task_id: str, description: str) -> MemberTask:
|
||||
"""接收任务分配
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
description: 任务描述
|
||||
|
||||
Returns:
|
||||
创建的任务对象
|
||||
"""
|
||||
task = MemberTask(task_id=task_id, description=description)
|
||||
self._tasks[task_id] = task
|
||||
self.status = MemberStatus.BUSY
|
||||
return task
|
||||
|
||||
def update_task_status(
|
||||
self, task_id: str, status: str, result: Any = None, error: str | None = None
|
||||
) -> bool:
|
||||
"""更新任务状态
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
status: 新状态
|
||||
result: 任务结果
|
||||
error: 错误信息
|
||||
|
||||
Returns:
|
||||
是否更新成功
|
||||
"""
|
||||
if task_id not in self._tasks:
|
||||
return False
|
||||
|
||||
task = self._tasks[task_id]
|
||||
task.status = status
|
||||
if result is not None:
|
||||
task.result = result
|
||||
if error is not None:
|
||||
task.error = error
|
||||
|
||||
if status in ("completed", "failed"):
|
||||
self.status = MemberStatus.IDLE
|
||||
|
||||
return True
|
||||
|
||||
def get_task(self, task_id: str) -> MemberTask | None:
|
||||
"""获取任务
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
|
||||
Returns:
|
||||
任务对象或 None
|
||||
"""
|
||||
return self._tasks.get(task_id)
|
||||
|
||||
def get_pending_tasks(self) -> list[MemberTask]:
|
||||
"""获取待处理任务
|
||||
|
||||
Returns:
|
||||
待处理任务列表
|
||||
"""
|
||||
return [t for t in self._tasks.values() if t.status == "pending"]
|
||||
|
||||
def get_active_task(self) -> MemberTask | None:
|
||||
"""获取当前执行中的任务
|
||||
|
||||
Returns:
|
||||
当前任务或 None
|
||||
"""
|
||||
for task in self._tasks.values():
|
||||
if task.status == "in_progress":
|
||||
return task
|
||||
return None
|
||||
|
||||
def get_completed_tasks(self) -> list[MemberTask]:
|
||||
"""获取已完成任务
|
||||
|
||||
Returns:
|
||||
已完成任务列表
|
||||
"""
|
||||
return [t for t in self._tasks.values() if t.status == "completed"]
|
||||
|
||||
def set_metadata(self, key: str, value: Any) -> None:
|
||||
"""设置元数据
|
||||
|
||||
Args:
|
||||
key: 元数据键
|
||||
value: 元数据值
|
||||
"""
|
||||
self._metadata[key] = value
|
||||
|
||||
def get_metadata(self, key: str) -> Any:
|
||||
"""获取元数据
|
||||
|
||||
Args:
|
||||
key: 元数据键
|
||||
|
||||
Returns:
|
||||
元数据值或 None
|
||||
"""
|
||||
return self._metadata.get(key)
|
||||
|
||||
def get_status(self) -> dict[str, Any]:
|
||||
"""获取成员状态
|
||||
|
||||
Returns:
|
||||
状态字典
|
||||
"""
|
||||
return {
|
||||
"member_id": self.member_id,
|
||||
"name": self.name,
|
||||
"status": self.status.value,
|
||||
"capabilities": self.capabilities,
|
||||
"task_count": len(self._tasks),
|
||||
"pending_count": len(self.get_pending_tasks()),
|
||||
"active_task": self.get_active_task().__dict__ if self.get_active_task() else None,
|
||||
}
|
||||
@@ -1,6 +1,9 @@
|
||||
from app.agents.tools.search import (
|
||||
search_knowledge, get_knowledge_graph_context,
|
||||
build_knowledge_graph, hybrid_search, web_search,
|
||||
search_knowledge,
|
||||
get_knowledge_graph_context,
|
||||
build_knowledge_graph,
|
||||
hybrid_search,
|
||||
web_search,
|
||||
)
|
||||
from app.agents.tools.task import get_tasks, create_task, update_task_status
|
||||
from app.agents.tools.forum import get_forum_posts, create_forum_post, scan_forum_for_instructions
|
||||
@@ -13,6 +16,58 @@ from app.agents.tools.schedule import (
|
||||
)
|
||||
from app.agents.tools.time_reasoning import resolve_time_expression
|
||||
|
||||
# Phase 6.1: Tool Registry exports
|
||||
from app.agents.tools.registry import (
|
||||
ToolRegistry,
|
||||
get_tool_registry,
|
||||
reset_tool_registry,
|
||||
)
|
||||
from app.agents.tools.manifest import (
|
||||
HookConfig,
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
ToolCategory,
|
||||
ToolManifest,
|
||||
)
|
||||
from app.agents.tools.migration import (
|
||||
migrate_tool,
|
||||
migrate_all_tools,
|
||||
get_tool_executor,
|
||||
BackwardCompatTool,
|
||||
)
|
||||
|
||||
# Phase 6.2: Hook System exports
|
||||
from app.agents.tools.hooks import (
|
||||
HookManager,
|
||||
HookExecutor,
|
||||
HookType,
|
||||
HookDefinition,
|
||||
HookResult,
|
||||
ExecutionContext,
|
||||
get_hook_manager,
|
||||
get_hook_executor,
|
||||
)
|
||||
|
||||
# Phase 6.3: Streaming Executor exports
|
||||
from app.agents.tools.streaming import (
|
||||
StreamingToolExecutor,
|
||||
get_streaming_executor,
|
||||
)
|
||||
|
||||
# Phase 6.4: Builtin Tools exports
|
||||
from app.agents.tools.builtins import (
|
||||
GlobTool,
|
||||
GrepTool,
|
||||
ReadFileTool,
|
||||
WriteFileTool,
|
||||
BashTool,
|
||||
PowerShellTool,
|
||||
LSPTools,
|
||||
GitTool,
|
||||
TeamAgentTool,
|
||||
TaskBroadcastTool,
|
||||
)
|
||||
|
||||
TASK_TOOLS = [
|
||||
get_tasks,
|
||||
create_task,
|
||||
@@ -83,3 +138,12 @@ SUB_COMMANDER_TOOLSETS = {
|
||||
"analyst_progress": ANALYST_PROGRESS_TOOLS,
|
||||
"analyst_insights": ANALYST_INSIGHT_TOOLS,
|
||||
}
|
||||
|
||||
# Code Commander toolset (tools implemented in later phases)
|
||||
CODE_COMMANDER_TOOLSET_NAMES = [
|
||||
"execute_code_task",
|
||||
"get_execution_status",
|
||||
"send_interactive_input",
|
||||
"download_workspace",
|
||||
"cleanup_workspace",
|
||||
]
|
||||
|
||||
196
backend/app/agents/tools/ai_adapter.py
Normal file
196
backend/app/agents/tools/ai_adapter.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
AI CLI Adapter - 统一接口适配不同 AI CLI (Claude/Gemini/Codex/OpenCode)
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeExecutionResult:
|
||||
"""代码执行结果"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
files_created: list[str] = field(default_factory=list)
|
||||
output: str = ""
|
||||
error: str | None = None
|
||||
exit_code: int = 0
|
||||
|
||||
|
||||
class AICLIAdapter(ABC):
|
||||
"""AI CLI 适配器抽象基类"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def cli_name(self) -> str:
|
||||
"""CLI 命令名称,如 'claude', 'gemini'"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def requires_workspace(self) -> bool:
|
||||
"""是否需要工作目录"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def provider(self) -> Literal["claude", "gemini", "codex", "opencode"]:
|
||||
"""AI 提供商标识"""
|
||||
return self.cli_name
|
||||
|
||||
@abstractmethod
|
||||
def build_command(self, prompt: str, workspace: Path | None) -> list[str]:
|
||||
"""构建 CLI 命令"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def parse_output(self, output: str) -> CodeExecutionResult:
|
||||
"""解析 CLI 输出"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_installed(self) -> bool:
|
||||
"""检查 CLI 是否已安装"""
|
||||
pass
|
||||
|
||||
|
||||
class ClaudeAdapter(AICLIAdapter):
|
||||
"""Claude CLI 适配器"""
|
||||
|
||||
cli_name = "claude"
|
||||
requires_workspace = True
|
||||
|
||||
def build_command(self, prompt: str, workspace: Path | None) -> list[str]:
|
||||
cmd = ["claude", "-p", prompt]
|
||||
if workspace:
|
||||
cmd.extend(["--output-format", "stream-json"])
|
||||
cmd.append("--dangerously-skip-permissions")
|
||||
return cmd
|
||||
|
||||
def parse_output(self, output: str) -> CodeExecutionResult:
|
||||
# Claude CLI 输出可能是纯文本或 JSON
|
||||
# 简化处理:直接返回输出
|
||||
if not output.strip():
|
||||
return CodeExecutionResult(
|
||||
success=False,
|
||||
message="No output from Claude CLI",
|
||||
output=output,
|
||||
)
|
||||
return CodeExecutionResult(
|
||||
success=True,
|
||||
message="Execution completed",
|
||||
output=output,
|
||||
)
|
||||
|
||||
def is_installed(self) -> bool:
|
||||
import shutil
|
||||
|
||||
return shutil.which("claude") is not None
|
||||
|
||||
|
||||
class GeminiAdapter(AICLIAdapter):
|
||||
"""Gemini CLI 适配器"""
|
||||
|
||||
cli_name = "gemini"
|
||||
requires_workspace = False
|
||||
|
||||
def build_command(self, prompt: str, workspace: Path | None) -> list[str]:
|
||||
cmd = ["gemini", "-p", prompt]
|
||||
return cmd
|
||||
|
||||
def parse_output(self, output: str) -> CodeExecutionResult:
|
||||
if not output.strip():
|
||||
return CodeExecutionResult(
|
||||
success=False,
|
||||
message="No output from Gemini CLI",
|
||||
output=output,
|
||||
)
|
||||
return CodeExecutionResult(
|
||||
success=True,
|
||||
message="Execution completed",
|
||||
output=output,
|
||||
)
|
||||
|
||||
def is_installed(self) -> bool:
|
||||
import shutil
|
||||
|
||||
return shutil.which("gemini") is not None
|
||||
|
||||
|
||||
class CodexAdapter(AICLIAdapter):
|
||||
"""Codex CLI 适配器"""
|
||||
|
||||
cli_name = "codex"
|
||||
requires_workspace = True
|
||||
|
||||
def build_command(self, prompt: str, workspace: Path | None) -> list[str]:
|
||||
cmd = ["codex", "-p", prompt]
|
||||
return cmd
|
||||
|
||||
def parse_output(self, output: str) -> CodeExecutionResult:
|
||||
if not output.strip():
|
||||
return CodeExecutionResult(
|
||||
success=False,
|
||||
message="No output from Codex CLI",
|
||||
output=output,
|
||||
)
|
||||
return CodeExecutionResult(
|
||||
success=True,
|
||||
message="Execution completed",
|
||||
output=output,
|
||||
)
|
||||
|
||||
def is_installed(self) -> bool:
|
||||
import shutil
|
||||
|
||||
return shutil.which("codex") is not None
|
||||
|
||||
|
||||
class OpenCodeAdapter(AICLIAdapter):
|
||||
"""OpenCode CLI 适配器"""
|
||||
|
||||
cli_name = "opencode"
|
||||
requires_workspace = True
|
||||
|
||||
def build_command(self, prompt: str, workspace: Path | None) -> list[str]:
|
||||
cmd = ["opencode", "-p", prompt]
|
||||
return cmd
|
||||
|
||||
def parse_output(self, output: str) -> CodeExecutionResult:
|
||||
if not output.strip():
|
||||
return CodeExecutionResult(
|
||||
success=False,
|
||||
message="No output from OpenCode CLI",
|
||||
output=output,
|
||||
)
|
||||
return CodeExecutionResult(
|
||||
success=True,
|
||||
message="Execution completed",
|
||||
output=output,
|
||||
)
|
||||
|
||||
def is_installed(self) -> bool:
|
||||
import shutil
|
||||
|
||||
return shutil.which("opencode") is not None
|
||||
|
||||
|
||||
# 提供商注册表
|
||||
ADAPTER_REGISTRY: dict[str, AICLIAdapter] = {
|
||||
"claude": ClaudeAdapter(),
|
||||
"gemini": GeminiAdapter(),
|
||||
"codex": CodexAdapter(),
|
||||
"opencode": OpenCodeAdapter(),
|
||||
}
|
||||
|
||||
|
||||
def get_adapter(provider: str) -> AICLIAdapter:
|
||||
"""获取指定提供商的适配器"""
|
||||
adapter = ADAPTER_REGISTRY.get(provider.lower())
|
||||
if adapter is None:
|
||||
raise ValueError(
|
||||
f"Unknown AI provider: {provider}. Available: {list(ADAPTER_REGISTRY.keys())}"
|
||||
)
|
||||
return adapter
|
||||
18
backend/app/agents/tools/async_bridge.py
Normal file
18
backend/app/agents/tools/async_bridge.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def run_async(coro: Any, timeout: int = 30):
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
|
||||
|
||||
|
||||
__all__ = ["run_async"]
|
||||
161
backend/app/agents/tools/base.py
Normal file
161
backend/app/agents/tools/base.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""工具基类 - 工具系统重构 Phase 6.1"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from app.agents.tools.manifest import (
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
ToolCategory,
|
||||
ToolManifest,
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseTool(ABC, Generic[T]):
|
||||
"""工具基类
|
||||
|
||||
提供工具的标准接口和默认实现。
|
||||
所有自定义工具应继承此类。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
category: ToolCategory,
|
||||
permission_class: PermissionClass,
|
||||
side_effect_scope: SideEffectScope = SideEffectScope.NONE,
|
||||
requires_confirmation: bool = False,
|
||||
is_streaming: bool = False,
|
||||
tags: list[str] | None = None,
|
||||
):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.category = category
|
||||
self.permission_class = permission_class
|
||||
self.side_effect_scope = side_effect_scope
|
||||
self.requires_confirmation = requires_confirmation
|
||||
self.is_streaming = is_streaming
|
||||
self.tags = tags or []
|
||||
|
||||
def get_manifest(self) -> ToolManifest:
|
||||
"""获取工具元数据
|
||||
|
||||
Returns:
|
||||
工具元数据
|
||||
"""
|
||||
return ToolManifest(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
category=self.category,
|
||||
parameters=self.get_parameters(),
|
||||
return_schema=self.get_return_schema(),
|
||||
permission_class=self.permission_class,
|
||||
side_effect_scope=self.side_effect_scope,
|
||||
requires_confirmation=self.requires_confirmation,
|
||||
is_streaming=self.is_streaming,
|
||||
tags=self.tags,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
"""获取参数 Schema(JSON Schema 格式)
|
||||
|
||||
Returns:
|
||||
参数 schema
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
"""获取返回值 Schema
|
||||
|
||||
Returns:
|
||||
返回值 schema
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs) -> T:
|
||||
"""执行工具
|
||||
|
||||
Args:
|
||||
**kwargs: 工具参数
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
pass
|
||||
|
||||
async def execute_safe(self, **kwargs) -> dict[str, Any]:
|
||||
"""安全执行工具,捕获异常
|
||||
|
||||
Args:
|
||||
**kwargs: 工具参数
|
||||
|
||||
Returns:
|
||||
包含 success 和 result/error 的字典
|
||||
"""
|
||||
try:
|
||||
result = await self.execute(**kwargs)
|
||||
return {"success": True, "result": result}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}(name={self.name!r})>"
|
||||
|
||||
|
||||
class ReadTool(BaseTool):
|
||||
"""只读工具基类"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs.setdefault("category", ToolCategory.READ)
|
||||
kwargs.setdefault("permission_class", PermissionClass.READ)
|
||||
kwargs.setdefault("side_effect_scope", SideEffectScope.NONE)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class WriteTool(BaseTool):
|
||||
"""写入工具基类"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs.setdefault("category", ToolCategory.WRITE)
|
||||
kwargs.setdefault("permission_class", PermissionClass.WRITE)
|
||||
kwargs.setdefault("side_effect_scope", SideEffectScope.LOCAL_STATE)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class DBWriteTool(BaseTool):
|
||||
"""数据库写入工具基类"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs.setdefault("category", ToolCategory.DB_WRITE)
|
||||
kwargs.setdefault("permission_class", PermissionClass.WRITE)
|
||||
kwargs.setdefault("side_effect_scope", SideEffectScope.DB_WRITE)
|
||||
kwargs.setdefault("requires_confirmation", True)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class ExternalTool(BaseTool):
|
||||
"""外部工具基类(执行外部命令等)"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs.setdefault("category", ToolCategory.EXTERNAL)
|
||||
kwargs.setdefault("permission_class", PermissionClass.EXTERNAL)
|
||||
kwargs.setdefault("side_effect_scope", SideEffectScope.NETWORK)
|
||||
kwargs.setdefault("requires_confirmation", True)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class NetworkTool(BaseTool):
|
||||
"""网络工具基类"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs.setdefault("category", ToolCategory.NETWORK)
|
||||
kwargs.setdefault("permission_class", PermissionClass.EXTERNAL)
|
||||
kwargs.setdefault("side_effect_scope", SideEffectScope.NETWORK)
|
||||
super().__init__(**kwargs)
|
||||
43
backend/app/agents/tools/builtins/__init__.py
Normal file
43
backend/app/agents/tools/builtins/__init__.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""内置工具集 - Phase 6.4
|
||||
|
||||
新的内置工具,使用 BaseTool 基类。
|
||||
"""
|
||||
|
||||
from app.agents.tools.builtins.file_tools import (
|
||||
GlobTool,
|
||||
GrepTool,
|
||||
ReadFileTool,
|
||||
WriteFileTool,
|
||||
)
|
||||
|
||||
from app.agents.tools.builtins.system_tools import (
|
||||
BashTool,
|
||||
PowerShellTool,
|
||||
)
|
||||
|
||||
from app.agents.tools.builtins.dev_tools import (
|
||||
LSPTools,
|
||||
GitTool,
|
||||
)
|
||||
|
||||
from app.agents.tools.builtins.collaboration_tools import (
|
||||
TeamAgentTool,
|
||||
TaskBroadcastTool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# File tools
|
||||
"GlobTool",
|
||||
"GrepTool",
|
||||
"ReadFileTool",
|
||||
"WriteFileTool",
|
||||
# System tools
|
||||
"BashTool",
|
||||
"PowerShellTool",
|
||||
# Dev tools
|
||||
"LSPTools",
|
||||
"GitTool",
|
||||
# Collaboration tools
|
||||
"TeamAgentTool",
|
||||
"TaskBroadcastTool",
|
||||
]
|
||||
129
backend/app/agents/tools/builtins/collaboration_tools.py
Normal file
129
backend/app/agents/tools/builtins/collaboration_tools.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""协作工具 - Phase 6.4"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.base import WriteTool
|
||||
from app.agents.tools.manifest import (
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
)
|
||||
|
||||
|
||||
class TeamAgentTool(WriteTool):
|
||||
"""团队 Agent 通信工具
|
||||
|
||||
用于与其他 Agent 进行消息传递和协作。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="team_agent",
|
||||
description="向团队 Agent 发送消息或请求协作",
|
||||
permission_class=PermissionClass.WRITE,
|
||||
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||
tags=["collaboration", "team", "agent"],
|
||||
)
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_name": {
|
||||
"type": "string",
|
||||
"description": "目标 Agent 名称",
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "要发送的消息",
|
||||
},
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["send", "request", "delegate"],
|
||||
"description": "操作类型",
|
||||
},
|
||||
},
|
||||
"required": ["agent_name", "message"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"success": {"type": "boolean"},
|
||||
"response": {"type": "string"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(self, agent_name: str, message: str, action: str = "send") -> dict[str, Any]:
|
||||
# 注意:实际实现需要通过 Agent 通信协议
|
||||
# 这里只是一个框架实现
|
||||
return {
|
||||
"success": True,
|
||||
"response": f"Message '{action}' to agent '{agent_name}': {message}",
|
||||
"agent_name": agent_name,
|
||||
"action": action,
|
||||
}
|
||||
|
||||
|
||||
class TaskBroadcastTool(WriteTool):
|
||||
"""任务广播工具
|
||||
|
||||
向多个 Agent 广播任务。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="task_broadcast",
|
||||
description="向多个 Agent 广播任务",
|
||||
permission_class=PermissionClass.WRITE,
|
||||
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||
tags=["collaboration", "broadcast", "task"],
|
||||
)
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_names": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "目标 Agent 列表",
|
||||
},
|
||||
"task": {
|
||||
"type": "string",
|
||||
"description": "要广播的任务描述",
|
||||
},
|
||||
"priority": {
|
||||
"type": "string",
|
||||
"enum": ["low", "normal", "high", "urgent"],
|
||||
"description": "任务优先级",
|
||||
},
|
||||
},
|
||||
"required": ["agent_names", "task"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"success": {"type": "boolean"},
|
||||
"broadcast_to": {"type": "array", "items": {"type": "string"}},
|
||||
"responses": {"type": "array"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
agent_names: list[str],
|
||||
task: str,
|
||||
priority: str = "normal",
|
||||
) -> dict[str, Any]:
|
||||
# 注意:实际实现需要通过 Agent 通信协议
|
||||
# 这里只是一个框架实现
|
||||
return {
|
||||
"success": True,
|
||||
"broadcast_to": agent_names,
|
||||
"task": task,
|
||||
"priority": priority,
|
||||
"responses": [f"Acknowledged by {agent}" for agent in agent_names],
|
||||
}
|
||||
155
backend/app/agents/tools/builtins/dev_tools.py
Normal file
155
backend/app/agents/tools/builtins/dev_tools.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""开发工具 - Phase 6.4"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.base import ReadTool, WriteTool
|
||||
from app.agents.tools.manifest import (
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
)
|
||||
|
||||
|
||||
class LSPTools(ReadTool):
|
||||
"""语言服务器协议工具集
|
||||
|
||||
提供代码导航、查找引用等 LSP 功能。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="lsp_tools",
|
||||
description="LSP 代码导航和查找引用",
|
||||
permission_class=PermissionClass.READ,
|
||||
side_effect_scope=SideEffectScope.NONE,
|
||||
tags=["development", "lsp", "code"],
|
||||
)
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["goto_definition", "find_references", "document_symbols"],
|
||||
"description": "LSP 操作类型",
|
||||
},
|
||||
"file": {
|
||||
"type": "string",
|
||||
"description": "文件路径",
|
||||
},
|
||||
"line": {
|
||||
"type": "integer",
|
||||
"description": "行号(1-based)",
|
||||
},
|
||||
"character": {
|
||||
"type": "integer",
|
||||
"description": "列号(0-based)",
|
||||
},
|
||||
},
|
||||
"required": ["action", "file"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"success": {"type": "boolean"},
|
||||
"results": {"type": "array"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
action: str,
|
||||
file: str,
|
||||
line: int = 1,
|
||||
character: int = 0,
|
||||
) -> dict[str, Any]:
|
||||
# 注意:实际 LSP 调用需要通过 lsp-utils 或类似库
|
||||
# 这里只是一个框架实现
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"LSP action '{action}' not fully implemented - requires LSP server integration",
|
||||
"action": action,
|
||||
"file": file,
|
||||
"position": {"line": line, "character": character},
|
||||
}
|
||||
|
||||
|
||||
class GitTool(ReadTool):
|
||||
"""Git 操作工具
|
||||
|
||||
提供常用的 Git 操作。
|
||||
"""
|
||||
|
||||
def __init__(self, repo_path: str = "."):
|
||||
super().__init__(
|
||||
name="git",
|
||||
description="执行 Git 命令",
|
||||
permission_class=PermissionClass.EXTERNAL,
|
||||
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||
requires_confirmation=True,
|
||||
tags=["development", "git", "version-control"],
|
||||
)
|
||||
self.repo_path = repo_path
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "Git 子命令和参数,如 'status' 或 'log --oneline -10'",
|
||||
},
|
||||
"repo_path": {
|
||||
"type": "string",
|
||||
"description": "仓库路径(可选)",
|
||||
},
|
||||
},
|
||||
"required": ["command"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"stdout": {"type": "string"},
|
||||
"stderr": {"type": "string"},
|
||||
"returncode": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(self, command: str, repo_path: str | None = None) -> dict[str, Any]:
|
||||
import asyncio
|
||||
import os
|
||||
import platform
|
||||
|
||||
repo = repo_path or self.repo_path
|
||||
|
||||
# 构建完整的 git 命令
|
||||
if platform.system() == "Windows":
|
||||
full_command = f'git -C "{repo}" {command}'
|
||||
else:
|
||||
full_command = f"git -C '{repo}' {command}"
|
||||
|
||||
try:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
full_command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
return {
|
||||
"stdout": stdout.decode("utf-8", errors="replace"),
|
||||
"stderr": stderr.decode("utf-8", errors="replace"),
|
||||
"returncode": process.returncode,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": str(e),
|
||||
"returncode": -1,
|
||||
}
|
||||
255
backend/app/agents/tools/builtins/file_tools.py
Normal file
255
backend/app/agents/tools/builtins/file_tools.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""文件操作工具 - Phase 6.4"""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.base import ExternalTool, ReadTool, WriteTool
|
||||
from app.agents.tools.manifest import (
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
ToolCategory,
|
||||
)
|
||||
|
||||
|
||||
class GlobTool(ReadTool):
|
||||
"""文件路径匹配工具
|
||||
|
||||
使用 glob 模式查找文件。
|
||||
"""
|
||||
|
||||
def __init__(self, root_dir: str = "."):
|
||||
super().__init__(
|
||||
name="glob",
|
||||
description="使用 glob 模式查找文件路径",
|
||||
permission_class=PermissionClass.READ,
|
||||
side_effect_scope=SideEffectScope.NONE,
|
||||
tags=["file", "search", "glob"],
|
||||
)
|
||||
self.root_dir = root_dir
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Glob 模式,如 **/*.py",
|
||||
},
|
||||
"root_dir": {
|
||||
"type": "string",
|
||||
"description": "搜索根目录(可选)",
|
||||
},
|
||||
},
|
||||
"required": ["pattern"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
|
||||
async def execute(self, pattern: str, root_dir: str | None = None) -> list[str]:
|
||||
import glob as glob_module
|
||||
|
||||
root = root_dir or self.root_dir
|
||||
return glob_module.glob(pattern, root_dir=root, recursive=True)
|
||||
|
||||
|
||||
class GrepTool(ReadTool):
|
||||
"""文件内容搜索工具
|
||||
|
||||
在文件中搜索匹配的行。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="grep",
|
||||
description="在文件中搜索匹配的文本行",
|
||||
permission_class=PermissionClass.READ,
|
||||
side_effect_scope=SideEffectScope.NONE,
|
||||
tags=["file", "search", "text"],
|
||||
)
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "正则表达式模式",
|
||||
},
|
||||
"paths": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "要搜索的文件路径列表",
|
||||
},
|
||||
"case_sensitive": {
|
||||
"type": "boolean",
|
||||
"description": "是否区分大小写",
|
||||
},
|
||||
},
|
||||
"required": ["pattern", "paths"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file": {"type": "string"},
|
||||
"line": {"type": "integer"},
|
||||
"content": {"type": "string"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self, pattern: str, paths: list[str], case_sensitive: bool = True
|
||||
) -> list[dict[str, Any]]:
|
||||
import re
|
||||
|
||||
flags = 0 if case_sensitive else re.IGNORECASE
|
||||
regex = re.compile(pattern, flags)
|
||||
results = []
|
||||
|
||||
for path in paths:
|
||||
if not os.path.isfile(path):
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
if regex.search(line):
|
||||
results.append(
|
||||
{
|
||||
"file": path,
|
||||
"line": line_num,
|
||||
"content": line.rstrip(),
|
||||
}
|
||||
)
|
||||
except (UnicodeDecodeError, PermissionError):
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class ReadFileTool(ReadTool):
|
||||
"""文件读取工具"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="read_file",
|
||||
description="读取文件内容",
|
||||
permission_class=PermissionClass.READ,
|
||||
side_effect_scope=SideEffectScope.NONE,
|
||||
tags=["file", "read"],
|
||||
)
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "文件路径",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "最大行数",
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "起始行号",
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {"type": "string"},
|
||||
"lines": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(self, path: str, limit: int | None = None, offset: int = 0) -> dict[str, Any]:
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
total_lines = len(lines)
|
||||
start = max(0, offset)
|
||||
end = len(lines) if limit is None else min(start + limit, len(lines))
|
||||
|
||||
content = "".join(lines[start:end])
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"lines": total_lines,
|
||||
"truncated": limit is not None and end < len(lines),
|
||||
}
|
||||
|
||||
|
||||
class WriteFileTool(WriteTool):
|
||||
"""文件写入工具"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="write_file",
|
||||
description="写入文件内容",
|
||||
permission_class=PermissionClass.WRITE,
|
||||
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||
requires_confirmation=True,
|
||||
tags=["file", "write"],
|
||||
)
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "文件路径",
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "文件内容",
|
||||
},
|
||||
"append": {
|
||||
"type": "boolean",
|
||||
"description": "是否追加模式",
|
||||
},
|
||||
},
|
||||
"required": ["path", "content"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"success": {"type": "boolean"},
|
||||
"bytes_written": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(self, path: str, content: str, append: bool = False) -> dict[str, Any]:
|
||||
mode = "a" if append else "w"
|
||||
|
||||
# 确保目录存在
|
||||
directory = os.path.dirname(path)
|
||||
if directory and not os.path.exists(directory):
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
with open(path, mode, encoding="utf-8") as f:
|
||||
bytes_written = f.write(content)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"bytes_written": bytes_written,
|
||||
}
|
||||
193
backend/app/agents/tools/builtins/system_tools.py
Normal file
193
backend/app/agents/tools/builtins/system_tools.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""系统工具 - Phase 6.4"""
|
||||
|
||||
import asyncio
|
||||
import shlex
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.base import ExternalTool
|
||||
from app.agents.tools.manifest import (
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
)
|
||||
|
||||
|
||||
class BashTool(ExternalTool):
|
||||
"""Bash 命令执行工具"""
|
||||
|
||||
def __init__(self, working_dir: str = "."):
|
||||
super().__init__(
|
||||
name="bash",
|
||||
description="执行 Bash 命令",
|
||||
permission_class=PermissionClass.EXTERNAL,
|
||||
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||
requires_confirmation=True,
|
||||
tags=["system", "bash", "shell"],
|
||||
)
|
||||
self.working_dir = working_dir
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "要执行的 Bash 命令",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "超时时间(秒)",
|
||||
},
|
||||
"working_dir": {
|
||||
"type": "string",
|
||||
"description": "工作目录(可选)",
|
||||
},
|
||||
},
|
||||
"required": ["command"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"stdout": {"type": "string"},
|
||||
"stderr": {"type": "string"},
|
||||
"returncode": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self, command: str, timeout: int = 30, working_dir: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
import os
|
||||
|
||||
cwd = working_dir or self.working_dir
|
||||
|
||||
try:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": f"Command timed out after {timeout} seconds",
|
||||
"returncode": -1,
|
||||
}
|
||||
|
||||
return {
|
||||
"stdout": stdout.decode("utf-8", errors="replace"),
|
||||
"stderr": stderr.decode("utf-8", errors="replace"),
|
||||
"returncode": process.returncode,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": str(e),
|
||||
"returncode": -1,
|
||||
}
|
||||
|
||||
|
||||
class PowerShellTool(ExternalTool):
|
||||
"""PowerShell 命令执行工具"""
|
||||
|
||||
def __init__(self, working_dir: str = "."):
|
||||
super().__init__(
|
||||
name="powershell",
|
||||
description="执行 PowerShell 命令",
|
||||
permission_class=PermissionClass.EXTERNAL,
|
||||
side_effect_scope=SideEffectScope.LOCAL_STATE,
|
||||
requires_confirmation=True,
|
||||
tags=["system", "powershell", "shell"],
|
||||
)
|
||||
self.working_dir = working_dir
|
||||
|
||||
def get_parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "要执行的 PowerShell 命令",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "超时时间(秒)",
|
||||
},
|
||||
"working_dir": {
|
||||
"type": "string",
|
||||
"description": "工作目录(可选)",
|
||||
},
|
||||
},
|
||||
"required": ["command"],
|
||||
}
|
||||
|
||||
def get_return_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"stdout": {"type": "string"},
|
||||
"stderr": {"type": "string"},
|
||||
"returncode": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self, command: str, timeout: int = 30, working_dir: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
import platform
|
||||
|
||||
# 检测是否是 Windows 平台
|
||||
is_windows = platform.system() == "Windows"
|
||||
|
||||
if not is_windows:
|
||||
# 非 Windows 平台,可能没有 PowerShell
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": "PowerShell is not available on this platform",
|
||||
"returncode": -1,
|
||||
}
|
||||
|
||||
cwd = working_dir or self.working_dir
|
||||
|
||||
try:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
"powershell.exe",
|
||||
"-NoProfile",
|
||||
"-Command",
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": f"Command timed out after {timeout} seconds",
|
||||
"returncode": -1,
|
||||
}
|
||||
|
||||
return {
|
||||
"stdout": stdout.decode("utf-8", errors="replace"),
|
||||
"stderr": stderr.decode("utf-8", errors="replace"),
|
||||
"returncode": process.returncode,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": str(e),
|
||||
"returncode": -1,
|
||||
}
|
||||
217
backend/app/agents/tools/collaboration.py
Normal file
217
backend/app/agents/tools/collaboration.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Agent Collaboration Protocol
|
||||
|
||||
Inter-agent tool collaboration messaging system.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MessageType(str, Enum):
|
||||
"""Collaboration message types"""
|
||||
|
||||
REQUEST = "request" # Request collaboration
|
||||
RESPONSE = "response" # Response result
|
||||
PROGRESS = "progress" # Progress update
|
||||
CANCEL = "cancel" # Cancel request
|
||||
|
||||
|
||||
class CollaborationMessage(BaseModel):
|
||||
"""Collaboration message model"""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
type: MessageType
|
||||
from_agent: str
|
||||
to_agent: str
|
||||
content: Dict[str, Any]
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
def is_request(self) -> bool:
|
||||
return self.type == MessageType.REQUEST
|
||||
|
||||
def is_response(self) -> bool:
|
||||
return self.type == MessageType.RESPONSE
|
||||
|
||||
|
||||
class CollaborationProtocol:
|
||||
"""Agent collaboration protocol for inter-agent tool requests"""
|
||||
|
||||
def __init__(self):
|
||||
self._pending_requests: Dict[str, CollaborationMessage] = {}
|
||||
self._handlers: Dict[str, Callable] = {}
|
||||
self._response_futures: Dict[str, asyncio.Future] = {}
|
||||
|
||||
def register_handler(self, tool_name: str, handler: Callable) -> None:
|
||||
"""Register a tool handler for collaboration
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
handler: Async callable to handle the tool execution
|
||||
"""
|
||||
self._handlers[tool_name] = handler
|
||||
|
||||
async def request_collaboration(
|
||||
self,
|
||||
from_agent: str,
|
||||
to_agent: str,
|
||||
tool_name: str,
|
||||
parameters: Dict[str, Any],
|
||||
timeout_ms: int = 30000,
|
||||
) -> Dict[str, Any]:
|
||||
"""Request collaboration from another agent
|
||||
|
||||
Args:
|
||||
from_agent: Source agent name
|
||||
to_agent: Target agent name
|
||||
tool_name: Tool to execute
|
||||
parameters: Tool parameters
|
||||
timeout_ms: Timeout in milliseconds
|
||||
|
||||
Returns:
|
||||
Execution result dict with status and result/error
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
message = CollaborationMessage(
|
||||
id=request_id,
|
||||
type=MessageType.REQUEST,
|
||||
from_agent=from_agent,
|
||||
to_agent=to_agent,
|
||||
content={
|
||||
"tool": tool_name,
|
||||
"parameters": parameters,
|
||||
},
|
||||
metadata={"timeout": timeout_ms},
|
||||
)
|
||||
|
||||
self._pending_requests[request_id] = message
|
||||
|
||||
# Create future for response
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self._response_futures[request_id] = future
|
||||
|
||||
# Send the message
|
||||
await self._send_message(message)
|
||||
|
||||
# Wait for response with timeout
|
||||
try:
|
||||
result = await asyncio.wait_for(future, timeout=timeout_ms / 1000)
|
||||
return result
|
||||
except asyncio.TimeoutError:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "Collaboration request timed out",
|
||||
}
|
||||
finally:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
self._response_futures.pop(request_id, None)
|
||||
|
||||
async def handle_request(self, message: CollaborationMessage) -> CollaborationMessage:
|
||||
"""Handle an incoming collaboration request
|
||||
|
||||
Args:
|
||||
message: The collaboration message
|
||||
|
||||
Returns:
|
||||
Response message with result or error
|
||||
"""
|
||||
import uuid
|
||||
|
||||
tool_name = message.content.get("tool")
|
||||
parameters = message.content.get("parameters", {})
|
||||
|
||||
handler = self._handlers.get(tool_name)
|
||||
if not handler:
|
||||
return CollaborationMessage(
|
||||
id=str(uuid.uuid4()),
|
||||
type=MessageType.RESPONSE,
|
||||
from_agent=message.to_agent,
|
||||
to_agent=message.from_agent,
|
||||
content={
|
||||
"status": "error",
|
||||
"error": f"Unknown tool: {tool_name}",
|
||||
},
|
||||
metadata={},
|
||||
)
|
||||
|
||||
try:
|
||||
result = await handler(**parameters)
|
||||
return CollaborationMessage(
|
||||
id=str(uuid.uuid4()),
|
||||
type=MessageType.RESPONSE,
|
||||
from_agent=message.to_agent,
|
||||
to_agent=message.from_agent,
|
||||
content={"status": "success", "result": result},
|
||||
metadata={},
|
||||
)
|
||||
except Exception as e:
|
||||
return CollaborationMessage(
|
||||
id=str(uuid.uuid4()),
|
||||
type=MessageType.RESPONSE,
|
||||
from_agent=message.to_agent,
|
||||
to_agent=message.from_agent,
|
||||
content={"status": "error", "error": str(e)},
|
||||
metadata={},
|
||||
)
|
||||
|
||||
async def handle_response(self, message: CollaborationMessage) -> None:
|
||||
"""Handle an incoming response message
|
||||
|
||||
Args:
|
||||
message: The response message
|
||||
"""
|
||||
request_id = None
|
||||
for req_id, pending in self._pending_requests.items():
|
||||
if pending.id == message.id:
|
||||
request_id = req_id
|
||||
break
|
||||
|
||||
if request_id and request_id in self._response_futures:
|
||||
future = self._response_futures[request_id]
|
||||
if not future.done():
|
||||
future.set_result(message.content)
|
||||
|
||||
async def _send_message(self, message: CollaborationMessage) -> None:
|
||||
"""Send a collaboration message
|
||||
|
||||
This is a placeholder for actual transport implementation.
|
||||
In production, this would use WebSocket, message queue, or shared storage.
|
||||
|
||||
Args:
|
||||
message: The message to send
|
||||
"""
|
||||
# TODO: Implement actual message transport
|
||||
# Options: WebSocket, Redis pub/sub, shared database
|
||||
pass
|
||||
|
||||
def get_pending_requests(self) -> list:
|
||||
"""Get list of pending requests"""
|
||||
return [
|
||||
{
|
||||
"id": msg.id,
|
||||
"from": msg.from_agent,
|
||||
"to": msg.to_agent,
|
||||
"tool": msg.content.get("tool"),
|
||||
}
|
||||
for msg in self._pending_requests.values()
|
||||
]
|
||||
|
||||
|
||||
# Global collaboration protocol instance
|
||||
_collaboration_protocol: Optional[CollaborationProtocol] = None
|
||||
|
||||
|
||||
def get_collaboration_protocol() -> CollaborationProtocol:
|
||||
"""Get the global collaboration protocol instance"""
|
||||
global _collaboration_protocol
|
||||
if _collaboration_protocol is None:
|
||||
_collaboration_protocol = CollaborationProtocol()
|
||||
return _collaboration_protocol
|
||||
112
backend/app/agents/tools/direct_executor.py
Normal file
112
backend/app/agents/tools/direct_executor.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
Direct Executor - 直接执行器
|
||||
用于低风险任务,直接执行不隔离
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from app.agents.tools.ai_adapter import AICLIAdapter
|
||||
|
||||
|
||||
class ExecutionResult:
|
||||
"""执行结果"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
success: bool,
|
||||
exit_code: int,
|
||||
stdout: str,
|
||||
stderr: str,
|
||||
):
|
||||
self.success = success
|
||||
self.exit_code = exit_code
|
||||
self.stdout = stdout
|
||||
self.stderr = stderr
|
||||
|
||||
|
||||
class DirectExecutor:
|
||||
"""直接执行器(用于低风险任务)"""
|
||||
|
||||
def __init__(self, adapter: AICLIAdapter, timeout: int = 60):
|
||||
self.adapter = adapter
|
||||
self.timeout = timeout
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
prompt: str,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
直接执行,不需要沙盒
|
||||
|
||||
Args:
|
||||
prompt: 任务描述
|
||||
|
||||
Yields:
|
||||
str: 实时输出
|
||||
"""
|
||||
# 1. 检查 CLI 是否安装
|
||||
if not self.adapter.is_installed():
|
||||
yield f"[ERROR] {self.adapter.cli_name} is not installed\n"
|
||||
yield f"[ERROR] Please install {self.adapter.cli_name} first\n"
|
||||
return
|
||||
|
||||
# 2. 构建命令
|
||||
cmd = self.adapter.build_command(prompt, None)
|
||||
|
||||
# 3. 异步执行,实时 yield 输出
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env={**os.environ, "TERM": "xterm-256color"},
|
||||
)
|
||||
|
||||
# 4. 实时读取输出
|
||||
stdout_lines = []
|
||||
stderr_lines = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
line_bytes = await asyncio.wait_for(
|
||||
process.stdout.readline(),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
if not line_bytes:
|
||||
break
|
||||
line = line_bytes.decode("utf-8", errors="replace")
|
||||
stdout_lines.append(line)
|
||||
yield line
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
yield f"\n[ERROR] Execution timed out after {self.timeout}s\n"
|
||||
break
|
||||
|
||||
# 5. 读取 stderr
|
||||
stderr_bytes = await process.communicate()
|
||||
if stderr_bytes[1]:
|
||||
stderr = stderr_bytes[1].decode("utf-8", errors="replace")
|
||||
stderr_lines.append(stderr)
|
||||
yield f"\n[STDERR]\n{stderr}\n"
|
||||
|
||||
# 6. 完成标记
|
||||
yield f"\n[EXIT_CODE] {process.returncode or 0}\n"
|
||||
yield f"\n[COMPLETE] success={process.returncode == 0}\n"
|
||||
|
||||
async def execute_sync(self, prompt: str) -> ExecutionResult:
|
||||
"""同步执行并返回完整结果"""
|
||||
output_parts = []
|
||||
async for line in self.execute(prompt):
|
||||
output_parts.append(line)
|
||||
|
||||
output = "".join(output_parts)
|
||||
return ExecutionResult(
|
||||
success="[COMPLETE] success=True" in output,
|
||||
exit_code=0,
|
||||
stdout=output,
|
||||
stderr="",
|
||||
)
|
||||
@@ -4,19 +4,12 @@ from langchain_core.tools import tool
|
||||
from app.database import async_session
|
||||
from app.models.forum import ForumPost, ForumReply
|
||||
from app.agents.context import get_current_user
|
||||
from app.agents.tools.async_bridge import run_async
|
||||
from sqlalchemy import select
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def _run_async(coro, timeout: int = 30):
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
|
||||
return run_async(coro, timeout=timeout)
|
||||
|
||||
|
||||
@tool
|
||||
|
||||
46
backend/app/agents/tools/hooks/__init__.py
Normal file
46
backend/app/agents/tools/hooks/__init__.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Hook 系统 - Phase 6.2"""
|
||||
|
||||
from app.agents.tools.hooks.types import (
|
||||
HookDefinition,
|
||||
HookResult,
|
||||
HookStage,
|
||||
HookTrigger,
|
||||
HookType,
|
||||
ExecutionContext,
|
||||
HookHandler,
|
||||
PreToolHook,
|
||||
PostToolHook,
|
||||
ErrorToolHook,
|
||||
SkipToolHook,
|
||||
)
|
||||
from app.agents.tools.hooks.manager import (
|
||||
HookManager,
|
||||
get_hook_manager,
|
||||
reset_hook_manager,
|
||||
)
|
||||
from app.agents.tools.hooks.executor import (
|
||||
HookExecutor,
|
||||
get_hook_executor,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Types
|
||||
"HookType",
|
||||
"HookStage",
|
||||
"HookTrigger",
|
||||
"HookDefinition",
|
||||
"HookResult",
|
||||
"ExecutionContext",
|
||||
"HookHandler",
|
||||
"PreToolHook",
|
||||
"PostToolHook",
|
||||
"ErrorToolHook",
|
||||
"SkipToolHook",
|
||||
# Manager
|
||||
"HookManager",
|
||||
"get_hook_manager",
|
||||
"reset_hook_manager",
|
||||
# Executor
|
||||
"HookExecutor",
|
||||
"get_hook_executor",
|
||||
]
|
||||
11
backend/app/agents/tools/hooks/builtins/__init__.py
Normal file
11
backend/app/agents/tools/hooks/builtins/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""内置 Hook 集合 - Phase 7"""
|
||||
|
||||
from app.agents.tools.hooks.builtins.audit_log import AuditLogHook
|
||||
from app.agents.tools.hooks.builtins.dangerous_confirmation import DangerousConfirmationHook
|
||||
from app.agents.tools.hooks.builtins.security_scan import SecurityScanHook
|
||||
|
||||
__all__ = [
|
||||
"AuditLogHook",
|
||||
"DangerousConfirmationHook",
|
||||
"SecurityScanHook",
|
||||
]
|
||||
115
backend/app/agents/tools/hooks/builtins/audit_log.py
Normal file
115
backend/app/agents/tools/hooks/builtins/audit_log.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""审计日志 Hook - Phase 7.2
|
||||
|
||||
记录所有工具调用到审计日志。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.hooks.types import (
|
||||
ExecutionContext,
|
||||
HookResult,
|
||||
HookType,
|
||||
)
|
||||
from app.agents.tools.manifest import ToolCategory
|
||||
|
||||
|
||||
class AuditLogHook:
|
||||
"""审计日志 Hook
|
||||
|
||||
记录所有工具调用的详细信息,包括:
|
||||
- 调用时间
|
||||
- 工具名称
|
||||
- 输入参数
|
||||
- 执行结果
|
||||
- 执行时长
|
||||
- 用户 ID
|
||||
"""
|
||||
|
||||
def __init__(self, log_path: str | None = None):
|
||||
"""
|
||||
Args:
|
||||
log_path: 日志文件路径,None 则输出到 stdout
|
||||
"""
|
||||
self.log_path = log_path
|
||||
self._logs: list[dict[str, Any]] = []
|
||||
|
||||
async def pre_tool_use(self, context: ExecutionContext) -> HookResult:
|
||||
"""工具执行前记录"""
|
||||
log_entry = {
|
||||
"event": "pre_tool",
|
||||
"tool_name": context.tool_name,
|
||||
"input": context.tool_input,
|
||||
"user_id": context.user_id,
|
||||
"session_id": context.session_id,
|
||||
}
|
||||
self._logs.append(log_entry)
|
||||
self._write_log(log_entry)
|
||||
return HookResult(
|
||||
hook_name="audit_log",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
)
|
||||
|
||||
async def post_tool_use(self, context: ExecutionContext, result: Any) -> HookResult:
|
||||
"""工具执行后记录"""
|
||||
log_entry = {
|
||||
"event": "post_tool",
|
||||
"tool_name": context.tool_name,
|
||||
"result": str(result)[:500] if result else None,
|
||||
"duration_ms": (
|
||||
(context.end_time - context.start_time) * 1000
|
||||
if context.start_time and context.end_time
|
||||
else None
|
||||
),
|
||||
}
|
||||
self._logs.append(log_entry)
|
||||
self._write_log(log_entry)
|
||||
return HookResult(
|
||||
hook_name="audit_log",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
modified_output=result,
|
||||
)
|
||||
|
||||
async def tool_error(self, context: ExecutionContext, error: Exception) -> HookResult:
|
||||
"""工具出错时记录"""
|
||||
log_entry = {
|
||||
"event": "tool_error",
|
||||
"tool_name": context.tool_name,
|
||||
"error": str(error),
|
||||
"error_type": type(error).__name__,
|
||||
}
|
||||
self._logs.append(log_entry)
|
||||
self._write_log(log_entry)
|
||||
return HookResult(
|
||||
hook_name="audit_log",
|
||||
success=False,
|
||||
continue_execution=True,
|
||||
error=str(error),
|
||||
)
|
||||
|
||||
def _write_log(self, entry: dict[str, Any]) -> None:
|
||||
"""写入日志"""
|
||||
import json
|
||||
import datetime
|
||||
|
||||
entry["timestamp"] = datetime.datetime.now().isoformat()
|
||||
|
||||
if self.log_path:
|
||||
try:
|
||||
with open(self.log_path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||
except Exception:
|
||||
# 日志写入失败不影响主流程
|
||||
pass
|
||||
else:
|
||||
# 输出到 stdout
|
||||
print(f"[AUDIT] {json.dumps(entry, ensure_ascii=False)}")
|
||||
|
||||
def get_logs(self) -> list[dict[str, Any]]:
|
||||
"""获取所有日志"""
|
||||
return self._logs.copy()
|
||||
|
||||
def clear_logs(self) -> None:
|
||||
"""清空日志"""
|
||||
self._logs.clear()
|
||||
@@ -0,0 +1,142 @@
|
||||
"""危险操作确认 Hook - Phase 7.2
|
||||
|
||||
对危险操作要求用户确认。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.hooks.types import (
|
||||
ExecutionContext,
|
||||
HookResult,
|
||||
)
|
||||
from app.agents.tools.manifest import SideEffectScope
|
||||
|
||||
|
||||
# 危险操作关键词
|
||||
DANGEROUS_PATTERNS = [
|
||||
# 文件操作
|
||||
"delete",
|
||||
"remove",
|
||||
"rm ",
|
||||
"rmdir",
|
||||
"unlink",
|
||||
"format",
|
||||
"truncate",
|
||||
# 系统操作
|
||||
"shutdown",
|
||||
"reboot",
|
||||
"kill",
|
||||
"pkill",
|
||||
"sudo",
|
||||
"chmod",
|
||||
"chown",
|
||||
# 数据操作
|
||||
"drop",
|
||||
"truncate",
|
||||
"delete from",
|
||||
"delete.*where",
|
||||
"insert into.*select",
|
||||
"update.*set",
|
||||
# 网络操作
|
||||
"curl",
|
||||
"wget",
|
||||
"nc ",
|
||||
"netcat",
|
||||
"ssh ",
|
||||
"scp ",
|
||||
"sftp ",
|
||||
# 环境变量
|
||||
"export.*secret",
|
||||
"export.*key",
|
||||
"export.*token",
|
||||
]
|
||||
|
||||
|
||||
class DangerousConfirmationHook:
|
||||
"""危险操作确认 Hook
|
||||
|
||||
检查工具调用是否包含危险操作,如是则要求确认。
|
||||
"""
|
||||
|
||||
def __init__(self, auto_block: bool = False):
|
||||
"""
|
||||
Args:
|
||||
auto_block: True 表示自动拦截危险操作,False 表示仅警告
|
||||
"""
|
||||
self.auto_block = auto_block
|
||||
self._pending_confirmations: dict[str, bool] = {}
|
||||
|
||||
async def pre_tool_use(self, context: ExecutionContext) -> HookResult:
|
||||
"""检查是否为危险操作"""
|
||||
is_dangerous = self._check_dangerous(context.tool_name, context.tool_input)
|
||||
|
||||
if is_dangerous:
|
||||
if self.auto_block:
|
||||
return HookResult(
|
||||
hook_name="dangerous_confirmation",
|
||||
success=False,
|
||||
continue_execution=False,
|
||||
error=f"危险操作被自动拦截: {context.tool_name}",
|
||||
metadata={"dangerous": True, "auto_blocked": True},
|
||||
)
|
||||
else:
|
||||
# 标记需要确认
|
||||
context.metadata["requires_confirmation"] = True
|
||||
context.metadata["dangerous_operation"] = True
|
||||
return HookResult(
|
||||
hook_name="dangerous_confirmation",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
metadata={"dangerous": True, "requires_confirmation": True},
|
||||
)
|
||||
|
||||
return HookResult(
|
||||
hook_name="dangerous_confirmation",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
)
|
||||
|
||||
def _check_dangerous(self, tool_name: str, tool_input: dict[str, Any]) -> bool:
|
||||
"""检查是否为危险操作"""
|
||||
# 检查工具名称
|
||||
dangerous_tools = [
|
||||
"delete",
|
||||
"remove",
|
||||
"drop",
|
||||
"truncate",
|
||||
"kill",
|
||||
"shutdown",
|
||||
"reboot",
|
||||
"bash",
|
||||
"powershell",
|
||||
"shell",
|
||||
]
|
||||
|
||||
if tool_name.lower() in dangerous_tools:
|
||||
return True
|
||||
|
||||
# 检查输入参数
|
||||
input_str = str(tool_input).lower()
|
||||
|
||||
for pattern in DANGEROUS_PATTERNS:
|
||||
if pattern.lower() in input_str:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def confirm(self, session_id: str, confirmed: bool) -> None:
|
||||
"""确认危险操作
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
confirmed: True 表示用户确认,False 表示取消
|
||||
"""
|
||||
self._pending_confirmations[session_id] = confirmed
|
||||
|
||||
def is_confirmed(self, session_id: str) -> bool:
|
||||
"""检查是否已确认"""
|
||||
return self._pending_confirmations.get(session_id, False)
|
||||
|
||||
def clear_confirmation(self, session_id: str) -> None:
|
||||
"""清除确认状态"""
|
||||
self._pending_confirmations.pop(session_id, None)
|
||||
183
backend/app/agents/tools/hooks/builtins/security_scan.py
Normal file
183
backend/app/agents/tools/hooks/builtins/security_scan.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""安全扫描 Hook - Phase 7.2
|
||||
|
||||
扫描工具调用和结果中的敏感信息。
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.hooks.types import (
|
||||
ExecutionContext,
|
||||
HookResult,
|
||||
)
|
||||
|
||||
|
||||
# 敏感信息模式
|
||||
SENSITIVE_PATTERNS = {
|
||||
"api_key": [
|
||||
r"api[_-]?key['\"]?\s*[:=]\s*['\"]?[a-zA-Z0-9_\-]{20,}",
|
||||
r"apikey['\"]?\s*[:=]\s*['\"]?[a-zA-Z0-9_\-]{20,}",
|
||||
],
|
||||
"password": [
|
||||
r"password['\"]?\s*[:=]\s*['\"]?[^\s'\"]{8,}",
|
||||
r"passwd['\"]?\s*[:=]\s*['\"]?[^\s'\"]{8,}",
|
||||
r"secret['\"]?\s*[:=]\s*['\"]?[a-zA-Z0-9_\-]{20,}",
|
||||
],
|
||||
"token": [
|
||||
r"token['\"]?\s*[:=]\s*['\"]?[a-zA-Z0-9_\-\.]{20,}",
|
||||
r"bearer\s+[a-zA-Z0-9_\-\.]+",
|
||||
r"ghp_[a-zA-Z0-9]{36}",
|
||||
r"sk-[a-zA-Z0-9]{48}",
|
||||
],
|
||||
"private_key": [
|
||||
r"-----BEGIN (RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----",
|
||||
r"-----END (RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----",
|
||||
],
|
||||
"ip_address": [
|
||||
r"\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b",
|
||||
],
|
||||
"email": [
|
||||
r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class SecurityScanHook:
|
||||
"""安全扫描 Hook
|
||||
|
||||
扫描工具输入和输出中的敏感信息,进行脱敏处理。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redact: bool = True,
|
||||
block_on_detect: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
redact: 是否对敏感信息进行脱敏
|
||||
block_on_detect: 检测到敏感信息时是否阻止执行
|
||||
"""
|
||||
self.redact = redact
|
||||
self.block_on_detect = block_on_detect
|
||||
self._compiled_patterns = {
|
||||
name: [re.compile(p, re.IGNORECASE) for p in patterns]
|
||||
for name, patterns in SENSITIVE_PATTERNS.items()
|
||||
}
|
||||
|
||||
async def pre_tool_use(self, context: ExecutionContext) -> HookResult:
|
||||
"""扫描输入参数"""
|
||||
detected = self._scan_dict(context.tool_input)
|
||||
|
||||
if detected:
|
||||
context.metadata["security_detected"] = detected
|
||||
|
||||
if self.block_on_detect:
|
||||
return HookResult(
|
||||
hook_name="security_scan",
|
||||
success=False,
|
||||
continue_execution=False,
|
||||
error=f"检测到敏感信息: {', '.join(detected.keys())}",
|
||||
metadata={"detected": detected, "blocked": True},
|
||||
)
|
||||
|
||||
if self.redact:
|
||||
redacted_input = self._redact_dict(context.tool_input.copy())
|
||||
return HookResult(
|
||||
hook_name="security_scan",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
modified_input=redacted_input,
|
||||
metadata={"detected": detected, "redacted": True},
|
||||
)
|
||||
|
||||
return HookResult(
|
||||
hook_name="security_scan",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
)
|
||||
|
||||
async def post_tool_use(self, context: ExecutionContext, result: Any) -> HookResult:
|
||||
"""扫描输出结果"""
|
||||
if isinstance(result, dict):
|
||||
detected = self._scan_dict(result)
|
||||
|
||||
if detected:
|
||||
context.metadata["security_detected_output"] = detected
|
||||
|
||||
if self.redact:
|
||||
redacted_result = self._redact_dict(result.copy())
|
||||
return HookResult(
|
||||
hook_name="security_scan",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
modified_output=redacted_result,
|
||||
metadata={"detected": detected, "redacted": True},
|
||||
)
|
||||
|
||||
elif isinstance(result, str):
|
||||
detected = self._scan_string(result)
|
||||
if detected:
|
||||
context.metadata["security_detected_output"] = detected
|
||||
|
||||
if self.redact:
|
||||
redacted_result = self._redact_string(result)
|
||||
return HookResult(
|
||||
hook_name="security_scan",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
modified_output=redacted_result,
|
||||
metadata={"detected": detected, "redacted": True},
|
||||
)
|
||||
|
||||
return HookResult(
|
||||
hook_name="security_scan",
|
||||
success=True,
|
||||
continue_execution=True,
|
||||
modified_output=result,
|
||||
)
|
||||
|
||||
def _scan_dict(self, data: dict[str, Any]) -> dict[str, list[str]]:
|
||||
"""扫描字典中的敏感信息"""
|
||||
result: dict[str, list[str]] = {}
|
||||
|
||||
for key, value in data.items():
|
||||
if isinstance(value, str):
|
||||
found = self._scan_string(value)
|
||||
if found:
|
||||
result[key] = found
|
||||
|
||||
return result
|
||||
|
||||
def _scan_string(self, text: str) -> list[str]:
|
||||
"""扫描字符串中的敏感信息"""
|
||||
found_types = []
|
||||
|
||||
for name, patterns in self._compiled_patterns.items():
|
||||
for pattern in patterns:
|
||||
if pattern.search(text):
|
||||
if name not in found_types:
|
||||
found_types.append(name)
|
||||
break
|
||||
|
||||
return found_types
|
||||
|
||||
def _redact_dict(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""脱敏字典中的敏感信息"""
|
||||
for key, value in data.items():
|
||||
if isinstance(value, str):
|
||||
data[key] = self._redact_string(value)
|
||||
elif isinstance(value, dict):
|
||||
data[key] = self._redact_dict(value)
|
||||
elif isinstance(value, list):
|
||||
data[key] = [self._redact_string(v) if isinstance(v, str) else v for v in value]
|
||||
|
||||
return data
|
||||
|
||||
def _redact_string(self, text: str) -> str:
|
||||
"""脱敏字符串中的敏感信息"""
|
||||
for name, patterns in self._compiled_patterns.items():
|
||||
for pattern in patterns:
|
||||
text = pattern.sub(f"[REDACTED:{name}]", text)
|
||||
|
||||
return text
|
||||
105
backend/app/agents/tools/hooks/config.py
Normal file
105
backend/app/agents/tools/hooks/config.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Hook 配置持久化 - Phase 7.3"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.hooks.manager import get_hook_manager
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookConfigEntry:
|
||||
"""Hook 配置条目"""
|
||||
|
||||
name: str
|
||||
hook_type: str
|
||||
enabled: bool
|
||||
tool_names: list[str] | None = None
|
||||
categories: list[str] | None = None
|
||||
priority: int = 0
|
||||
|
||||
|
||||
class HookConfigPersistence:
|
||||
"""Hook 配置持久化"""
|
||||
|
||||
def __init__(self, config_path: str | None = None):
|
||||
"""
|
||||
Args:
|
||||
config_path: 配置文件路径,None 则使用默认路径
|
||||
"""
|
||||
if config_path is None:
|
||||
config_path = os.path.join(
|
||||
os.path.dirname(__file__), "..", "..", "..", "..", "config", "hooks.json"
|
||||
)
|
||||
self.config_path = config_path
|
||||
|
||||
def load_config(self) -> list[HookConfigEntry]:
|
||||
"""从文件加载 Hook 配置"""
|
||||
if not os.path.exists(self.config_path):
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(self.config_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return [HookConfigEntry(**entry) for entry in data]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def save_config(self, entries: list[HookConfigEntry]) -> bool:
|
||||
"""保存 Hook 配置到文件"""
|
||||
try:
|
||||
os.makedirs(os.path.dirname(self.config_path), exist_ok=True)
|
||||
with open(self.config_path, "w", encoding="utf-8") as f:
|
||||
json.dump([asdict(e) for e in entries], f, indent=2, ensure_ascii=False)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def apply_config(self) -> int:
|
||||
"""应用配置到 HookManager
|
||||
|
||||
Returns:
|
||||
应用的 Hook 数量
|
||||
"""
|
||||
from app.agents.tools.hooks.types import HookType
|
||||
|
||||
manager = get_hook_manager()
|
||||
entries = self.load_config()
|
||||
count = 0
|
||||
|
||||
for entry in entries:
|
||||
if entry.enabled:
|
||||
from app.agents.tools.hooks.types import HookDefinition, HookTrigger
|
||||
|
||||
trigger = HookTrigger(
|
||||
tool_names=entry.tool_names,
|
||||
categories=entry.categories,
|
||||
)
|
||||
|
||||
# 创建空的 handler,只是注册配置
|
||||
hook_def = HookDefinition(
|
||||
name=entry.name,
|
||||
hook_type=HookType(entry.hook_type),
|
||||
trigger=trigger,
|
||||
handler=lambda ctx, *args: ctx,
|
||||
priority=entry.priority,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
manager.register(hook_def)
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
# 全局单例
|
||||
_persistence: HookConfigPersistence | None = None
|
||||
|
||||
|
||||
def get_hook_config_persistence() -> HookConfigPersistence:
|
||||
"""获取全局 Hook 配置持久化实例"""
|
||||
global _persistence
|
||||
if _persistence is None:
|
||||
_persistence = HookConfigPersistence()
|
||||
return _persistence
|
||||
5
backend/app/agents/tools/hooks/custom/__init__.py
Normal file
5
backend/app/agents/tools/hooks/custom/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""自定义 Hook 加载器包"""
|
||||
|
||||
from app.agents.tools.hooks.custom.loader import CustomHookLoader, get_custom_hook_loader
|
||||
|
||||
__all__ = ["CustomHookLoader", "get_custom_hook_loader"]
|
||||
153
backend/app/agents/tools/hooks/custom/loader.py
Normal file
153
backend/app/agents/tools/hooks/custom/loader.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""自定义 Hook 加载器 - Phase 7.4
|
||||
|
||||
支持动态加载用户自定义的 Hook。
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.hooks.types import HookDefinition, HookType, HookTrigger, HookResult
|
||||
|
||||
|
||||
class CustomHookLoader:
|
||||
"""自定义 Hook 加载器
|
||||
|
||||
从指定目录动态加载自定义 Hook 模块。
|
||||
"""
|
||||
|
||||
def __init__(self, hooks_dir: str | None = None):
|
||||
"""
|
||||
Args:
|
||||
hooks_dir: Hook 目录,None 则使用默认目录
|
||||
"""
|
||||
if hooks_dir is None:
|
||||
hooks_dir = os.path.join(
|
||||
os.path.dirname(__file__), "..", "..", "..", "data", "custom_hooks"
|
||||
)
|
||||
self.hooks_dir = hooks_dir
|
||||
self._loaded_hooks: dict[str, HookDefinition] = {}
|
||||
|
||||
def load_all(self) -> list[HookDefinition]:
|
||||
"""加载所有自定义 Hook
|
||||
|
||||
Returns:
|
||||
Hook 定义列表
|
||||
"""
|
||||
hooks = []
|
||||
|
||||
if not os.path.exists(self.hooks_dir):
|
||||
return hooks
|
||||
|
||||
for filename in os.listdir(self.hooks_dir):
|
||||
if filename.endswith(".py") and not filename.startswith("_"):
|
||||
hook_path = os.path.join(self.hooks_dir, filename)
|
||||
hook_def = self._load_hook_from_file(hook_path, filename[:-3])
|
||||
if hook_def:
|
||||
hooks.append(hook_def)
|
||||
self._loaded_hooks[hook_def.name] = hook_def
|
||||
|
||||
return hooks
|
||||
|
||||
def _load_hook_from_file(self, hook_path: str, module_name: str) -> HookDefinition | None:
|
||||
"""从文件加载 Hook
|
||||
|
||||
Args:
|
||||
hook_path: Hook 文件路径
|
||||
module_name: 模块名
|
||||
|
||||
Returns:
|
||||
Hook 定义或 None
|
||||
"""
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location(module_name, hook_path)
|
||||
if not spec or not spec.loader:
|
||||
return None
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# 查找 HOOK_DEFINITION 或 hook_definition
|
||||
hook_def = getattr(module, "HOOK_DEFINITION", None) or getattr(
|
||||
module, "hook_definition", None
|
||||
)
|
||||
|
||||
if hook_def and isinstance(hook_def, HookDefinition):
|
||||
return hook_def
|
||||
|
||||
# 如果没有定义,尝试从函数自动推断
|
||||
if hasattr(module, "pre_tool_hook") or hasattr(module, "post_tool_hook"):
|
||||
return self._infer_hook_definition(module, module_name)
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def _infer_hook_definition(self, module: Any, module_name: str) -> HookDefinition | None:
|
||||
"""从模块函数推断 Hook 定义
|
||||
|
||||
Args:
|
||||
module: 模块对象
|
||||
module_name: 模块名
|
||||
|
||||
Returns:
|
||||
Hook 定义或 None
|
||||
"""
|
||||
hook_type = None
|
||||
handler = None
|
||||
|
||||
if hasattr(module, "pre_tool_hook"):
|
||||
handler = module.pre_tool_hook
|
||||
hook_type = HookType.PRE_TOOL_USE
|
||||
elif hasattr(module, "post_tool_hook"):
|
||||
handler = module.post_tool_hook
|
||||
hook_type = HookType.POST_TOOL_USE
|
||||
elif hasattr(module, "error_tool_hook"):
|
||||
handler = module.error_tool_hook
|
||||
hook_type = HookType.TOOL_ERROR
|
||||
|
||||
if not handler or not hook_type:
|
||||
return None
|
||||
|
||||
return HookDefinition(
|
||||
name=module_name,
|
||||
hook_type=hook_type,
|
||||
trigger=HookTrigger(),
|
||||
handler=handler,
|
||||
priority=0,
|
||||
enabled=True,
|
||||
description=f"Auto-loaded hook from {module_name}",
|
||||
)
|
||||
|
||||
def get_hook(self, name: str) -> HookDefinition | None:
|
||||
"""获取已加载的 Hook
|
||||
|
||||
Args:
|
||||
name: Hook 名称
|
||||
|
||||
Returns:
|
||||
Hook 定义或 None
|
||||
"""
|
||||
return self._loaded_hooks.get(name)
|
||||
|
||||
def reload(self) -> list[HookDefinition]:
|
||||
"""重新加载所有 Hook
|
||||
|
||||
Returns:
|
||||
重新加载的 Hook 列表
|
||||
"""
|
||||
self._loaded_hooks.clear()
|
||||
return self.load_all()
|
||||
|
||||
|
||||
# 全局加载器
|
||||
_loader: CustomHookLoader | None = None
|
||||
|
||||
|
||||
def get_custom_hook_loader() -> CustomHookLoader:
|
||||
"""获取全局自定义 Hook 加载器"""
|
||||
global _loader
|
||||
if _loader is None:
|
||||
_loader = CustomHookLoader()
|
||||
return _loader
|
||||
170
backend/app/agents/tools/hooks/executor.py
Normal file
170
backend/app/agents/tools/hooks/executor.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""Hook 执行器 - Phase 6.2
|
||||
|
||||
执行 Hook 拦截逻辑。
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.hooks.manager import get_hook_manager
|
||||
from app.agents.tools.hooks.types import (
|
||||
HookDefinition,
|
||||
HookResult,
|
||||
HookType,
|
||||
ExecutionContext,
|
||||
)
|
||||
|
||||
|
||||
class HookExecutor:
|
||||
"""Hook 执行器
|
||||
|
||||
负责在工具执行前后执行 Hook 逻辑。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._manager = get_hook_manager()
|
||||
|
||||
async def execute_pre_hooks(
|
||||
self, context: ExecutionContext
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""执行 pre-tool Hook
|
||||
|
||||
Args:
|
||||
context: 执行上下文
|
||||
|
||||
Returns:
|
||||
(是否继续执行, 修改后的输入)
|
||||
"""
|
||||
hooks = self._manager.get_hooks(HookType.PRE_TOOL_USE, context.tool_name)
|
||||
modified_input = context.tool_input
|
||||
|
||||
for hook in hooks:
|
||||
try:
|
||||
# 调用 hook handler
|
||||
handler = hook.handler
|
||||
if callable(handler):
|
||||
result = await self._call_hook(handler, context)
|
||||
if result and not result.continue_execution:
|
||||
# Hook 决定中断执行
|
||||
return False, modified_input
|
||||
if result.modified_input is not None:
|
||||
modified_input = result.modified_input
|
||||
except Exception as e:
|
||||
# Hook 出错,默认继续执行
|
||||
pass
|
||||
|
||||
return True, modified_input
|
||||
|
||||
async def execute_post_hooks(self, context: ExecutionContext, result: Any) -> Any:
|
||||
"""执行 post-tool Hook
|
||||
|
||||
Args:
|
||||
context: 执行上下文
|
||||
result: 工具执行结果
|
||||
|
||||
Returns:
|
||||
修改后的结果
|
||||
"""
|
||||
hooks = self._manager.get_hooks(HookType.POST_TOOL_USE, context.tool_name)
|
||||
modified_result = result
|
||||
|
||||
for hook in hooks:
|
||||
try:
|
||||
handler = hook.handler
|
||||
if callable(handler):
|
||||
hook_result = await self._call_hook(handler, context, modified_result)
|
||||
if hook_result and hook_result.modified_output is not None:
|
||||
modified_result = hook_result.modified_output
|
||||
except Exception:
|
||||
# Hook 出错,默认保留原结果
|
||||
pass
|
||||
|
||||
return modified_result
|
||||
|
||||
async def execute_error_hooks(
|
||||
self, context: ExecutionContext, error: Exception
|
||||
) -> HookResult | None:
|
||||
"""执行 error Hook
|
||||
|
||||
Args:
|
||||
context: 执行上下文
|
||||
error: 异常
|
||||
|
||||
Returns:
|
||||
Hook 结果,如果返回 None 则继续传播错误
|
||||
"""
|
||||
hooks = self._manager.get_hooks(HookType.TOOL_ERROR, context.tool_name)
|
||||
|
||||
for hook in hooks:
|
||||
try:
|
||||
handler = hook.handler
|
||||
if callable(handler):
|
||||
result = await self._call_hook(handler, context, error)
|
||||
if result is not None and result.continue_execution:
|
||||
return result
|
||||
except Exception:
|
||||
# Hook 出错,继续执行其他 error hooks
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
async def execute_skip_check(self, context: ExecutionContext) -> bool:
|
||||
"""检查是否应跳过工具执行
|
||||
|
||||
Args:
|
||||
context: 执行上下文
|
||||
|
||||
Returns:
|
||||
True 表示跳过,False 表示执行
|
||||
"""
|
||||
hooks = self._manager.get_hooks(HookType.TOOL_SKIP, context.tool_name)
|
||||
|
||||
for hook in hooks:
|
||||
try:
|
||||
handler = hook.handler
|
||||
if callable(handler):
|
||||
result = await self._call_hook(handler, context)
|
||||
if result is not None and isinstance(result, bool):
|
||||
return result
|
||||
except Exception:
|
||||
# Hook 出错,默认不跳过
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
async def _call_hook(
|
||||
self, handler: Any, context: ExecutionContext, *args: Any
|
||||
) -> HookResult | None:
|
||||
"""调用 Hook 处理函数
|
||||
|
||||
Args:
|
||||
handler: Hook 处理函数
|
||||
context: 执行上下文
|
||||
*args: 额外参数
|
||||
|
||||
Returns:
|
||||
Hook 结果
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# 如果是普通函数,直接调用
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
return await handler(context, *args)
|
||||
else:
|
||||
return handler(context, *args)
|
||||
|
||||
|
||||
# 全局单例
|
||||
_executor: HookExecutor | None = None
|
||||
|
||||
|
||||
def get_hook_executor() -> HookExecutor:
|
||||
"""获取全局 Hook 执行器
|
||||
|
||||
Returns:
|
||||
全局 HookExecutor 实例
|
||||
"""
|
||||
global _executor
|
||||
if _executor is None:
|
||||
_executor = HookExecutor()
|
||||
return _executor
|
||||
174
backend/app/agents/tools/hooks/manager.py
Normal file
174
backend/app/agents/tools/hooks/manager.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""Hook 管理器 - Phase 6.2
|
||||
|
||||
管理 Hook 的注册、查找和配置。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.hooks.types import (
|
||||
HookDefinition,
|
||||
HookResult,
|
||||
HookTrigger,
|
||||
HookType,
|
||||
ExecutionContext,
|
||||
)
|
||||
|
||||
|
||||
class HookManager:
|
||||
"""Hook 管理器
|
||||
|
||||
管理全局 Hook 的注册和配置。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._hooks: dict[HookType, list[HookDefinition]] = {
|
||||
HookType.PRE_TOOL_USE: [],
|
||||
HookType.POST_TOOL_USE: [],
|
||||
HookType.TOOL_ERROR: [],
|
||||
HookType.TOOL_SKIP: [],
|
||||
}
|
||||
self._global_hooks: list[HookDefinition] = [] # 全局 Hook(对所有工具生效)
|
||||
|
||||
def register(self, definition: HookDefinition) -> None:
|
||||
"""注册 Hook
|
||||
|
||||
Args:
|
||||
definition: Hook 定义
|
||||
"""
|
||||
if definition.trigger.tool_names is None and definition.trigger.categories is None:
|
||||
# 全局 Hook
|
||||
self._global_hooks.append(definition)
|
||||
else:
|
||||
# 特定工具 Hook
|
||||
self._hooks[definition.hook_type].append(definition)
|
||||
|
||||
# 按优先级排序
|
||||
self._hooks[definition.hook_type].sort(key=lambda h: h.priority, reverse=True)
|
||||
self._global_hooks.sort(key=lambda h: h.priority, reverse=True)
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""注销 Hook
|
||||
|
||||
Args:
|
||||
name: Hook 名称
|
||||
|
||||
Returns:
|
||||
是否成功注销
|
||||
"""
|
||||
# 从特定工具 Hook 中移除
|
||||
for hooks in self._hooks.values():
|
||||
for i, hook in enumerate(hooks):
|
||||
if hook.name == name:
|
||||
hooks.pop(i)
|
||||
return True
|
||||
|
||||
# 从全局 Hook 中移除
|
||||
for i, hook in enumerate(self._global_hooks):
|
||||
if hook.name == name:
|
||||
self._global_hooks.pop(i)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_hooks(self, hook_type: HookType, tool_name: str | None = None) -> list[HookDefinition]:
|
||||
"""获取指定类型和工具的 Hook
|
||||
|
||||
Args:
|
||||
hook_type: Hook 类型
|
||||
tool_name: 工具名称(可选)
|
||||
|
||||
Returns:
|
||||
匹配的 Hook 列表
|
||||
"""
|
||||
result: list[HookDefinition] = []
|
||||
|
||||
# 添加全局 Hook
|
||||
for hook in self._global_hooks:
|
||||
if hook.hook_type == hook_type and hook.enabled:
|
||||
result.append(hook)
|
||||
|
||||
# 添加特定工具 Hook
|
||||
for hook in self._hooks[hook_type]:
|
||||
if not hook.enabled:
|
||||
continue
|
||||
|
||||
if hook.trigger.tool_names is None and hook.trigger.categories is None:
|
||||
continue
|
||||
|
||||
# 检查是否匹配
|
||||
if hook.trigger.tool_names and tool_name not in hook.trigger.tool_names:
|
||||
continue
|
||||
|
||||
result.append(hook)
|
||||
|
||||
return result
|
||||
|
||||
def list_all(self) -> list[HookDefinition]:
|
||||
"""列出所有已注册的 Hook
|
||||
|
||||
Returns:
|
||||
Hook 列表
|
||||
"""
|
||||
all_hooks = list(self._global_hooks)
|
||||
for hooks in self._hooks.values():
|
||||
all_hooks.extend(hooks)
|
||||
return all_hooks
|
||||
|
||||
def enable(self, name: str) -> bool:
|
||||
"""启用 Hook
|
||||
|
||||
Args:
|
||||
name: Hook 名称
|
||||
|
||||
Returns:
|
||||
是否成功启用
|
||||
"""
|
||||
for hook in self.list_all():
|
||||
if hook.name == name:
|
||||
hook.enabled = True
|
||||
return True
|
||||
return False
|
||||
|
||||
def disable(self, name: str) -> bool:
|
||||
"""禁用 Hook
|
||||
|
||||
Args:
|
||||
name: Hook 名称
|
||||
|
||||
Returns:
|
||||
是否成功禁用
|
||||
"""
|
||||
for hook in self.list_all():
|
||||
if hook.name == name:
|
||||
hook.enabled = False
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清除所有 Hook"""
|
||||
self._hooks = {ht: [] for ht in HookType}
|
||||
self._global_hooks = []
|
||||
|
||||
|
||||
# 全局单例
|
||||
_global_hook_manager: HookManager | None = None
|
||||
|
||||
|
||||
def get_hook_manager() -> HookManager:
|
||||
"""获取全局 Hook 管理器
|
||||
|
||||
Returns:
|
||||
全局 HookManager 实例
|
||||
"""
|
||||
global _global_hook_manager
|
||||
if _global_hook_manager is None:
|
||||
_global_hook_manager = HookManager()
|
||||
return _global_hook_manager
|
||||
|
||||
|
||||
def reset_hook_manager() -> None:
|
||||
"""重置全局 Hook 管理器(用于测试)"""
|
||||
global _global_hook_manager
|
||||
if _global_hook_manager is not None:
|
||||
_global_hook_manager.clear()
|
||||
_global_hook_manager = None
|
||||
90
backend/app/agents/tools/hooks/types.py
Normal file
90
backend/app/agents/tools/hooks/types.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Hook 类型定义 - Phase 6.2
|
||||
|
||||
Hook 拦截系统类型定义。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
class HookType(Enum):
|
||||
"""Hook 类型"""
|
||||
|
||||
PRE_TOOL_USE = "pre_tool_use" # 工具执行前
|
||||
POST_TOOL_USE = "post_tool_use" # 工具执行后
|
||||
TOOL_ERROR = "tool_error" # 工具执行出错
|
||||
TOOL_SKIP = "tool_skip" # 工具跳过(条件执行)
|
||||
|
||||
|
||||
class HookStage(Enum):
|
||||
"""Hook 执行阶段"""
|
||||
|
||||
BEFORE = "before"
|
||||
AFTER = "after"
|
||||
ON_ERROR = "on_error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookTrigger:
|
||||
"""Hook 触发条件"""
|
||||
|
||||
tool_names: list[str] | None = None # 只对特定工具生效,None 表示全部
|
||||
categories: list[str] | None = None # 只对特定类别生效
|
||||
conditions: dict[str, Any] | None = None # 自定义条件
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookDefinition:
|
||||
"""Hook 定义"""
|
||||
|
||||
name: str
|
||||
hook_type: HookType
|
||||
trigger: HookTrigger
|
||||
handler: Callable[..., Any] # Hook 处理函数
|
||||
priority: int = 0 # 优先级,数字越大越先执行
|
||||
enabled: bool = True
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookResult:
|
||||
"""Hook 执行结果"""
|
||||
|
||||
hook_name: str
|
||||
success: bool
|
||||
continue_execution: bool = True # False 表示中断执行
|
||||
modified_input: Any = None # 修改后的输入
|
||||
modified_output: Any = None # 修改后的输出
|
||||
error: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionContext:
|
||||
"""工具执行上下文"""
|
||||
|
||||
tool_name: str
|
||||
tool_input: dict[str, Any]
|
||||
user_id: str | None = None
|
||||
session_id: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# 执行结果(由 HookExecutor 填充)
|
||||
result: Any = None
|
||||
error: Exception | None = None
|
||||
start_time: float | None = None
|
||||
end_time: float | None = None
|
||||
|
||||
|
||||
# Hook 处理函数类型
|
||||
HookHandler = Callable[[ExecutionContext, HookDefinition], HookResult]
|
||||
|
||||
# Pre-hook: 在工具执行前调用,可以修改输入或决定是否跳过
|
||||
PreToolHook = Callable[[ExecutionContext], tuple[bool, dict[str, Any] | None]]
|
||||
# post-hook: 在工具执行后调用,可以修改输出
|
||||
PostToolHook = Callable[[ExecutionContext, Any], Any]
|
||||
# Error hook: 在工具出错时调用
|
||||
ErrorToolHook = Callable[[ExecutionContext, Exception], HookResult | None]
|
||||
# Skip hook: 决定是否跳过工具执行
|
||||
SkipToolHook = Callable[[ExecutionContext], bool]
|
||||
58
backend/app/agents/tools/interactive_input.py
Normal file
58
backend/app/agents/tools/interactive_input.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
InteractiveInputHandler - 交互输入处理
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.terminal_engine import PTYManager
|
||||
|
||||
|
||||
class InteractiveInputHandler:
|
||||
"""交互输入处理器"""
|
||||
|
||||
def __init__(self, pty_manager: PTYManager):
|
||||
self.pty_manager = pty_manager
|
||||
self._pending_inputs: dict[str, asyncio.Event] = {}
|
||||
self._input_cache: dict[str, str] = {}
|
||||
|
||||
async def wait_for_input(self, session_id: str, prompt: str) -> str:
|
||||
"""等待用户输入(如 "y" 确认)"""
|
||||
event = asyncio.Event()
|
||||
self._pending_inputs[session_id] = event
|
||||
|
||||
# 发送提示
|
||||
from app.routers.terminal import manager
|
||||
|
||||
try:
|
||||
await manager.send(session_id, f"\n{prompt}\n")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 等待输入完成
|
||||
try:
|
||||
await asyncio.wait_for(event.wait(), timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
del self._pending_inputs[session_id]
|
||||
return self._input_cache.get(session_id, "")
|
||||
|
||||
del self._pending_inputs[session_id]
|
||||
|
||||
return self._input_cache.get(session_id, "")
|
||||
|
||||
async def send_input(self, session_id: str, data: str):
|
||||
"""用户发送输入"""
|
||||
self._input_cache[session_id] = data
|
||||
if session_id in self._pending_inputs:
|
||||
self._pending_inputs[session_id].set()
|
||||
|
||||
# 同时写入 PTY
|
||||
await self.pty_manager.write(session_id, data + "\n")
|
||||
|
||||
def clear_input(self, session_id: str):
|
||||
"""清除输入缓存"""
|
||||
if session_id in self._input_cache:
|
||||
del self._input_cache[session_id]
|
||||
if session_id in self._pending_inputs:
|
||||
self._pending_inputs[session_id].set() # 取消等待
|
||||
del self._pending_inputs[session_id]
|
||||
77
backend/app/agents/tools/manifest.py
Normal file
77
backend/app/agents/tools/manifest.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""工具元数据和数据类型定义"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ToolCategory(Enum):
|
||||
"""工具类别"""
|
||||
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
EXTERNAL = "external"
|
||||
DB_WRITE = "db_write"
|
||||
NETWORK = "network"
|
||||
|
||||
|
||||
class SideEffectScope(Enum):
|
||||
"""副作用范围"""
|
||||
|
||||
NONE = "none"
|
||||
LOCAL_STATE = "local_state"
|
||||
DB_WRITE = "db_write"
|
||||
NETWORK = "network"
|
||||
|
||||
|
||||
class PermissionClass(Enum):
|
||||
"""权限级别"""
|
||||
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
EXTERNAL = "external"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolManifest:
|
||||
"""工具元数据"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
category: ToolCategory
|
||||
parameters: dict[str, Any] # JSON Schema
|
||||
return_schema: dict[str, Any]
|
||||
permission_class: PermissionClass
|
||||
side_effect_scope: SideEffectScope
|
||||
requires_confirmation: bool = False
|
||||
is_streaming: bool = False
|
||||
tags: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"category": self.category.value,
|
||||
"parameters": self.parameters,
|
||||
"return_schema": self.return_schema,
|
||||
"permission_class": self.permission_class.value,
|
||||
"side_effect_scope": self.side_effect_scope.value,
|
||||
"requires_confirmation": self.requires_confirmation,
|
||||
"is_streaming": self.is_streaming,
|
||||
"tags": self.tags,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookConfig:
|
||||
"""Hook 配置"""
|
||||
|
||||
name: str
|
||||
hook_type: str # "pre_tool_use", "post_tool_use", "tool_error", "tool_skip"
|
||||
filter_names: list[str] | None = None # 只对特定工具生效,None 表示全部
|
||||
|
||||
def matches_tool(self, tool_name: str) -> bool:
|
||||
"""检查 Hook 是否对指定工具生效"""
|
||||
if self.filter_names is None:
|
||||
return True
|
||||
return tool_name in self.filter_names
|
||||
251
backend/app/agents/tools/migration.py
Normal file
251
backend/app/agents/tools/migration.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""工具迁移和向后兼容层 - Phase 6.1
|
||||
|
||||
将现有 @tool 装饰的工具迁移到 ToolRegistry,同时保持向后兼容。
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
from typing import Any, Callable
|
||||
|
||||
from app.agents.tools.manifest import (
|
||||
PermissionClass,
|
||||
SideEffectScope,
|
||||
ToolCategory,
|
||||
ToolManifest,
|
||||
)
|
||||
from app.agents.tools.registry import get_tool_registry
|
||||
|
||||
|
||||
# 现有工具的类别映射
|
||||
_TOOL_CATEGORY_MAP: dict[str, tuple[ToolCategory, PermissionClass, SideEffectScope]] = {
|
||||
# 知识检索 - 只读
|
||||
"search_knowledge": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||
"get_knowledge_graph_context": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||
"hybrid_search": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||
"web_search": (ToolCategory.NETWORK, PermissionClass.EXTERNAL, SideEffectScope.NETWORK),
|
||||
# 知识构建 - 写入
|
||||
"build_knowledge_graph": (
|
||||
ToolCategory.WRITE,
|
||||
PermissionClass.WRITE,
|
||||
SideEffectScope.LOCAL_STATE,
|
||||
),
|
||||
# 任务工具
|
||||
"get_tasks": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||
"create_task": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||
"update_task_status": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||
# 日程工具
|
||||
"get_schedule_day": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||
"create_todo": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||
"create_schedule_task": (
|
||||
ToolCategory.WRITE,
|
||||
PermissionClass.WRITE,
|
||||
SideEffectScope.LOCAL_STATE,
|
||||
),
|
||||
"create_reminder": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||
"create_goal": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||
"resolve_time_expression": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||
# 论坛工具
|
||||
"get_forum_posts": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||
"create_forum_post": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
|
||||
"scan_forum_for_instructions": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
|
||||
}
|
||||
|
||||
|
||||
def get_tool_category(name: str) -> tuple[ToolCategory, PermissionClass, SideEffectScope]:
|
||||
"""获取工具的类别信息"""
|
||||
return _TOOL_CATEGORY_MAP.get(
|
||||
name,
|
||||
(ToolCategory.EXTERNAL, PermissionClass.EXTERNAL, SideEffectScope.NETWORK),
|
||||
)
|
||||
|
||||
|
||||
def infer_tags_from_docstring(docstring: str | None) -> list[str]:
|
||||
"""从 docstring 推断工具标签"""
|
||||
if not docstring:
|
||||
return []
|
||||
tags = []
|
||||
doc_lower = docstring.lower()
|
||||
if "搜索" in docstring or "查询" in docstring or "search" in doc_lower:
|
||||
tags.append("search")
|
||||
if "创建" in docstring or "新建" in docstring or "create" in doc_lower:
|
||||
tags.append("create")
|
||||
if "获取" in docstring or "读取" in docstring or "get" in doc_lower:
|
||||
tags.append("read")
|
||||
if "更新" in docstring or "修改" in docstring or "update" in doc_lower:
|
||||
tags.append("update")
|
||||
return tags
|
||||
|
||||
|
||||
def migrate_tool(tool_func: Callable) -> Callable:
|
||||
"""将现有 @tool 装饰的函数迁移到 ToolRegistry
|
||||
|
||||
Args:
|
||||
tool_func: LangChain @tool 装饰的函数
|
||||
|
||||
Returns:
|
||||
原函数(已注册到 registry)
|
||||
"""
|
||||
registry = get_tool_registry()
|
||||
|
||||
# 如果已经注册,跳过
|
||||
if registry.get(tool_func.name):
|
||||
return tool_func
|
||||
|
||||
# 获取类别信息
|
||||
category, permission, side_effect = get_tool_category(tool_func.name)
|
||||
|
||||
# 从 docstring 提取 description
|
||||
description = tool_func.description if hasattr(tool_func, "description") else ""
|
||||
|
||||
# 推断 tags
|
||||
tags = infer_tags_from_docstring(description)
|
||||
tags.append("migrated")
|
||||
|
||||
# 创建 manifest
|
||||
manifest = ToolManifest(
|
||||
name=tool_func.name,
|
||||
description=description,
|
||||
category=category,
|
||||
parameters={}, # LangChain @tool 动态处理参数
|
||||
return_schema={},
|
||||
permission_class=permission,
|
||||
side_effect_scope=side_effect,
|
||||
requires_confirmation=side_effect != SideEffectScope.NONE,
|
||||
is_streaming=False,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
# 注册到 registry
|
||||
registry.register(manifest, tool_func)
|
||||
|
||||
return tool_func
|
||||
|
||||
|
||||
def migrate_all_tools() -> int:
|
||||
"""迁移所有现有工具到 ToolRegistry
|
||||
|
||||
Returns:
|
||||
迁移的工具数量
|
||||
"""
|
||||
from app.agents.tools import (
|
||||
ALL_TOOLS,
|
||||
KNOWLEDGE_GRAPH_TOOLS,
|
||||
KNOWLEDGE_RETRIEVAL_TOOLS,
|
||||
SCHEDULE_READ_TOOLS,
|
||||
SCHEDULE_WRITE_TOOLS,
|
||||
TASK_TOOLS,
|
||||
FORUM_TOOLS,
|
||||
)
|
||||
|
||||
all_tools = (
|
||||
KNOWLEDGE_RETRIEVAL_TOOLS
|
||||
+ KNOWLEDGE_GRAPH_TOOLS
|
||||
+ TASK_TOOLS
|
||||
+ SCHEDULE_READ_TOOLS
|
||||
+ SCHEDULE_WRITE_TOOLS
|
||||
+ FORUM_TOOLS
|
||||
)
|
||||
|
||||
count = 0
|
||||
for tool in all_tools:
|
||||
try:
|
||||
migrate_tool(tool)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
print(f"Failed to migrate tool {getattr(tool, 'name', 'unknown')}: {e}")
|
||||
|
||||
return count
|
||||
|
||||
|
||||
class BackwardCompatTool:
|
||||
"""向后兼容工具包装器
|
||||
|
||||
确保现有代码通过 registry.get_executor() 仍能正常调用工具。
|
||||
"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self._registry = get_tool_registry()
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
executor = self._registry.get_executor(self.name)
|
||||
if executor is None:
|
||||
raise ValueError(f"Tool not found in registry: {self.name}")
|
||||
return executor(*args, **kwargs)
|
||||
|
||||
def invoke(self, tool_input: dict[str, Any]) -> Any:
|
||||
"""LangChain 风格的 invoke 调用"""
|
||||
executor = self._registry.get_executor(self.name)
|
||||
if executor is None:
|
||||
raise ValueError(f"Tool not found in registry: {self.name}")
|
||||
|
||||
# 处理位置参数
|
||||
if isinstance(tool_input, dict):
|
||||
return executor(**tool_input)
|
||||
return executor(tool_input)
|
||||
|
||||
|
||||
def create_compat_layer() -> dict[str, BackwardCompatTool]:
|
||||
"""创建向后兼容层
|
||||
|
||||
返回一个字典,允许通过名称访问兼容的工具包装器。
|
||||
"""
|
||||
registry = get_tool_registry()
|
||||
tools = registry.list_all()
|
||||
|
||||
return {tool.name: BackwardCompatTool(tool.name) for tool in tools}
|
||||
|
||||
|
||||
# 自动迁移装饰器
|
||||
def auto_migrate(func: Callable) -> Callable:
|
||||
"""自动迁移装饰器
|
||||
|
||||
用于装饰新的 @tool 函数,自动注册到 registry。
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# 迁移到 registry
|
||||
migrate_tool(wrapper)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# 便捷函数:获取兼容的工具执行器
|
||||
def get_tool_executor(name: str) -> Callable | None:
|
||||
"""获取工具执行器(兼容层)
|
||||
|
||||
优先从 registry 获取,fallback 到直接导入。
|
||||
"""
|
||||
registry = get_tool_registry()
|
||||
executor = registry.get_executor(name)
|
||||
|
||||
if executor is not None:
|
||||
return executor
|
||||
|
||||
# Fallback: 直接从模块导入(仅用于迁移期间)
|
||||
try:
|
||||
from app.agents.tools import (
|
||||
TASK_TOOLS,
|
||||
SCHEDULE_READ_TOOLS,
|
||||
SCHEDULE_WRITE_TOOLS,
|
||||
FORUM_TOOLS,
|
||||
KNOWLEDGE_RETRIEVAL_TOOLS,
|
||||
)
|
||||
|
||||
all_tools = (
|
||||
KNOWLEDGE_RETRIEVAL_TOOLS
|
||||
+ TASK_TOOLS
|
||||
+ SCHEDULE_READ_TOOLS
|
||||
+ SCHEDULE_WRITE_TOOLS
|
||||
+ FORUM_TOOLS
|
||||
)
|
||||
|
||||
for tool in all_tools:
|
||||
if hasattr(tool, "name") and tool.name == name:
|
||||
return tool
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return None
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user