Compare commits

..

8 Commits

Author SHA1 Message Date
fca7a7cf3d Phase 7-10: CustomHookLoader, MCPSkillLoader, SkillTriggerDetector, TeamMember, WebSocketManager 2026-04-05 10:56:21 +08:00
d18167826e feat(agents): Phase 8.4-10.5 built-in plugins, bundled skills, coordinator 2026-04-04 23:24:34 +08:00
88955ed550 feat(agents): Phase 7-10 API endpoints for hooks, plugins, skills, sessions 2026-04-04 23:13:47 +08:00
a3fe4d24fc feat(agents): Phase 7-10 hook system, plugins, skills, orchestration
Phase 7: Built-in Hooks (audit_log, dangerous_confirmation, security_scan)
Phase 8: Plugin system (PluginManager, PluginSandbox, PluginManifest)
Phase 9: Skills registry (SkillRegistry, local/plugin/MCP loaders)
Phase 10: TeamLeader, RemoteTransport, BackgroundTaskManager
2026-04-04 22:56:27 +08:00
e5bd492d74 feat(agents): Phase 6 tool system refactoring
Phase 6.1: ToolRegistry infrastructure
- Add ToolManifest with ToolCategory, PermissionClass, SideEffectScope
- Add ToolRegistry singleton with register/get/unregister/list/search
- Add BaseTool abstract class with ReadTool/WriteTool/DBWriteTool/ExternalTool/NetworkTool subclasses
- Add migration layer for backward compatibility

Phase 6.2: Hook interception system
- Add HookType (PRE_TOOL_USE, POST_TOOL_USE, TOOL_ERROR, TOOL_SKIP)
- Add HookManager with singleton for hook registration
- Add HookExecutor for pre/post/error hook execution

Phase 6.3: Streaming execution
- Add StreamingToolExecutor with batch execution support

Phase 6.4: New builtin tools
- Add file_tools: GlobTool, GrepTool, ReadFileTool, WriteFileTool
- Add system_tools: BashTool, PowerShellTool
- Add dev_tools: LSPTools, GitTool
- Add collaboration_tools: TeamAgentTool, TaskBroadcastTool

Tests: 29 passed
2026-04-04 22:47:48 +08:00
a7b6b5eb90 feat: add agent visibility APIs and harden runtime verification
Add Day 4 visibility endpoints and response models, strengthen collaboration/task verification behavior, and patch conversation schema startup migration for agent_state compatibility. Extend backend regression coverage for runtime schemas, verifier behavior, visibility APIs, router auth, and legacy conversation list loading.
2026-04-04 00:56:03 +08:00
aa0ef0fbea feat: add Jarvis agent verification foundation
Add Day 1 agent runtime foundations with task and event schemas, verifier support, capability metadata, graph event tracing, and regression coverage while preserving the direct execution path.
2026-04-03 15:18:08 +08:00
4972b4e6b1 fix: harden L3 runtime continuity and tool execution
Align the L3 graph, agent service, and sync tool shims on one canonical continuity contract so clarification resumes and persisted snapshots behave consistently. Add targeted regressions and hardening notes covering system-message coalescing, async bridge usage, and continuity rehydration.
2026-04-03 13:14:59 +08:00
196 changed files with 35927 additions and 16161 deletions

View File

@@ -0,0 +1,220 @@
"""Background task executor - Phase 10.4"""
import asyncio
from collections.abc import Callable, Coroutine
from datetime import datetime
from typing import Any
from .manager import (
BackgroundTask,
BackgroundTaskManager,
BackgroundTaskStatus,
get_background_task_manager,
)
class BackgroundExecutor:
"""Executes background tasks with error handling and result storage.
Provides methods to execute tasks synchronously or asynchronously,
with full integration into BackgroundTaskManager for tracking.
"""
def __init__(self, task_manager: BackgroundTaskManager | None = None):
"""Initialize the executor.
Args:
task_manager: Optional BackgroundTaskManager instance.
If not provided, uses the global singleton.
"""
self._task_manager = task_manager or get_background_task_manager()
self._executors: dict[str, asyncio.Task] = {}
async def execute_task(
self,
task_id: str,
func: Callable[..., Coroutine[Any, Any, Any]],
*args: Any,
**kwargs: Any,
) -> BackgroundTask:
"""Execute a specific task by ID.
Args:
task_id: Unique task identifier
func: Async function to execute
*args: Positional arguments for the function
**kwargs: Keyword arguments for the function
Returns:
The BackgroundTask with result or error
"""
# Get or create task record
task = self._task_manager.get_task_status(task_id)
if task is None:
# Create a new task record if one doesn't exist
task = BackgroundTask(
id=task_id,
name=f"executor_task_{task_id}",
status=BackgroundTaskStatus.PENDING,
created_at=datetime.now(),
)
self._task_manager._tasks[task_id] = task
# Update status to running
task.status = BackgroundTaskStatus.RUNNING
task.started_at = datetime.now()
try:
# Execute the async function
result = await func(*args, **kwargs)
task.status = BackgroundTaskStatus.COMPLETED
task.result = result
except Exception as e:
task.status = BackgroundTaskStatus.FAILED
task.error = f"{type(e).__name__}: {str(e)}"
task.result = None
finally:
task.completed_at = datetime.now()
# Clean up executor reference
if task_id in self._executors:
del self._executors[task_id]
return task
async def execute_async(
self,
task_id: str,
func: Callable[..., Coroutine[Any, Any, Any]],
*args: Any,
**kwargs: Any,
) -> str:
"""Execute a task asynchronously in the background.
Args:
task_id: Unique task identifier
func: Async function to execute
*args: Positional arguments for the function
**kwargs: Keyword arguments for the function
Returns:
The task ID
"""
# Create task record if it doesn't exist
if self._task_manager.get_task_status(task_id) is None:
self._task_manager._tasks[task_id] = BackgroundTask(
id=task_id,
name=f"async_task_{task_id}",
status=BackgroundTaskStatus.PENDING,
created_at=datetime.now(),
)
# Create and store the asyncio task
async_task = asyncio.create_task(self.execute_task(task_id, func, *args, **kwargs))
self._executors[task_id] = async_task
return task_id
def cancel_task(self, task_id: str) -> bool:
"""Cancel a running task.
Args:
task_id: The task ID to cancel
Returns:
True if cancelled, False if not found or not running
"""
if task_id not in self._executors:
return False
self._executors[task_id].cancel()
del self._executors[task_id]
# Update task status
task = self._task_manager.get_task_status(task_id)
if task:
task.status = BackgroundTaskStatus.CANCELLED
task.completed_at = datetime.now()
return True
return False
def get_task_result(self, task_id: str) -> Any:
"""Get the result of a completed task.
Args:
task_id: The task ID
Returns:
The task result or None if not found/not completed
"""
task = self._task_manager.get_task_status(task_id)
if task and task.status == BackgroundTaskStatus.COMPLETED:
return task.result
return None
def get_task_error(self, task_id: str) -> str | None:
"""Get the error of a failed task.
Args:
task_id: The task ID
Returns:
The error message or None if not found/not failed
"""
task = self._task_manager.get_task_status(task_id)
if task and task.status == BackgroundTaskStatus.FAILED:
return task.error
return None
def is_task_running(self, task_id: str) -> bool:
"""Check if a task is currently running.
Args:
task_id: The task ID
Returns:
True if running, False otherwise
"""
return task_id in self._executors
def wait_for_task(self, task_id: str, timeout: float | None = None) -> BackgroundTask:
"""Wait for a task to complete.
Args:
task_id: The task ID to wait for
timeout: Optional timeout in seconds
Returns:
The completed BackgroundTask
Raises:
asyncio.TimeoutError: If task doesn't complete within timeout
asyncio.CancelledError: If task is cancelled
"""
if task_id not in self._executors:
task = self._task_manager.get_task_status(task_id)
if task:
return task
raise ValueError(f"Task {task_id} not found")
async def wait_task() -> BackgroundTask:
await self._executors[task_id]
return self._task_manager.get_task_status(task_id)
return asyncio.run_until_complete(asyncio.wait_for(wait_task(), timeout=timeout))
@property
def task_manager(self) -> BackgroundTaskManager:
"""Get the underlying task manager."""
return self._task_manager
# Global executor instance
_executor: BackgroundExecutor | None = None
def get_background_executor() -> BackgroundExecutor:
"""Get the global BackgroundExecutor instance."""
global _executor
if _executor is None:
_executor = BackgroundExecutor()
return _executor

View File

@@ -0,0 +1,119 @@
"""后台任务系统 - Phase 10.4"""
import asyncio
import uuid
from dataclasses import dataclass
from datetime import datetime
from typing import Any
from enum import Enum
class BackgroundTaskStatus(Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
@dataclass
class BackgroundTask:
"""后台任务"""
id: str
name: str
status: BackgroundTaskStatus
created_at: datetime
started_at: datetime | None = None
completed_at: datetime | None = None
result: Any = None
error: str | None = None
class BackgroundTaskManager:
"""后台任务管理器"""
def __init__(self):
self._tasks: dict[str, BackgroundTask] = {}
self._.coroutines: dict[str, asyncio.Task] = {}
def submit_task(self, name: str, coro: Any, *args, **kwargs) -> str:
"""提交后台任务
Args:
name: 任务名称
coro: 协程函数
*args: 位置参数
**kwargs: 关键字参数
Returns:
任务 ID
"""
task_id = str(uuid.uuid4())[:8]
# 创建任务记录
self._tasks[task_id] = BackgroundTask(
id=task_id,
name=name,
status=BackgroundTaskStatus.PENDING,
created_at=datetime.now(),
)
# 创建 asyncio task
async def run_task():
self._tasks[task_id].status = BackgroundTaskStatus.RUNNING
self._tasks[task_id].started_at = datetime.now()
try:
result = await coro(*args, **kwargs)
self._tasks[task_id].status = BackgroundTaskStatus.COMPLETED
self._tasks[task_id].result = result
except Exception as e:
self._tasks[task_id].status = BackgroundTaskStatus.FAILED
self._tasks[task_id].error = str(e)
finally:
self._tasks[task_id].completed_at = datetime.now()
if task_id in self._coroutines:
del self._coroutines[task_id]
self._coroutines[task_id] = asyncio.create_task(run_task())
return task_id
def cancel_task(self, task_id: str) -> bool:
"""取消任务
Args:
task_id: 任务 ID
Returns:
是否成功取消
"""
if task_id not in self._tasks:
return False
if task_id in self._coroutines:
self._coroutines[task_id].cancel()
del self._coroutines[task_id]
self._tasks[task_id].status = BackgroundTaskStatus.CANCELLED
self._tasks[task_id].completed_at = datetime.now()
return True
def get_task_status(self, task_id: str) -> BackgroundTask | None:
"""获取任务状态"""
return self._tasks.get(task_id)
def list_tasks(self) -> list[BackgroundTask]:
"""列出所有任务"""
return list(self._tasks.values())
# 全局单例
_manager: BackgroundTaskManager | None = None
def get_background_task_manager() -> BackgroundTaskManager:
"""获取全局后台任务管理器"""
global _manager
if _manager is None:
_manager = BackgroundTaskManager()
return _manager

View File

@@ -0,0 +1,146 @@
"""Background task scheduler - Phase 10.4"""
from collections.abc import Callable, Coroutine
from typing import Any
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.base import BaseTrigger
from .manager import BackgroundTaskManager, get_background_task_manager
class BackgroundScheduler:
"""Background task scheduler using APScheduler.
Integrates with BackgroundTaskManager for task tracking and execution.
"""
def __init__(self, task_manager: BackgroundTaskManager | None = None):
"""Initialize the scheduler.
Args:
task_manager: Optional BackgroundTaskManager instance.
If not provided, uses the global singleton.
"""
self._scheduler = AsyncIOScheduler()
self._task_manager = task_manager or get_background_task_manager()
self._job_tasks: dict[str, str] = {} # Maps APScheduler job_id to task_id
def add_job(
self,
func: Callable[..., Coroutine[Any, Any, Any]],
trigger: BaseTrigger,
args: tuple[Any, ...] | None = None,
kwargs: dict[str, Any] | None = None,
id: str | None = None,
name: str | None = None,
**apscheduler_kwargs: Any,
) -> str:
"""Add a job to the scheduler.
Args:
func: Async function to execute
trigger: APScheduler trigger (date, interval, cron, etc.)
args: Positional arguments for the function
kwargs: Keyword arguments for the function
id: Unique job ID (auto-generated if not provided)
name: Job name for display purposes
**apscheduler_kwargs: Additional APScheduler options
Returns:
The job ID
"""
job_id = id or f"job_{len(self._job_tasks)}"
task_name = name or f"scheduled_task_{job_id}"
# Wrap the async function to integrate with BackgroundTaskManager
async def wrapped_func() -> None:
coro = func(*(args or ()), **(kwargs or {}))
task_id = self._task_manager.submit_task(task_name, coro)
self._job_tasks[job_id] = task_id
self._scheduler.add_job(
wrapped_func,
trigger=trigger,
id=job_id,
name=task_name,
**apscheduler_kwargs,
)
return job_id
def remove_job(self, job_id: str) -> bool:
"""Remove a job from the scheduler.
Args:
job_id: The ID of the job to remove
Returns:
True if job was removed, False if job didn't exist
"""
try:
self._scheduler.remove_job(job_id)
# Clean up task mapping if exists
if job_id in self._job_tasks:
task_id = self._job_tasks.pop(job_id)
# Cancel the background task if still running
self._task_manager.cancel_task(task_id)
return True
except Exception:
return False
def list_jobs(self) -> list[dict[str, Any]]:
"""List all scheduled jobs.
Returns:
List of job information dictionaries
"""
jobs = self._scheduler.get_jobs()
return [
{
"id": job.id,
"name": job.name,
"next_run_time": job.next_run_time,
"trigger": str(job.trigger),
}
for job in jobs
]
def start(self) -> None:
"""Start the scheduler."""
if not self._scheduler.running:
self._scheduler.start()
def shutdown(self, wait: bool = True) -> None:
"""Shutdown the scheduler.
Args:
wait: Whether to wait for running jobs to complete
"""
if self._scheduler.running:
self._scheduler.shutdown(wait=wait)
def pause(self) -> None:
"""Pause the scheduler."""
self._scheduler.pause()
def resume(self) -> None:
"""Resume the scheduler."""
self._scheduler.resume()
@property
def task_manager(self) -> BackgroundTaskManager:
"""Get the underlying task manager."""
return self._task_manager
# Global scheduler instance
_scheduler: BackgroundScheduler | None = None
def get_background_scheduler() -> BackgroundScheduler:
"""Get the global BackgroundScheduler instance."""
global _scheduler
if _scheduler is None:
_scheduler = BackgroundScheduler()
return _scheduler

View File

@@ -0,0 +1,508 @@
"""Agent 协调整器 - Phase 10.5
统一协调所有 Agent 组件TeamLeader, RemoteTransport, BackgroundTaskManager, SessionManager
"""
from typing import Any
from app.agents.background.manager import BackgroundTaskManager, get_background_task_manager
from app.agents.session.manager import AgentSession, create_agent_session, get_agent_session
from app.agents.team.leader import TeamLeader
from app.agents.transport.remote import RemoteTransport
class AgentCoordinator:
"""Agent 协调整器
统一协调所有 Agent 组件,提供单一入口处理各类 Agent 操作。
"""
def __init__(
self,
background_manager: BackgroundTaskManager | None = None,
):
"""
Args:
background_manager: 后台任务管理器None 则使用全局单例
"""
self._team_leaders: dict[str, TeamLeader] = {}
self._remote_transport = RemoteTransport()
self._background_manager = background_manager or get_background_task_manager()
self._sessions: dict[str, AgentSession] = {}
# === Team 协作方法 ===
def create_team(self, team_id: str, members: list[str]) -> dict[str, Any]:
"""创建团队
Args:
team_id: 团队 ID
members: 成员 ID 列表
Returns:
团队创建结果
"""
if team_id in self._team_leaders:
return {"status": "error", "message": f"Team '{team_id}' already exists"}
leader = TeamLeader(team_id=team_id, members=members)
self._team_leaders[team_id] = leader
return {
"status": "created",
"team_id": team_id,
"members": members,
}
def get_team(self, team_id: str) -> TeamLeader | None:
"""获取团队
Args:
team_id: 团队 ID
Returns:
TeamLeader 或 None
"""
return self._team_leaders.get(team_id)
def assign_task(self, team_id: str, description: str, member: str) -> dict[str, Any]:
"""创建并分配任务
Args:
team_id: 团队 ID
description: 任务描述
member: 成员 ID
Returns:
分配结果
"""
leader = self._team_leaders.get(team_id)
if not leader:
return {"status": "error", "message": f"Team '{team_id}' not found"}
task_id = leader.create_task(description)
success = leader.assign_task(task_id, member)
return {
"status": "assigned" if success else "error",
"task_id": task_id,
"assignee": member,
}
def broadcast_task(self, team_id: str, description: str) -> dict[str, Any]:
"""广播任务给所有成员
Args:
team_id: 团队 ID
description: 任务描述
Returns:
广播结果
"""
leader = self._team_leaders.get(team_id)
if not leader:
return {"status": "error", "message": f"Team '{team_id}' not found"}
task_ids = leader.broadcast_task(description)
return {
"status": "broadcast",
"team_id": team_id,
"task_ids": task_ids,
"member_count": len(leader.members),
}
def collect_team_results(self, team_id: str) -> dict[str, Any]:
"""收集团队任务结果
Args:
team_id: 团队 ID
Returns:
收集结果
"""
leader = self._team_leaders.get(team_id)
if not leader:
return {"status": "error", "message": f"Team '{team_id}' not found"}
results = leader.collect_results()
status = leader.get_team_status()
return {
"status": "collected",
"team_id": team_id,
"results": results,
"completed": status["completed"],
"failed": status["failed"],
}
def get_team_status(self, team_id: str) -> dict[str, Any]:
"""获取团队状态
Args:
team_id: 团队 ID
Returns:
团队状态
"""
leader = self._team_leaders.get(team_id)
if not leader:
return {"status": "error", "message": f"Team '{team_id}' not found"}
return leader.get_team_status()
# === 后台任务方法 ===
def submit_background_task(
self,
name: str,
coro: Any,
*args,
**kwargs,
) -> dict[str, Any]:
"""提交后台任务
Args:
name: 任务名称
coro: 协程函数
*args: 位置参数
**kwargs: 关键字参数
Returns:
提交结果
"""
task_id = self._background_manager.submit_task(name, coro, *args, **kwargs)
return {
"status": "submitted",
"task_id": task_id,
"name": name,
}
def cancel_background_task(self, task_id: str) -> dict[str, Any]:
"""取消后台任务
Args:
task_id: 任务 ID
Returns:
取消结果
"""
success = self._background_manager.cancel_task(task_id)
return {
"status": "cancelled" if success else "error",
"task_id": task_id,
}
def get_background_task_status(self, task_id: str) -> dict[str, Any]:
"""获取后台任务状态
Args:
task_id: 任务 ID
Returns:
任务状态
"""
task = self._background_manager.get_task_status(task_id)
if not task:
return {"status": "error", "message": f"Task '{task_id}' not found"}
return {
"status": "found",
"task_id": task.id,
"name": task.name,
"task_status": task.status.value,
"result": task.result,
"error": task.error,
}
def list_background_tasks(self) -> dict[str, Any]:
"""列出所有后台任务
Returns:
任务列表
"""
tasks = self._background_manager.list_tasks()
return {
"status": "list",
"count": len(tasks),
"tasks": [
{
"id": t.id,
"name": t.name,
"status": t.status.value,
}
for t in tasks
],
}
# === 会话方法 ===
def create_session(
self,
user_id: str | None = None,
parent_session_id: str | None = None,
) -> dict[str, Any]:
"""创建会话
Args:
user_id: 用户 ID
parent_session_id: 父会话 ID
Returns:
创建结果
"""
session = create_agent_session(
user_id=user_id,
parent_session_id=parent_session_id,
)
self._sessions[session.session_id] = session
return {
"status": "created",
"session_id": session.session_id,
"user_id": user_id,
"parent_session_id": parent_session_id,
}
def get_session(self, session_id: str) -> AgentSession | None:
"""获取会话
Args:
session_id: 会话 ID
Returns:
AgentSession 或 None
"""
return self._sessions.get(session_id) or get_agent_session(session_id)
async def process_session_message(
self,
session_id: str,
message: str,
response: str,
) -> dict[str, Any]:
"""处理会话消息
Args:
session_id: 会话 ID
message: 用户消息
response: 助手响应
Returns:
处理结果
"""
session = self.get_session(session_id)
if not session:
return {"status": "error", "message": f"Session '{session_id}' not found"}
await session.process_message(message, response)
return {
"status": "processed",
"session_id": session_id,
"message_count": session.context.message_count,
}
async def spawn_child_session(
self,
session_id: str,
user_id: str | None = None,
) -> dict[str, Any]:
"""创建子会话
Args:
session_id: 父会话 ID
user_id: 用户 ID
Returns:
创建结果
"""
session = self.get_session(session_id)
if not session:
return {"status": "error", "message": f"Session '{session_id}' not found"}
child = await session.spawn_child_session(user_id=user_id)
self._sessions[child.session_id] = child
return {
"status": "spawned",
"parent_session_id": session_id,
"child_session_id": child.session_id,
"depth": child.context.depth,
}
def get_session_summary(self, session_id: str) -> dict[str, Any]:
"""获取会话摘要
Args:
session_id: 会话 ID
Returns:
会话摘要
"""
import asyncio
session = self.get_session(session_id)
if not session:
return {"status": "error", "message": f"Session '{session_id}' not found"}
# get_session_summary is async, so we need to run it
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# Create a future
future = asyncio.ensure_future(session.get_session_summary())
return {"status": "found", "summary": future}
else:
return {
"status": "found",
"summary": loop.run_until_complete(session.get_session_summary()),
}
except RuntimeError:
# No event loop, create one
return {"status": "found", "summary": asyncio.run(session.get_session_summary())}
# === 远程传输方法 ===
def register_remote_handler(self, event_type: str, handler: Any) -> None:
"""注册远程消息处理器
Args:
event_type: 事件类型
handler: 处理函数
"""
self._remote_transport.register_handler(event_type, handler)
async def send_remote_response(
self,
session_id: str,
response: dict[str, Any],
) -> bool:
"""发送远程响应
Args:
session_id: 会话 ID
response: 响应数据
Returns:
是否发送成功
"""
return await self._remote_transport.send_response(session_id, response)
async def send_remote_event(
self,
session_id: str,
event: dict[str, Any],
) -> bool:
"""发送远程事件
Args:
session_id: 会话 ID
event: 事件数据
Returns:
是否发送成功
"""
return await self._remote_transport.send_event(session_id, event)
async def send_remote_tool_call(
self,
session_id: str,
tool_call: dict[str, Any],
) -> bool:
"""发送远程工具调用
Args:
session_id: 会话 ID
tool_call: 工具调用数据
Returns:
是否发送成功
"""
return await self._remote_transport.send_tool_call(session_id, tool_call)
# === 统一协调入口 ===
async def coordinate(self, request: dict[str, Any]) -> dict[str, Any]:
"""统一协调入口
根据请求类型协调各类 Agent 操作。
Args:
request: 请求数据,包含:
- action: 操作类型 (team_create, team_assign, task_submit, session_create, etc.)
- 其他参数根据 action 不同而不同
Returns:
协调结果
"""
action = request.get("action")
if action == "team_create":
return self.create_team(
team_id=request["team_id"],
members=request["members"],
)
elif action == "team_assign":
return self.assign_task(
team_id=request["team_id"],
description=request["description"],
member=request["member"],
)
elif action == "team_broadcast":
return self.broadcast_task(
team_id=request["team_id"],
description=request["description"],
)
elif action == "team_collect":
return self.collect_team_results(team_id=request["team_id"])
elif action == "team_status":
return self.get_team_status(team_id=request["team_id"])
elif action == "task_submit":
return self.submit_background_task(
name=request["name"],
coro=request["coro"],
*request.get("args", []),
**request.get("kwargs", {}),
)
elif action == "task_cancel":
return self.cancel_background_task(task_id=request["task_id"])
elif action == "task_status":
return self.get_background_task_status(task_id=request["task_id"])
elif action == "session_create":
return self.create_session(
user_id=request.get("user_id"),
parent_session_id=request.get("parent_session_id"),
)
elif action == "session_message":
return await self.process_session_message(
session_id=request["session_id"],
message=request["message"],
response=request["response"],
)
elif action == "session_spawn":
return await self.spawn_child_session(
session_id=request["session_id"],
user_id=request.get("user_id"),
)
elif action == "session_summary":
return self.get_session_summary(session_id=request["session_id"])
else:
return {"status": "error", "message": f"Unknown action: {action}"}
# 全局单例
_coordinator: AgentCoordinator | None = None
def get_agent_coordinator() -> AgentCoordinator:
"""获取全局 Agent 协调整器"""
global _coordinator
if _coordinator is None:
_coordinator = AgentCoordinator()
return _coordinator

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,14 @@
from app.agents.isolation.session_isolation import prepare_session_isolation
from app.agents.isolation.strategy_selector import IsolationDecision, select_isolation_strategy
from app.agents.isolation.worktree_isolation import (
WorktreeIsolationError,
prepare_worktree_isolation,
)
__all__ = [
"IsolationDecision",
"WorktreeIsolationError",
"prepare_session_isolation",
"prepare_worktree_isolation",
"select_isolation_strategy",
]

View File

@@ -0,0 +1,31 @@
from __future__ import annotations
from typing import Any
from uuid import uuid4
from app.agents.isolation.strategy_selector import IsolationDecision
def prepare_session_isolation(
*,
state: dict[str, Any],
decision: IsolationDecision,
role_value: str,
sub_commander: str,
) -> dict[str, Any]:
isolation_id = f"session-{uuid4().hex[:8]}"
return {
"mode": "session",
"isolation_id": isolation_id,
"workspace_path": None,
"parent_conversation_id": str(state.get("conversation_id") or "") or None,
"metadata": {
**dict(decision.metadata or {}),
"reason": decision.reason,
"role": role_value,
"sub_commander": sub_commander,
"tool_names": list(decision.tool_names),
"capability_ids": list(decision.capability_ids),
"status": "active",
},
}

View File

@@ -0,0 +1,147 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Literal
from app.agents.registry import load_builtin_registry_indexes
from app.agents.registry.models import CapabilityManifest, PermissionClass, SideEffectScope
IsolationMode = Literal["none", "session", "worktree"]
_WORKTREE_QUERY_MARKERS = (
"code",
"repo",
"repository",
"git",
"worktree",
"branch",
"patch",
"diff",
"refactor",
"build",
"test",
"fix",
"file",
"files",
"python",
"typescript",
"javascript",
"代码",
"仓库",
"分支",
"补丁",
"重构",
"构建",
"测试",
"修复",
"文件",
)
@dataclass(frozen=True)
class IsolationDecision:
mode: IsolationMode
reason: str
tool_names: tuple[str, ...] = ()
capability_ids: tuple[str, ...] = ()
metadata: dict[str, Any] = field(default_factory=dict)
def _capability_metadata(capability: CapabilityManifest | None) -> dict[str, Any]:
if capability is None:
return {}
return {
"capability_id": capability.capability_id,
"tool_name": capability.tool_name,
"permission_class": capability.permission_class.value,
"side_effect_scope": capability.side_effect_scope.value,
"supports_retry": capability.supports_retry,
"idempotent": capability.idempotent,
"safe_for_parallel_use": capability.safe_for_parallel_use,
"requires_confirmation": capability.requires_confirmation,
}
def select_isolation_strategy(
*,
user_query: str,
tool_names: list[str] | tuple[str, ...],
role_value: str,
execution_mode: str | None,
) -> IsolationDecision:
indexes = load_builtin_registry_indexes()
capabilities: list[CapabilityManifest] = []
capability_ids: list[str] = []
for tool_name in tool_names:
capability_id = indexes.capability_id_by_tool_name.get(tool_name)
capability = indexes.capability_by_id.get(capability_id) if capability_id else None
if capability is not None:
capabilities.append(capability)
capability_ids.append(capability.capability_id)
normalized_query = (user_query or "").strip().lower()
has_worktree_query_signal = any(marker in normalized_query for marker in _WORKTREE_QUERY_MARKERS)
has_write_capability = any(cap.permission_class == PermissionClass.WRITE for cap in capabilities)
has_external_capability = any(cap.permission_class == PermissionClass.EXTERNAL for cap in capabilities)
has_non_parallel_capability = any(not cap.safe_for_parallel_use for cap in capabilities)
has_stateful_side_effect = any(
cap.side_effect_scope in {SideEffectScope.LOCAL_STATE, SideEffectScope.DB_WRITE}
for cap in capabilities
)
metadata = {
"role": role_value,
"execution_mode": execution_mode,
"capabilities": [_capability_metadata(capability) for capability in capabilities],
"workspace_strategy": "inline",
"risk_level": "low",
}
if has_worktree_query_signal:
return IsolationDecision(
mode="worktree",
reason="workspace_mutation_signals_detected",
tool_names=tuple(tool_names),
capability_ids=tuple(capability_ids),
metadata={
**metadata,
"workspace_strategy": "ephemeral_worktree",
"risk_level": "high",
},
)
if has_write_capability or has_stateful_side_effect or has_non_parallel_capability:
return IsolationDecision(
mode="session",
reason="stateful_or_non_parallel_tooling",
tool_names=tuple(tool_names),
capability_ids=tuple(capability_ids),
metadata={
**metadata,
"workspace_strategy": "isolated_session",
"risk_level": "medium",
},
)
if execution_mode == "collaboration" or role_value in {"analyst", "librarian"} or has_external_capability:
return IsolationDecision(
mode="session",
reason="context_heavy_or_external_retrieval",
tool_names=tuple(tool_names),
capability_ids=tuple(capability_ids),
metadata={
**metadata,
"workspace_strategy": "isolated_session",
"risk_level": "medium",
},
)
return IsolationDecision(
mode="none",
reason="inline_execution_is_sufficient",
tool_names=tuple(tool_names),
capability_ids=tuple(capability_ids),
metadata=metadata,
)

View File

@@ -0,0 +1,83 @@
from __future__ import annotations
import re
import subprocess
from pathlib import Path
from typing import Any
from uuid import uuid4
from app.agents.isolation.strategy_selector import IsolationDecision
class WorktreeIsolationError(RuntimeError):
pass
def _slugify(value: str, *, fallback: str) -> str:
slug = re.sub(r"[^a-zA-Z0-9._-]+", "-", (value or "").strip()).strip("-").lower()
return slug or fallback
def _resolve_git_root() -> Path:
try:
result = subprocess.run(
["git", "rev-parse", "--show-toplevel"],
check=True,
capture_output=True,
text=True,
)
except subprocess.CalledProcessError as exc:
raise WorktreeIsolationError(exc.stderr.strip() or exc.stdout.strip() or "git_root_unavailable") from exc
git_root = Path(result.stdout.strip())
if not git_root.exists():
raise WorktreeIsolationError("git_root_not_found")
return git_root
def prepare_worktree_isolation(
*,
state: dict[str, Any],
decision: IsolationDecision,
role_value: str,
sub_commander: str,
create_workspace: bool = True,
) -> dict[str, Any]:
isolation_id = f"worktree-{uuid4().hex[:8]}"
conversation_slug = _slugify(str(state.get("conversation_id") or "conversation"), fallback="conversation")
role_slug = _slugify(role_value, fallback="agent")
git_root = _resolve_git_root()
workspace_root = git_root / ".worktrees" / "jarvis" / conversation_slug
workspace_path = workspace_root / f"{role_slug}-{isolation_id}"
branch = f"jarvis/{conversation_slug}/{role_slug}-{isolation_id}"
if create_workspace and not workspace_path.exists():
workspace_root.mkdir(parents=True, exist_ok=True)
try:
subprocess.run(
["git", "-C", str(git_root), "worktree", "add", "-b", branch, str(workspace_path), "HEAD"],
check=True,
capture_output=True,
text=True,
)
except subprocess.CalledProcessError as exc:
raise WorktreeIsolationError(exc.stderr.strip() or exc.stdout.strip() or "worktree_add_failed") from exc
return {
"mode": "worktree",
"isolation_id": isolation_id,
"workspace_path": str(workspace_path),
"parent_conversation_id": str(state.get("conversation_id") or "") or None,
"metadata": {
**dict(decision.metadata or {}),
"reason": decision.reason,
"role": role_value,
"sub_commander": sub_commander,
"tool_names": list(decision.tool_names),
"capability_ids": list(decision.capability_ids),
"repo_root": str(git_root),
"branch": branch,
"workspace_strategy": "ephemeral_worktree",
"cleanup_status": "pending",
"materialized": workspace_path.exists(),
},
}

View File

@@ -0,0 +1,20 @@
"""高级编排系统 - Phase 10"""
from app.agents.team.leader import TeamLeader, TeamTask, TaskStatus
from app.agents.transport.remote import RemoteTransport, StructuredMessage
from app.agents.background.manager import (
BackgroundTaskManager,
BackgroundTask,
get_background_task_manager,
)
__all__ = [
"TeamLeader",
"TeamTask",
"TaskStatus",
"RemoteTransport",
"StructuredMessage",
"BackgroundTaskManager",
"BackgroundTask",
"get_background_task_manager",
]

View File

@@ -0,0 +1,12 @@
"""插件系统 - Phase 8"""
from app.agents.plugins.manager import PluginManager, get_plugin_manager
from app.agents.plugins.manifest import PluginManifest
from app.agents.plugins.sandbox import PluginSandbox
__all__ = [
"PluginManager",
"PluginManifest",
"PluginSandbox",
"get_plugin_manager",
]

View File

@@ -0,0 +1,19 @@
"""Code Helper Plugin - Linting, formatting, and code explanation tools"""
def lint_file(file_path: str) -> dict:
"""Lint a source file and return issues found."""
return {"status": "ok", "tool": "lint_file", "result": f"Linting {file_path}"}
def format_file(file_path: str) -> dict:
"""Format a source file and return the result."""
return {"status": "ok", "tool": "format_file", "result": f"Formatting {file_path}"}
def explain_code(code_snippet: str) -> dict:
"""Explain a code snippet and return the explanation."""
return {"status": "ok", "tool": "explain_code", "result": f"Explaining code snippet"}
tools = [lint_file, format_file, explain_code]

View File

@@ -0,0 +1,22 @@
{
"id": "code_helper",
"name": "Code Helper",
"version": "1.0.0",
"description": "Code linting, formatting, and explanation tools",
"author": "",
"homepage": "",
"license": "MIT",
"plugin_type": "tool",
"main": "__init__.py",
"hooks": [],
"tools": ["lint_file", "format_file", "explain_code"],
"skills": [],
"dependencies": {},
"peer_dependencies": {},
"permissions": [],
"allowed_paths": [],
"denied_paths": [],
"network_allowed": false,
"allowed_hosts": [],
"config_schema": {}
}

View File

@@ -0,0 +1,18 @@
"""File Organizer Plugin - File organization and duplicate detection tools"""
def organize_by_type(directory: str) -> dict:
"""Organize files in a directory by file type."""
return {"status": "ok", "tool": "organize_by_type", "result": f"Organizing {directory} by type"}
def find_duplicates(directory: str) -> dict:
"""Find duplicate files in a directory."""
return {
"status": "ok",
"tool": "find_duplicates",
"result": f"Finding duplicates in {directory}",
}
tools = [organize_by_type, find_duplicates]

View File

@@ -0,0 +1,22 @@
{
"id": "file_organizer",
"name": "File Organizer",
"version": "1.0.0",
"description": "File organization and duplicate detection tools",
"author": "",
"homepage": "",
"license": "MIT",
"plugin_type": "tool",
"main": "__init__.py",
"hooks": [],
"tools": ["organize_by_type", "find_duplicates"],
"skills": [],
"dependencies": {},
"peer_dependencies": {},
"permissions": [],
"allowed_paths": [],
"denied_paths": [],
"network_allowed": false,
"allowed_hosts": [],
"config_schema": {}
}

View File

@@ -0,0 +1,23 @@
"""Git Helper Plugin - Git status, log, and diff summary tools"""
def git_status_summary() -> dict:
"""Get a summary of git status."""
return {"status": "ok", "tool": "git_status_summary", "result": "Git status summary"}
def git_log_summary(limit: int = 10) -> dict:
"""Get a summary of recent git commits."""
return {"status": "ok", "tool": "git_log_summary", "result": f"Git log summary (limit={limit})"}
def git_diff_summary(ref1: str = "HEAD", ref2: str = "HEAD~1") -> dict:
"""Get a summary of changes between two refs."""
return {
"status": "ok",
"tool": "git_diff_summary",
"result": f"Git diff summary ({ref1}..{ref2})",
}
tools = [git_status_summary, git_log_summary, git_diff_summary]

View File

@@ -0,0 +1,22 @@
{
"id": "git_helper",
"name": "Git Helper",
"version": "1.0.0",
"description": "Git status, log, and diff summary tools",
"author": "",
"homepage": "",
"license": "MIT",
"plugin_type": "tool",
"main": "__init__.py",
"hooks": [],
"tools": ["git_status_summary", "git_log_summary", "git_diff_summary"],
"skills": [],
"dependencies": {},
"peer_dependencies": {},
"permissions": [],
"allowed_paths": [],
"denied_paths": [],
"network_allowed": false,
"allowed_hosts": [],
"config_schema": {}
}

View File

@@ -0,0 +1,14 @@
"""Web Helper Plugin - Web fetching and HTML parsing tools"""
def fetch_url_content(url: str) -> dict:
"""Fetch content from a URL."""
return {"status": "ok", "tool": "fetch_url_content", "result": f"Fetching {url}"}
def parse_html_links(html_content: str) -> dict:
"""Parse HTML content and extract links."""
return {"status": "ok", "tool": "parse_html_links", "result": "Extracted links from HTML"}
tools = [fetch_url_content, parse_html_links]

View File

@@ -0,0 +1,22 @@
{
"id": "web_helper",
"name": "Web Helper",
"version": "1.0.0",
"description": "Web fetching and HTML parsing tools",
"author": "",
"homepage": "",
"license": "MIT",
"plugin_type": "tool",
"main": "__init__.py",
"hooks": [],
"tools": ["fetch_url_content", "parse_html_links"],
"skills": [],
"dependencies": {},
"peer_dependencies": {},
"permissions": [],
"allowed_paths": [],
"denied_paths": [],
"network_allowed": true,
"allowed_hosts": [],
"config_schema": {}
}

View File

@@ -0,0 +1,207 @@
"""插件管理器 - Phase 8.2"""
import importlib.util
import os
import sys
from typing import Any
from app.agents.plugins.manifest import PluginManifest
from app.agents.plugins.sandbox import PluginSandbox
class PluginManager:
"""插件管理器
负责插件的安装、卸载、启用、禁用和生命周期管理。
"""
def __init__(self, plugins_dir: str | None = None):
"""
Args:
plugins_dir: 插件目录None 则使用默认目录
"""
if plugins_dir is None:
plugins_dir = os.path.join(os.path.dirname(__file__), "..", "..", "..", "plugins")
self.plugins_dir = plugins_dir
self._plugins: dict[str, PluginManifest] = {}
self._enabled: dict[str, bool] = {}
self._modules: dict[str, Any] = {}
self._sandbox = PluginSandbox()
def install(self, plugin_path: str) -> bool:
"""安装插件
Args:
plugin_path: 插件目录路径或 manifest.json 所在目录
Returns:
是否安装成功
"""
try:
manifest_path = os.path.join(plugin_path, "manifest.json")
if not os.path.exists(manifest_path):
return False
with open(manifest_path, "r", encoding="utf-8") as f:
import json
data = json.load(f)
manifest = PluginManifest.from_dict(data)
# 验证 manifest
if not self._validate_manifest(manifest, plugin_path):
return False
# 复制插件到 plugins_dir
target_dir = os.path.join(self.plugins_dir, manifest.id)
os.makedirs(os.path.dirname(target_dir), exist_ok=True)
# 保存 manifest
with open(os.path.join(target_dir, "manifest.json"), "w", encoding="utf-8") as f:
json.dump(manifest.to_dict(), f, indent=2, ensure_ascii=False)
# 注册插件
self._plugins[manifest.id] = manifest
self._enabled[manifest.id] = True
return True
except Exception:
return False
def uninstall(self, plugin_id: str) -> bool:
"""卸载插件
Args:
plugin_id: 插件 ID
Returns:
是否卸载成功
"""
if plugin_id not in self._plugins:
return False
# 禁用插件
self.disable(plugin_id)
# 移除模块
if plugin_id in self._modules:
del self._modules[plugin_id]
# 移除插件
del self._plugins[plugin_id]
del self._enabled[plugin_id]
# 删除目录
plugin_dir = os.path.join(self.plugins_dir, plugin_id)
if os.path.exists(plugin_dir):
import shutil
shutil.rmtree(plugin_dir)
return True
def enable(self, plugin_id: str) -> bool:
"""启用插件
Args:
plugin_id: 插件 ID
Returns:
是否启用成功
"""
if plugin_id not in self._plugins:
return False
self._enabled[plugin_id] = True
return True
def disable(self, plugin_id: str) -> bool:
"""禁用插件
Args:
plugin_id: 插件 ID
Returns:
是否禁用成功
"""
if plugin_id not in self._plugins:
return False
self._enabled[plugin_id] = False
return True
def reload(self, plugin_id: str) -> bool:
"""重新加载插件
Args:
plugin_id: 插件 ID
Returns:
是否重新加载成功
"""
if plugin_id not in self._plugins:
return False
# 卸载模块
if plugin_id in self._modules:
del self._modules[plugin_id]
# 重新加载
return self._load_plugin_module(plugin_id)
def list_plugins(self) -> list[PluginManifest]:
"""列出所有插件"""
return list(self._plugins.values())
def get_plugin(self, plugin_id: str) -> PluginManifest | None:
"""获取插件清单"""
return self._plugins.get(plugin_id)
def is_enabled(self, plugin_id: str) -> bool:
"""检查插件是否启用"""
return self._enabled.get(plugin_id, False)
def _validate_manifest(self, manifest: PluginManifest, plugin_path: str) -> bool:
"""验证 manifest"""
# 检查主入口文件是否存在
main_path = os.path.join(plugin_path, manifest.main)
if not os.path.exists(main_path):
return False
return True
def _load_plugin_module(self, plugin_id: str) -> bool:
"""加载插件模块"""
plugin_dir = os.path.join(self.plugins_dir, plugin_id)
manifest = self._plugins.get(plugin_id)
if not manifest:
return False
try:
main_path = os.path.join(plugin_dir, manifest.main)
spec = importlib.util.spec_from_file_location(plugin_id, main_path)
if spec and spec.loader:
module = importlib.util.module_from_spec(spec)
sys.modules[plugin_id] = module
spec.loader.exec_module(module)
self._modules[plugin_id] = module
return True
except Exception:
pass
return False
# 全局单例
_manager: PluginManager | None = None
def get_plugin_manager() -> PluginManager:
"""获取全局插件管理器"""
global _manager
if _manager is None:
_manager = PluginManager()
return _manager

View File

@@ -0,0 +1,73 @@
"""插件清单定义 - Phase 8.1"""
from dataclasses import dataclass, field
from typing import Any
@dataclass
class PluginManifest:
"""插件清单
定义插件的元数据和接口。
"""
id: str # 唯一标识
name: str # 显示名称
version: str # 版本号
description: str # 描述
author: str = "" # 作者
homepage: str = "" # 主页
license: str = "MIT" # 许可证
# 插件类型
plugin_type: str = "tool" # tool, hook, skill, all
# 入口点
main: str = "index.py" # 主入口文件
hooks: list[str] = field(default_factory=list) # 提供的 Hook 列表
tools: list[str] = field(default_factory=list) # 提供的工具列表
skills: list[str] = field(default_factory=list) # 提供的 Skills 列表
# 依赖
dependencies: dict[str, str] = field(default_factory=dict) # pip 依赖
peer_dependencies: dict[str, str] = field(default_factory=dict) # 对等依赖
# 权限要求
permissions: list[str] = field(default_factory=list) # 需要的权限
allowed_paths: list[str] = field(default_factory=list) # 允许访问的路径
denied_paths: list[str] = field(default_factory=list) # 禁止访问的路径
# 网络权限
network_allowed: bool = False # 是否允许网络访问
allowed_hosts: list[str] = field(default_factory=list) # 允许访问的 host
# 配置
config_schema: dict[str, Any] = field(default_factory=dict) # 配置 schema
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"name": self.name,
"version": self.version,
"description": self.description,
"author": self.author,
"homepage": self.homepage,
"license": self.license,
"plugin_type": self.plugin_type,
"main": self.main,
"hooks": self.hooks,
"tools": self.tools,
"skills": self.skills,
"dependencies": self.dependencies,
"peer_dependencies": self.peer_dependencies,
"permissions": self.permissions,
"allowed_paths": self.allowed_paths,
"denied_paths": self.denied_paths,
"network_allowed": self.network_allowed,
"allowed_hosts": self.allowed_hosts,
"config_schema": self.config_schema,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "PluginManifest":
return cls(**data)

View File

@@ -0,0 +1,111 @@
"""插件沙箱隔离 - Phase 8.3"""
import os
import sys
from typing import Any
class PluginSandbox:
"""插件沙箱
提供插件执行隔离环境。
"""
def __init__(self):
self._allowed_paths: set[str] = set()
self._denied_paths: set[str] = set()
self._network_allowed: bool = False
self._allowed_hosts: set[str] = set()
def set_file_permissions(
self,
allowed_paths: list[str] | None = None,
denied_paths: list[str] | None = None,
) -> None:
"""设置文件访问权限
Args:
allowed_paths: 允许访问的路径列表
denied_paths: 禁止访问的路径列表
"""
self._allowed_paths = set(allowed_paths or [])
self._denied_paths = set(denied_paths or [])
def set_network_permissions(
self, allowed: bool, allowed_hosts: list[str] | None = None
) -> None:
"""设置网络访问权限
Args:
allowed: 是否允许网络访问
allowed_hosts: 允许访问的 host 列表
"""
self._network_allowed = allowed
self._allowed_hosts = set(allowed_hosts or [])
def check_file_access(self, path: str) -> bool:
"""检查文件访问权限
Args:
path: 文件路径
Returns:
是否允许访问
"""
# 如果有允许列表,只允许访问列表中的路径
if self._allowed_paths:
return path in self._allowed_paths or any(
path.startswith(allowed) for allowed in self._allowed_paths
)
# 如果有禁止列表,禁止访问列表中的路径
if self._denied_paths:
return not any(path.startswith(denied) for denied in self._denied_paths)
# 没有限制
return True
def check_network_access(self, host: str) -> bool:
"""检查网络访问权限
Args:
host: 主机地址
Returns:
是否允许访问
"""
if not self._network_allowed:
return False
if self._allowed_hosts:
return host in self._allowed_hosts or any(
host.endswith(allowed) for allowed in self._allowed_hosts
)
return True
def execute_in_sandbox(self, func: Any, *args, **kwargs) -> Any:
"""在沙箱中执行函数
Args:
func: 要执行的函数
*args: 位置参数
**kwargs: 关键字参数
Returns:
函数返回值
"""
# 保存当前状态
old_allowed_paths = self._allowed_paths.copy()
old_denied_paths = self._denied_paths.copy()
old_network_allowed = self._network_allowed
old_allowed_hosts = self._allowed_hosts.copy()
try:
return func(*args, **kwargs)
finally:
# 恢复状态
self._allowed_paths = old_allowed_paths
self._denied_paths = old_denied_paths
self._network_allowed = old_network_allowed
self._allowed_hosts = old_allowed_hosts

View File

@@ -324,6 +324,38 @@ ANALYST_INSIGHTS_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
"""
COORDINATOR_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
你是 Jarvis 的协作协调官,负责把复杂请求收束成最小受控协作,而不是放任系统进入自由 swarm。
## 你的职责:
- 先判断当前请求是否真的需要拆解;不需要时应明确建议继续走 direct
- 只有在明显多步骤、跨领域、需要多角色配合时,才拆成 2~4 个子任务
- 每个子任务必须清晰写出 `title`、`role`、`goal`、`expected_evidence`
- 角色建议只能来自现有 top-level agent`schedule_planner`、`librarian`、`analyst`、`executor`
- 汇总时基于子任务结果回收,不依赖单点硬编码拼接
## 边界:
- 禁止无限递归拆分
- 禁止创建新的 runtime agent / worker
- 禁止把一个简单请求硬拆成多个空泛步骤
- 如果证据不足、子任务未闭环,必须把风险明确暴露出来
"""
VERIFIER_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
你是 Jarvis 的验证官,负责对执行结果做最小但明确的核验。
## 你的职责:
- 只输出 passed、failed、skipped 三种验证结论之一
- 用一句话总结验证判断
- 如有证据,保留关键证据点
- 当信息不足以证明成功或失败时,优先判定为 skipped
- 不重写执行方案,不扩展无关建议
"""
JSON_ACTION_FALLBACK_PROMPT = """你当前运行在 JSON action fallback 模式。
你的输出必须满足以下规则:

View File

@@ -1,11 +1,19 @@
"""Registry manifest models and validation helpers."""
from functools import lru_cache
from app.agents.registry.indexes import RegistryIndexes, build_registry_indexes
from app.agents.registry.loader import RegistryBundle, load_builtin_registry_bundle
@lru_cache(maxsize=1)
def load_builtin_registry_indexes() -> RegistryIndexes:
return build_registry_indexes(load_builtin_registry_bundle())
__all__ = [
"RegistryBundle",
"RegistryIndexes",
"build_registry_indexes",
"load_builtin_registry_bundle",
"load_builtin_registry_indexes",
]

View File

@@ -2,6 +2,8 @@ from app.agents.prompts import SUB_COMMANDER_PROMPTS_BY_KEY
from app.agents.registry.models import (
AgentManifest,
CapabilityManifest,
PermissionClass,
SideEffectScope,
SpecialistTemplateManifest,
SubCommanderManifest,
)
@@ -55,6 +57,19 @@ TOP_LEVEL_AGENT_ROUTING_HINTS: dict[str, tuple[str, ...]] = {
),
}
TOP_LEVEL_AGENT_ALLOWED_SPAWN_ROLES: dict[str, tuple[str, ...]] = {
AgentRole.MASTER.value: (
AgentRole.SCHEDULE_PLANNER.value,
AgentRole.EXECUTOR.value,
AgentRole.LIBRARIAN.value,
AgentRole.ANALYST.value,
),
AgentRole.SCHEDULE_PLANNER.value: (AgentRole.SCHEDULE_PLANNER.value,),
AgentRole.EXECUTOR.value: (AgentRole.EXECUTOR.value,),
AgentRole.LIBRARIAN.value: (AgentRole.LIBRARIAN.value,),
AgentRole.ANALYST.value: (AgentRole.ANALYST.value,),
}
SUB_COMMANDER_PARENT_AGENT_IDS: dict[str, str] = {
"schedule_analysis": AgentRole.SCHEDULE_PLANNER.value,
"schedule_planning": AgentRole.SCHEDULE_PLANNER.value,
@@ -75,6 +90,8 @@ BUILTIN_AGENT_MANIFESTS: tuple[AgentManifest, ...] = tuple(
system_prompt_key=role.value,
routing_hints=list(TOP_LEVEL_AGENT_ROUTING_HINTS[role.value]),
default_sub_commanders=list(TOP_LEVEL_AGENT_DEFAULT_SUB_COMMANDERS[role.value]),
can_spawn_children=bool(TOP_LEVEL_AGENT_ALLOWED_SPAWN_ROLES[role.value]),
allowed_spawn_role_values=list(TOP_LEVEL_AGENT_ALLOWED_SPAWN_ROLES[role.value]),
skill_context_key=role.value.replace("agent_", ""),
)
for role in AgentRole
@@ -89,10 +106,150 @@ _capability_tool_names = tuple(
)
)
_CAPABILITY_METADATA_BY_TOOL_NAME: dict[str, dict[str, object]] = {
"get_tasks": {
"permission_class": PermissionClass.READ,
"side_effect_scope": SideEffectScope.NONE,
"supports_retry": True,
"idempotent": True,
"safe_for_parallel_use": True,
"requires_confirmation": False,
},
"get_schedule_day": {
"permission_class": PermissionClass.READ,
"side_effect_scope": SideEffectScope.NONE,
"supports_retry": True,
"idempotent": True,
"safe_for_parallel_use": True,
"requires_confirmation": False,
},
"resolve_time_expression": {
"permission_class": PermissionClass.READ,
"side_effect_scope": SideEffectScope.NONE,
"supports_retry": True,
"idempotent": True,
"safe_for_parallel_use": True,
"requires_confirmation": False,
},
"search_knowledge": {
"permission_class": PermissionClass.READ,
"side_effect_scope": SideEffectScope.NONE,
"supports_retry": True,
"idempotent": True,
"safe_for_parallel_use": True,
"requires_confirmation": False,
},
"hybrid_search": {
"permission_class": PermissionClass.READ,
"side_effect_scope": SideEffectScope.NONE,
"supports_retry": True,
"idempotent": True,
"safe_for_parallel_use": True,
"requires_confirmation": False,
},
"get_knowledge_graph_context": {
"permission_class": PermissionClass.READ,
"side_effect_scope": SideEffectScope.NONE,
"supports_retry": True,
"idempotent": True,
"safe_for_parallel_use": True,
"requires_confirmation": False,
},
"get_forum_posts": {
"permission_class": PermissionClass.READ,
"side_effect_scope": SideEffectScope.NONE,
"supports_retry": True,
"idempotent": True,
"safe_for_parallel_use": True,
"requires_confirmation": False,
},
"scan_forum_for_instructions": {
"permission_class": PermissionClass.READ,
"side_effect_scope": SideEffectScope.NONE,
"supports_retry": True,
"idempotent": True,
"safe_for_parallel_use": True,
"requires_confirmation": False,
},
"web_search": {
"permission_class": PermissionClass.EXTERNAL,
"side_effect_scope": SideEffectScope.NETWORK,
"supports_retry": True,
"idempotent": True,
"safe_for_parallel_use": True,
"requires_confirmation": False,
},
"create_task": {
"permission_class": PermissionClass.WRITE,
"side_effect_scope": SideEffectScope.LOCAL_STATE,
"supports_retry": False,
"idempotent": False,
"safe_for_parallel_use": False,
"requires_confirmation": True,
},
"update_task_status": {
"permission_class": PermissionClass.WRITE,
"side_effect_scope": SideEffectScope.LOCAL_STATE,
"supports_retry": False,
"idempotent": False,
"safe_for_parallel_use": False,
"requires_confirmation": True,
},
"create_todo": {
"permission_class": PermissionClass.WRITE,
"side_effect_scope": SideEffectScope.LOCAL_STATE,
"supports_retry": False,
"idempotent": False,
"safe_for_parallel_use": False,
"requires_confirmation": True,
},
"create_schedule_task": {
"permission_class": PermissionClass.WRITE,
"side_effect_scope": SideEffectScope.LOCAL_STATE,
"supports_retry": False,
"idempotent": False,
"safe_for_parallel_use": False,
"requires_confirmation": True,
},
"create_reminder": {
"permission_class": PermissionClass.WRITE,
"side_effect_scope": SideEffectScope.LOCAL_STATE,
"supports_retry": False,
"idempotent": False,
"safe_for_parallel_use": False,
"requires_confirmation": True,
},
"create_goal": {
"permission_class": PermissionClass.WRITE,
"side_effect_scope": SideEffectScope.LOCAL_STATE,
"supports_retry": False,
"idempotent": False,
"safe_for_parallel_use": False,
"requires_confirmation": True,
},
"create_forum_post": {
"permission_class": PermissionClass.WRITE,
"side_effect_scope": SideEffectScope.LOCAL_STATE,
"supports_retry": False,
"idempotent": False,
"safe_for_parallel_use": False,
"requires_confirmation": True,
},
"build_knowledge_graph": {
"permission_class": PermissionClass.WRITE,
"side_effect_scope": SideEffectScope.LOCAL_STATE,
"supports_retry": False,
"idempotent": False,
"safe_for_parallel_use": False,
"requires_confirmation": True,
},
}
BUILTIN_CAPABILITY_MANIFESTS: tuple[CapabilityManifest, ...] = tuple(
CapabilityManifest(
capability_id=tool_name,
tool_name=tool_name,
**dict(_CAPABILITY_METADATA_BY_TOOL_NAME.get(tool_name, {})),
)
for tool_name in _capability_tool_names
)

View File

@@ -16,6 +16,7 @@ from app.agents.registry.models import (
@dataclass(frozen=True)
class RegistryIndexes:
agent_by_id: Mapping[str, AgentManifest]
agent_by_role_value: Mapping[str, AgentManifest]
sub_commander_by_id: Mapping[str, SubCommanderManifest]
capability_by_id: Mapping[str, CapabilityManifest]
specialist_template_by_id: Mapping[str, SpecialistTemplateManifest]
@@ -24,6 +25,7 @@ class RegistryIndexes:
skill_context_key_by_agent_id: Mapping[str, str]
capability_id_by_tool_name: Mapping[str, str]
capability_ids_by_sub_commander_id: Mapping[str, tuple[str, ...]]
spawnable_role_values_by_agent_id: Mapping[str, tuple[str, ...]]
def summarize_registry_indexes(indexes: RegistryIndexes) -> dict[str, int]:
@@ -50,6 +52,9 @@ def build_registry_indexes(bundle: RegistryBundle) -> RegistryIndexes:
return RegistryIndexes(
agent_by_id=MappingProxyType(agent_by_id),
agent_by_role_value=MappingProxyType({
agent.role_value: agent for agent in bundle.agents
}),
sub_commander_by_id=MappingProxyType(sub_commander_by_id),
capability_by_id=MappingProxyType(capability_by_id),
specialist_template_by_id=MappingProxyType(specialist_template_by_id),
@@ -73,4 +78,9 @@ def build_registry_indexes(bundle: RegistryBundle) -> RegistryIndexes:
sub_commander.sub_commander_id: tuple(sub_commander.capability_ids)
for sub_commander in bundle.sub_commanders
}),
spawnable_role_values_by_agent_id=MappingProxyType({
agent.agent_id: tuple(agent.allowed_spawn_role_values)
for agent in bundle.agents
if agent.can_spawn_children and agent.allowed_spawn_role_values
}),
)

View File

@@ -1,4 +1,19 @@
from pydantic import BaseModel
from enum import Enum
from pydantic import BaseModel, Field
class PermissionClass(str, Enum):
READ = "read"
WRITE = "write"
EXTERNAL = "external"
class SideEffectScope(str, Enum):
NONE = "none"
LOCAL_STATE = "local_state"
DB_WRITE = "db_write"
NETWORK = "network"
class AgentManifest(BaseModel):
@@ -8,6 +23,8 @@ class AgentManifest(BaseModel):
system_prompt_key: str
routing_hints: list[str]
default_sub_commanders: list[str]
can_spawn_children: bool = False
allowed_spawn_role_values: list[str] = Field(default_factory=list)
skill_context_key: str | None = None
continuity_policy: str | None = None
clarification_policy: str | None = None
@@ -23,6 +40,12 @@ class SubCommanderManifest(BaseModel):
class CapabilityManifest(BaseModel):
capability_id: str
tool_name: str
permission_class: PermissionClass = PermissionClass.READ
side_effect_scope: SideEffectScope = SideEffectScope.NONE
supports_retry: bool = False
idempotent: bool = False
safe_for_parallel_use: bool = False
requires_confirmation: bool = False
class SpecialistTemplateManifest(BaseModel):

View File

@@ -0,0 +1,86 @@
from __future__ import annotations
from typing import Any
INPUT_TOKEN_USD_RATE = 0.000003
OUTPUT_TOKEN_USD_RATE = 0.000015
DEFAULT_COST_THRESHOLDS = {
"total_tokens": 4000,
"estimated_cost": 0.02,
}
def estimate_token_cost(input_tokens: int, output_tokens: int) -> float | None:
total_tokens = max(input_tokens, 0) + max(output_tokens, 0)
if total_tokens <= 0:
return None
return round(
(max(input_tokens, 0) * INPUT_TOKEN_USD_RATE)
+ (max(output_tokens, 0) * OUTPUT_TOKEN_USD_RATE),
6,
)
def extract_token_usage(response: Any) -> tuple[int, int]:
usage_metadata = getattr(response, "usage_metadata", None) or {}
if isinstance(usage_metadata, dict):
input_tokens = int(
usage_metadata.get("input_tokens")
or usage_metadata.get("prompt_tokens")
or 0
)
output_tokens = int(
usage_metadata.get("output_tokens")
or usage_metadata.get("completion_tokens")
or 0
)
if input_tokens or output_tokens:
return input_tokens, output_tokens
response_metadata = getattr(response, "response_metadata", None) or {}
token_usage = {}
if isinstance(response_metadata, dict):
token_usage = response_metadata.get("token_usage") or response_metadata.get("usage") or {}
if isinstance(token_usage, dict):
input_tokens = int(
token_usage.get("prompt_tokens")
or token_usage.get("input_tokens")
or 0
)
output_tokens = int(
token_usage.get("completion_tokens")
or token_usage.get("output_tokens")
or 0
)
if input_tokens or output_tokens:
return input_tokens, output_tokens
return 0, 0
def coerce_cost_thresholds(raw_thresholds: Any) -> dict[str, float]:
thresholds: dict[str, float] = dict(DEFAULT_COST_THRESHOLDS)
if not isinstance(raw_thresholds, dict):
return thresholds
for key in DEFAULT_COST_THRESHOLDS:
value = raw_thresholds.get(key)
if isinstance(value, (int, float)) and value > 0:
thresholds[key] = float(value)
return thresholds
def is_cost_budget_warning(
input_tokens: int,
output_tokens: int,
estimated_cost: float | None,
thresholds: dict[str, float] | None = None,
) -> bool:
effective_thresholds = thresholds or DEFAULT_COST_THRESHOLDS
total_tokens = max(input_tokens, 0) + max(output_tokens, 0)
token_threshold = float(effective_thresholds.get("total_tokens") or 0)
cost_threshold = float(effective_thresholds.get("estimated_cost") or 0)
return (
(token_threshold > 0 and total_tokens >= token_threshold)
or (cost_threshold > 0 and estimated_cost is not None and estimated_cost >= cost_threshold)
)

View File

@@ -0,0 +1,25 @@
from app.agents.schemas.event import AgentEvent
from app.agents.schemas.message import AgentMessage
from app.agents.schemas.task import (
AgentTask,
CollaborationBudget,
InterruptRecord,
RecoveryRecord,
TaskLifecycleStatus,
TaskResult,
TaskResultStatus,
VerificationStatus,
)
__all__ = [
"AgentEvent",
"AgentMessage",
"AgentTask",
"CollaborationBudget",
"InterruptRecord",
"RecoveryRecord",
"TaskLifecycleStatus",
"TaskResult",
"TaskResultStatus",
"VerificationStatus",
]

View File

@@ -0,0 +1,52 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any, Literal
from pydantic import BaseModel, Field
AgentEventType = Literal[
"agent.tool.start",
"agent.tool.result",
"agent.verify.started",
"agent.verify.completed",
"agent.created",
"agent.spawn.blocked",
"agent.message.sent",
"agent.message.received",
"agent.interrupt.requested",
"agent.interrupt.completed",
"agent.recovery.started",
"agent.recovery.completed",
"agent.task.interrupted",
"agent.task.recovered",
"agent.task.reassigned",
"agent.collaboration.budget.updated",
"agent.isolation.selected",
"agent.isolation.fallback",
"agent.cost.updated",
"agent.cost.warning",
"agent.phase.changed",
"agent.checkpoint.recorded",
"agent.error",
]
AgentEventSeverity = Literal["info", "warning", "error"]
class AgentEvent(BaseModel):
event_id: str
event_type: AgentEventType
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
conversation_id: str | None = None
agent_id: str | None = None
sub_commander_id: str | None = None
task_id: str | None = None
parent_task_id: str | None = None
child_task_id: str | None = None
thread_id: str | None = None
message_id: str | None = None
interrupt_id: str | None = None
recovery_id: str | None = None
payload: dict[str, Any] = Field(default_factory=dict)
severity: AgentEventSeverity = "info"

View File

@@ -0,0 +1,29 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any, Literal
from pydantic import BaseModel, Field
AgentMessageType = Literal[
"task_request",
"task_update",
"handoff",
"verification_request",
"verification_feedback",
"interrupt_notice",
]
class AgentMessage(BaseModel):
message_id: str
thread_id: str
from_agent_id: str
to_agent_id: str
task_id: str | None = None
reply_to_message_id: str | None = None
message_type: AgentMessageType = "task_update"
content_summary: str
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
payload: dict[str, Any] = Field(default_factory=dict)

View File

@@ -0,0 +1,85 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any, Literal
from pydantic import BaseModel, Field
TaskLifecycleStatus = Literal["pending", "in_progress", "completed", "failed", "blocked"]
VerificationStatus = Literal["passed", "failed", "skipped"]
TaskResultStatus = Literal["completed", "failed", "blocked", "passed", "skipped"]
InterruptStatus = Literal["requested", "acknowledged", "resolved"]
BudgetMode = Literal["direct", "collaboration"]
class InterruptRecord(BaseModel):
interrupt_id: str
reason: str
status: InterruptStatus = "requested"
requested_by: str | None = None
source_event_id: str | None = None
requested_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
payload: dict[str, Any] = Field(default_factory=dict)
class RecoveryRecord(BaseModel):
recovery_id: str
source_interrupt_id: str | None = None
strategy: str | None = None
resumed_from_task_id: str | None = None
resumed_from_thread_id: str | None = None
recovered_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
payload: dict[str, Any] = Field(default_factory=dict)
class CollaborationBudget(BaseModel):
mode: BudgetMode = "direct"
max_parallel_tasks: int | None = None
remaining_parallel_tasks: int | None = None
max_tool_calls: int | None = None
remaining_tool_calls: int | None = None
max_iterations: int | None = None
remaining_iterations: int | None = None
escalation_threshold: int | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
class AgentTask(BaseModel):
task_id: str
title: str
status: TaskLifecycleStatus = "pending"
owner_agent_id: str | None = None
role: str | None = None
goal: str | None = None
parent_task_id: str | None = None
child_task_ids: list[str] = Field(default_factory=list)
thread_id: str | None = None
message_id: str | None = None
message_index: int | None = None
expected_evidence: list[dict[str, Any]] = Field(default_factory=list)
evidence: list[dict[str, Any]] = Field(default_factory=list)
interrupt_records: list[InterruptRecord | dict[str, Any]] = Field(default_factory=list)
recovery_records: list[RecoveryRecord | dict[str, Any]] = Field(default_factory=list)
collaboration_budget: CollaborationBudget | dict[str, Any] | None = None
result_summary: str | None = None
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
class TaskResult(BaseModel):
task_id: str
status: TaskResultStatus
summary: str | None = None
evidence: list[dict[str, Any]] = Field(default_factory=list)
owner_agent_id: str | None = None
parent_task_id: str | None = None
child_task_ids: list[str] = Field(default_factory=list)
thread_id: str | None = None
message_id: str | None = None
message_index: int | None = None
interrupt_records: list[InterruptRecord | dict[str, Any]] = Field(default_factory=list)
recovery_records: list[RecoveryRecord | dict[str, Any]] = Field(default_factory=list)
budget_snapshot: CollaborationBudget | dict[str, Any] | None = None
next_action: str | None = None
output_data: dict[str, Any] | None = None

View File

@@ -0,0 +1,17 @@
"""Agent Session Management - Phase 10.3"""
from app.agents.session.manager import (
AgentSession,
SessionContext,
SessionPersistence,
create_agent_session,
get_agent_session,
)
__all__ = [
"AgentSession",
"SessionContext",
"SessionPersistence",
"create_agent_session",
"get_agent_session",
]

View File

@@ -0,0 +1,238 @@
"""Agent Session 管理 - Phase 10.3
支持会话层级管理和子会话创建。
"""
import json
import os
import uuid
from dataclasses import asdict, dataclass, field
from datetime import datetime
from typing import Any
@dataclass
class SessionContext:
"""会话上下文"""
session_id: str
parent_session_id: str | None = None
root_session_id: str | None = None
depth: int = 0
user_id: str | None = None
created_at: str | None = None
last_active: str | None = None
message_count: int = 0
metadata: dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.now().isoformat()
if self.last_active is None:
self.last_active = self.created_at
@dataclass
class SessionPersistence:
"""会话持久化"""
def __init__(self, persistence_dir: str | None = None):
if persistence_dir is None:
persistence_dir = os.path.join(
os.path.dirname(__file__), "..", "..", "..", "data", "sessions"
)
self.persistence_dir = persistence_dir
def _get_session_path(self, session_id: str) -> str:
return os.path.join(self.persistence_dir, f"{session_id}.json")
def save(self, session: "AgentSession") -> bool:
"""保存会话"""
try:
os.makedirs(self.persistence_dir, exist_ok=True)
path = self._get_session_path(session.session_id)
data = {
"session_id": session.session_id,
"parent_session_id": session.context.parent_session_id,
"root_session_id": session.context.root_session_id,
"depth": session.context.depth,
"user_id": session.context.user_id,
"created_at": session.context.created_at,
"last_active": session.context.last_active,
"message_count": session.context.message_count,
"metadata": session.context.metadata,
"history": session._history,
}
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
return True
except Exception:
return False
def load(self, session_id: str) -> dict[str, Any] | None:
"""加载会话"""
try:
path = self._get_session_path(session_id)
if not os.path.exists(path):
return None
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return None
def delete(self, session_id: str) -> bool:
"""删除会话"""
try:
path = self._get_session_path(session_id)
if os.path.exists(path):
os.remove(path)
return True
except Exception:
return False
def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]:
"""列出所有会话"""
sessions = []
try:
os.makedirs(self.persistence_dir, exist_ok=True)
for filename in os.listdir(self.persistence_dir):
if filename.endswith(".json"):
path = os.path.join(self.persistence_dir, filename)
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
if user_id is None or data.get("user_id") == user_id:
sessions.append(data)
except Exception:
pass
return sessions
class AgentSession:
"""Agent 会话管理器
支持:
- 会话层级parent/root/depth
- 子会话创建
- 会话摘要
- 持久化
"""
def __init__(
self,
session_id: str | None = None,
user_id: str | None = None,
parent_session_id: str | None = None,
):
self.session_id = session_id or str(uuid.uuid4())[:8]
self.context = SessionContext(
session_id=self.session_id,
user_id=user_id,
parent_session_id=parent_session_id,
depth=0 if parent_session_id is None else 1,
)
self._history: list[dict[str, Any]] = []
self._persistence = SessionPersistence()
# 如果有父会话,设置 root_session_id
if parent_session_id:
parent_data = self._persistence.load(parent_session_id)
if parent_data:
self.context.root_session_id = (
parent_data.get("root_session_id") or parent_session_id
)
self.context.depth = parent_data.get("depth", 0) + 1
async def initialize(self) -> dict[str, Any]:
"""初始化会话"""
self.context.last_active = datetime.now().isoformat()
self._persistence.save(self)
return {
"session_id": self.session_id,
"depth": self.context.depth,
"parent_session_id": self.context.parent_session_id,
"root_session_id": self.context.root_session_id,
}
async def process_message(self, message: str, response: str) -> None:
"""处理消息并记录到历史"""
self.context.message_count += 1
self.context.last_active = datetime.now().isoformat()
self._history.append(
{
"role": "user",
"content": message,
"timestamp": datetime.now().isoformat(),
}
)
self._history.append(
{
"role": "assistant",
"content": response,
"timestamp": datetime.now().isoformat(),
}
)
self._persistence.save(self)
async def spawn_child_session(self, user_id: str | None = None) -> "AgentSession":
"""创建子会话"""
child = AgentSession(
user_id=user_id or self.context.user_id,
parent_session_id=self.session_id,
)
child.context.root_session_id = self.context.root_session_id or self.session_id
await child.initialize()
return child
async def get_session_summary(self) -> dict[str, Any]:
"""获取会话摘要"""
return {
"session_id": self.session_id,
"parent_session_id": self.context.parent_session_id,
"root_session_id": self.context.root_session_id,
"depth": self.context.depth,
"user_id": self.context.user_id,
"created_at": self.context.created_at,
"last_active": self.context.last_active,
"message_count": self.context.message_count,
"history_length": len(self._history),
}
async def persist(self) -> bool:
"""持久化会话"""
return self._persistence.save(self)
def get_history(self) -> list[dict[str, Any]]:
"""获取会话历史"""
return self._history.copy()
def add_metadata(self, key: str, value: Any) -> None:
"""添加会话元数据"""
self.context.metadata[key] = value
def get_metadata(self, key: str) -> Any:
"""获取会话元数据"""
return self.context.metadata.get(key)
# 全局会话存储(内存中)
_sessions: dict[str, AgentSession] = {}
def get_agent_session(session_id: str) -> AgentSession | None:
"""获取会话"""
return _sessions.get(session_id)
def create_agent_session(
session_id: str | None = None,
user_id: str | None = None,
parent_session_id: str | None = None,
) -> AgentSession:
"""创建新会话"""
session = AgentSession(
session_id=session_id,
user_id=user_id,
parent_session_id=parent_session_id,
)
_sessions[session.session_id] = session
return session

View File

@@ -0,0 +1,16 @@
"""Skills 注册表 - Phase 9"""
from app.agents.skills.registry import SkillRegistry, get_skill_registry
from app.agents.skills.metadata import SkillMetadata
from app.agents.skills.loaders.local_loader import LocalSkillLoader
from app.agents.skills.loaders.plugin_loader import PluginSkillLoader
from app.agents.skills.mcp_builder import MCPSkillBuilder
__all__ = [
"SkillRegistry",
"SkillMetadata",
"LocalSkillLoader",
"PluginSkillLoader",
"MCPSkillBuilder",
"get_skill_registry",
]

View File

@@ -0,0 +1,72 @@
"""Built-in Skills - Phase 9.4
This module contains bundled skills that are always available
without requiring external skill loaders.
"""
from typing import Any
# SkillMetadata-compatible structure for bundled skills
BUNDLED_SKILLS: list[dict[str, Any]] = [
{
"id": "code-analysis",
"name": "Code Analysis",
"description": "Analyze code structure, patterns, and quality. Helps understand codebase architecture, find issues, and suggest improvements.",
"version": "1.0.0",
"prompts": [
"Analyze the code structure and identify key components, their relationships, and responsibilities.",
"Review the code for potential issues like bugs, security vulnerabilities, or performance problems.",
"Explain how the code works and what it does in simple terms.",
],
"tools": ["grep", "read", "glob", "lsp_symbols", "lsp_find_references"],
},
{
"id": "git-helper",
"name": "Git Helper",
"description": "Assists with Git operations including commit, branch management, merge conflicts, and repository exploration.",
"version": "1.0.0",
"prompts": [
"Show me the current git status and any uncommitted changes.",
"Help me create a meaningful commit message for these changes.",
"Explain the git history and branch structure of this repository.",
],
"tools": ["bash"],
},
{
"id": "web-research",
"name": "Web Research",
"description": "Search the web for information, documentation, and resources. Helps find answers and learn about technologies.",
"version": "1.0.0",
"prompts": [
"Search the web for information about {topic} and summarize the key findings.",
"Find official documentation or reliable resources about {topic}.",
"Look up the latest news or developments in {topic}.",
],
"tools": ["search_brave_web_search", "websearch_web_search_exa", "webfetch"],
},
{
"id": "file-management",
"name": "File Management",
"description": "Helps with file operations like creating, editing, organizing, and managing project files and directories.",
"version": "1.0.0",
"prompts": [
"Create a new file at {path} with the following content: {content}",
"Organize the files in the project structure and suggest improvements.",
"Find all files related to {topic} or matching {pattern}.",
],
"tools": ["read", "write", "glob", "bash"],
},
{
"id": "task-planning",
"name": "Task Planning",
"description": "Helps break down complex tasks into smaller steps, create implementation plans, and track progress.",
"version": "1.0.0",
"prompts": [
"Break down this task into smaller, manageable steps: {task}",
"Create an implementation plan for building {feature} with clear phases.",
"Review the current progress and suggest next steps for completing {goal}.",
],
"tools": ["todowrite", "read", "write"],
},
]

View File

@@ -0,0 +1,12 @@
"""Skills 加载器包"""
from app.agents.skills.loaders.local_loader import LocalSkillLoader
from app.agents.skills.loaders.plugin_loader import PluginSkillLoader
from app.agents.skills.loaders.mcp_loader import MCPSkillLoader, get_mcp_skill_loader
__all__ = [
"LocalSkillLoader",
"PluginSkillLoader",
"MCPSkillLoader",
"get_mcp_skill_loader",
]

View File

@@ -0,0 +1,100 @@
"""本地 Skills 加载器 - Phase 9.2"""
import os
import re
from typing import Any
from app.agents.skills.metadata import SkillMetadata
class LocalSkillLoader:
"""本地 Skills 加载器
从 skills_dir 目录加载 SKILL.md 文件。
"""
def __init__(self, skills_dir: str):
self.skills_dir = skills_dir
def load_all(self) -> list[SkillMetadata]:
"""加载所有本地 Skills
Returns:
Skill 元数据列表
"""
skills = []
if not os.path.exists(self.skills_dir):
return skills
for root, dirs, files in os.walk(self.skills_dir):
# 跳过隐藏目录
dirs[:] = [d for d in dirs if not d.startswith(".")]
if "SKILL.md" in files:
skill = self._load_skill_from_dir(root)
if skill:
skills.append(skill)
return skills
def _load_skill_from_dir(self, skill_dir: str) -> SkillMetadata | None:
"""从目录加载 Skill
Args:
skill_dir: Skill 目录
Returns:
Skill 元数据
"""
skill_path = os.path.join(skill_dir, "SKILL.md")
try:
with open(skill_path, "r", encoding="utf-8") as f:
content = f.read()
# 解析 frontmatter
metadata = self._parse_frontmatter(content)
# 获取 Skill 名称(目录名)
name = os.path.basename(skill_dir)
return SkillMetadata(
name=metadata.get("name", name),
description=metadata.get("description", ""),
version=metadata.get("version", "1.0.0"),
author=metadata.get("author", ""),
tags=metadata.get("tags", []),
triggers=metadata.get("triggers", []),
content=content,
source="local",
source_id=skill_dir,
)
except Exception:
return None
def _parse_frontmatter(self, content: str) -> dict[str, Any]:
"""解析 frontmatter"""
metadata = {}
# 匹配 --- 包裹的 frontmatter
match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL)
if match:
frontmatter = match.group(1)
for line in frontmatter.split("\n"):
if ":" in line:
key, value = line.split(":", 1)
key = key.strip()
value = value.strip()
# 处理列表
if value.startswith("[") and value.endswith("]"):
value = [v.strip().strip('"').strip("'") for v in value[1:-1].split(",")]
elif value.lower() in ("true", "false"):
value = value.lower() == "true"
metadata[key] = value
return metadata

View File

@@ -0,0 +1,169 @@
"""MCP Skill 加载器 - Phase 9.2
从 MCP (Model Context Protocol) 服务器发现和加载 Skills。
"""
import os
from typing import Any
from app.agents.skills.metadata import SkillMetadata
class MCPSkillLoader:
"""MCP Skill 加载器
从 MCP 服务器发现可用的 Skills。
"""
def __init__(self, mcp_servers: list[dict[str, Any]] | None = None):
"""
Args:
mcp_servers: MCP 服务器列表,每项包含 name, command, env 等
"""
self.mcp_servers = mcp_servers or []
self._discovered_skills: dict[str, SkillMetadata] = {}
def discover_skills(self) -> list[SkillMetadata]:
"""从所有配置的 MCP 服务器发现 Skills
Returns:
发现的 Skill 列表
"""
skills = []
for server in self.mcp_servers:
server_skills = self._discover_from_server(server)
skills.extend(server_skills)
return skills
def _discover_from_server(self, server: dict[str, Any]) -> list[SkillMetadata]:
"""从单个 MCP 服务器发现 Skills
Args:
server: 服务器配置
Returns:
Skill 列表
"""
skills = []
server_name = server.get("name", "unknown")
# 模拟从 MCP 服务器获取工具列表
# 实际实现时,这里会调用 MCP 服务器的 list_tools 接口
try:
tools = self._call_mcp_list_tools(server)
for tool in tools:
skill = self._tool_to_skill(tool, server_name)
if skill:
skills.append(skill)
self._discovered_skills[skill.name] = skill
except Exception:
pass
return skills
def _call_mcp_list_tools(self, server: dict[str, Any]) -> list[dict[str, Any]]:
"""调用 MCP 服务器的 list_tools 接口
Args:
server: 服务器配置
Returns:
工具列表
"""
# TODO: 实现实际的 MCP 协议调用
# 目前返回空列表,实际使用时需要实现 MCP 客户端
return []
def _tool_to_skill(self, tool: dict[str, Any], server: str) -> SkillMetadata | None:
"""将 MCP 工具转换为 Skill
Args:
tool: MCP 工具定义
server: 服务器名
Returns:
Skill 元数据或 None
"""
tool_name = tool.get("name")
if not tool_name:
return None
return SkillMetadata(
id=f"mcp_{server}_{tool_name}",
name=f"{server}:{tool_name}",
description=tool.get("description", f"MCP tool: {tool_name}"),
version="1.0.0",
content=self._generate_skill_content(tool),
triggers=[f"@{server}", f"/{tool_name}"],
tools=[tool_name],
tags=["mcp", server],
enabled=True,
)
def _generate_skill_content(self, tool: dict[str, Any]) -> str:
"""生成 Skill 内容
Args:
tool: MCP 工具定义
Returns:
Skill 内容字符串
"""
name = tool.get("name", "unknown")
description = tool.get("description", "No description")
input_schema = tool.get("inputSchema", {})
content = f"""# MCP Tool: {name}
**Description**: {description}
**Server**: {tool.get("server", "unknown")}
**Input Schema**:
```json
{input_schema}
```
**Usage**:
Use the `/{name}` command or `@{tool.get("server", "server")}` to invoke this tool.
**Examples**:
```
/{name} arg1=value1 arg2=value2
@{tool.get("server", "server")} {name} --arg1 value1
```
"""
return content
def get_skill(self, name: str) -> SkillMetadata | None:
"""获取已发现的 Skill
Args:
name: Skill 名称
Returns:
Skill 元数据或 None
"""
return self._discovered_skills.get(name)
def list_skills(self) -> list[SkillMetadata]:
"""列出所有已发现的 Skills
Returns:
Skill 列表
"""
return list(self._discovered_skills.values())
# 全局加载器
_loader: MCPSkillLoader | None = None
def get_mcp_skill_loader() -> MCPSkillLoader:
"""获取全局 MCP Skill 加载器"""
global _loader
if _loader is None:
_loader = MCPSkillLoader()
return _loader

View File

@@ -0,0 +1,53 @@
"""插件 Skills 加载器 - Phase 9.2"""
from typing import Any
from app.agents.skills.metadata import SkillMetadata
from app.agents.plugins.manager import get_plugin_manager
class PluginSkillLoader:
"""插件 Skills 加载器
从已安装的插件中加载 Skills。
"""
def __init__(self):
self.plugin_manager = get_plugin_manager()
def load_all(self) -> list[SkillMetadata]:
"""从所有已启用的插件加载 Skills
Returns:
Skill 元数据列表
"""
skills = []
for plugin in self.plugin_manager.list_plugins():
if not self.plugin_manager.is_enabled(plugin.id):
continue
# 从插件加载 Skills
plugin_skills = self._load_from_plugin(plugin)
skills.extend(plugin_skills)
return skills
def _load_from_plugin(self, plugin: Any) -> list[SkillMetadata]:
"""从单个插件加载 Skills"""
skills = []
for skill_name in plugin.skills:
skill = SkillMetadata(
name=f"{plugin.id}/{skill_name}",
description=f"Skill from plugin: {plugin.name}",
version=plugin.version,
author=plugin.author,
tags=["plugin", plugin.id],
content=f"# {skill_name}\n\nFrom plugin: {plugin.name}",
source="plugin",
source_id=plugin.id,
)
skills.append(skill)
return skills

View File

@@ -0,0 +1,100 @@
"""MCP Skill Builder - Phase 9.3"""
from typing import Any
from app.agents.skills.metadata import SkillMetadata
class MCPSkillBuilder:
"""MCP Skill Builder
从 MCP 服务器发现和构建 Skills。
"""
def __init__(self):
self._skills: dict[str, SkillMetadata] = {}
def discover_skills_from_mcp(self, mcp_servers: list[dict[str, Any]]) -> list[SkillMetadata]:
"""从 MCP 服务器发现 Skills
Args:
mcp_servers: MCP 服务器配置列表
Returns:
发现的 Skill 元数据列表
"""
skills = []
for server in mcp_servers:
server_skills = self._discover_from_server(server)
skills.extend(server_skills)
return skills
def _discover_from_server(self, server: dict[str, Any]) -> list[SkillMetadata]:
"""从单个 MCP 服务器发现 Skills"""
skills = []
server_name = server.get("name", "unknown")
tools = server.get("tools", [])
# 按工具分组
tool_groups: dict[str, list[str]] = {}
for tool in tools:
group = tool.get("group", "default")
if group not in tool_groups:
tool_groups[group] = []
tool_groups[group].append(tool)
# 为每个组创建一个 Skill
for group_name, group_tools in tool_groups.items():
skill = self._tool_to_skill(group_name, group_tools, server_name)
skills.append(skill)
return skills
def _tool_to_skill(self, group: str, tools: list[dict[str, Any]], server: str) -> SkillMetadata:
"""将 MCP 工具转换为 Skill"""
tool_summaries = []
for tool in tools:
name = tool.get("name", "unknown")
description = tool.get("description", "")
input_schema = tool.get("inputSchema", {})
tool_summaries.append(f"### {name}\n{description}\n\nInput: {input_schema}")
content = f"""# MCP Skill: {group}
来自 MCP 服务器: {server}
## 工具列表
{chr(10).join(tool_summaries)}
## 使用说明
使用这些工具前请确保理解每个工具的输入输出格式。
"""
return SkillMetadata(
name=f"mcp-{server}-{group}",
description=f"MCP skill from {server}: {group}",
version="1.0.0",
tags=["mcp", server, group],
triggers=[group, server],
content=content,
source="mcp",
source_id=f"{server}:{group}",
)
def _group_to_skill(self, group: str, tools: list[str], server: str) -> SkillMetadata:
"""将 MCP 工具组转换为 Skill"""
return SkillMetadata(
name=f"mcp-{server}-{group}",
description=f"MCP skill from {server}: {group}",
version="1.0.0",
tags=["mcp", server, group],
triggers=[group, server],
content=f"# {group}\n\nTools: {', '.join(tools)}",
source="mcp",
source_id=f"{server}:{group}",
)

View File

@@ -0,0 +1,42 @@
"""Skill 元数据定义 - Phase 9.1"""
from dataclasses import dataclass, field
from typing import Any
@dataclass
class SkillMetadata:
"""Skill 元数据"""
id: str = "" # Skill ID
name: str = "" # Skill 名称
description: str = "" # 描述
version: str = "1.0.0" # 版本
author: str = "" # 作者
tags: list[str] = field(default_factory=list) # 标签
triggers: list[str] = field(default_factory=list) # 触发关键词
content: str = "" # Skill 内容markdown
source: str = "local" # 来源local, plugin, mcp, bundled
source_id: str = "" # 来源 ID
enabled: bool = True # 是否启用
tools: list[str] = field(default_factory=list) # 关联的工具
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"name": self.name,
"description": self.description,
"version": self.version,
"author": self.author,
"tags": self.tags,
"triggers": self.triggers,
"content": self.content,
"source": self.source,
"source_id": self.source_id,
"enabled": self.enabled,
"tools": self.tools,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "SkillMetadata":
return cls(**data)

View File

@@ -0,0 +1,133 @@
"""Skills 注册表 - Phase 9.1"""
import os
from typing import Any
from app.agents.skills.metadata import SkillMetadata
from app.agents.skills.loaders.local_loader import LocalSkillLoader
class SkillRegistry:
"""Skills 注册表
管理所有 Skills 的注册、发现和加载。
"""
def __init__(self):
self._skills: dict[str, SkillMetadata] = {}
self._loaders: list[Any] = []
def load_all(self, skills_dir: str | None = None) -> int:
"""加载所有 Skills
Args:
skills_dir: Skills 目录None 则使用默认目录
Returns:
加载的 Skill 数量
"""
if skills_dir is None:
skills_dir = os.path.join(
os.path.dirname(__file__), "..", "..", "..", ".claude", "skills"
)
count = 0
# 本地加载器
local_loader = LocalSkillLoader(skills_dir)
local_skills = local_loader.load_all()
for skill in local_skills:
self.register(skill)
count += 1
# 插件加载器
for loader in self._loaders:
try:
external_skills = loader.load_all()
for skill in external_skills:
self.register(skill)
count += 1
except Exception:
pass
return count
def register(self, skill: SkillMetadata) -> None:
"""注册 Skill"""
self._skills[skill.name] = skill
def unregister(self, name: str) -> bool:
"""注销 Skill"""
if name in self._skills:
del self._skills[name]
return True
return False
def get_skill(self, name: str) -> SkillMetadata | None:
"""获取 Skill"""
return self._skills.get(name)
def search(self, query: str) -> list[SkillMetadata]:
"""搜索 Skills
Args:
query: 搜索关键词
Returns:
匹配的 Skills 列表
"""
query_lower = query.lower()
results = []
for skill in self._skills.values():
if not skill.enabled:
continue
# 匹配名称、描述、标签
if (
query_lower in skill.name.lower()
or query_lower in skill.description.lower()
or any(query_lower in tag.lower() for tag in skill.tags)
or any(query_lower in trigger.lower() for trigger in skill.triggers)
):
results.append(skill)
return results
def get_skill_context(self, names: list[str]) -> str:
"""获取 Skill 上下文
Args:
names: Skill 名称列表
Returns:
拼接的 Skill 内容
"""
contexts = []
for name in names:
skill = self._skills.get(name)
if skill and skill.enabled:
contexts.append(f"# {skill.name}\n\n{skill.content}")
return "\n\n---\n\n".join(contexts)
def add_loader(self, loader: Any) -> None:
"""添加加载器"""
self._loaders.append(loader)
def list_all(self) -> list[SkillMetadata]:
"""列出所有 Skills"""
return list(self._skills.values())
# 全局单例
_registry: SkillRegistry | None = None
def get_skill_registry() -> SkillRegistry:
"""获取全局 Skills 注册表"""
global _registry
if _registry is None:
_registry = SkillRegistry()
return _registry

View File

@@ -0,0 +1,140 @@
"""Skill 触发检测器 - Phase 9.5
检测消息中的 Skill 触发条件。
"""
import re
from typing import Any
from app.agents.skills.metadata import SkillMetadata
class SkillTriggerDetector:
"""Skill 触发检测器
检测用户消息中是否触发了某个 Skill。
"""
def __init__(self):
self._skills: dict[str, SkillMetadata] = {}
def register_skill(self, skill: SkillMetadata) -> None:
"""注册 Skill
Args:
skill: Skill 元数据
"""
self._skills[skill.name] = skill
def unregister_skill(self, name: str) -> bool:
"""注销 Skill
Args:
name: Skill 名称
Returns:
是否成功
"""
if name in self._skills:
del self._skills[name]
return True
return False
def detect_triggered_skills(self, message: str) -> list[str]:
"""检测触发的 Skills
Args:
message: 用户消息
Returns:
触发的 Skill 名称列表
"""
triggered = []
message_lower = message.lower()
for skill in self._skills.values():
if not skill.enabled:
continue
if self._matches_triggers(message, message_lower, skill):
triggered.append(skill.name)
return triggered
def _matches_triggers(self, message: str, message_lower: str, skill: SkillMetadata) -> bool:
"""检查消息是否匹配 Skill 触发条件
Args:
message: 原始消息
message_lower: 小写消息
skill: Skill 元数据
Returns:
是否匹配
"""
for trigger in skill.triggers:
trigger_lower = trigger.lower()
# 前缀匹配,如 "/code" 或 "@git"
if trigger_lower.startswith("/") or trigger_lower.startswith("@"):
if message_lower.startswith(trigger_lower):
return True
# 命令格式,如 "//analyze"
if trigger_lower.startswith("//"):
pattern = trigger_lower[2:]
if re.search(rf"\b{re.escape(pattern)}\b", message_lower):
return True
# 关键词匹配
if trigger_lower in message_lower:
return True
return False
def get_skill_prompt(self, skill_name: str) -> str | None:
"""获取 Skill 的提示词
Args:
skill_name: Skill 名称
Returns:
Skill 内容或 None
"""
skill = self._skills.get(skill_name)
if skill:
return skill.content
return None
def get_triggered_skill_context(self, message: str) -> str:
"""获取触发的 Skills 上下文
Args:
message: 用户消息
Returns:
拼接的 Skill 上下文
"""
triggered = self.detect_triggered_skills(message)
if not triggered:
return ""
contexts = []
for skill_name in triggered:
skill = self._skills.get(skill_name)
if skill:
contexts.append(f"# {skill.name}\n\n{skill.content}")
return "\n\n---\n\n".join(contexts)
# 全局检测器
_detector: SkillTriggerDetector | None = None
def get_skill_trigger_detector() -> SkillTriggerDetector:
"""获取全局 Skill 触发检测器"""
global _detector
if _detector is None:
_detector = SkillTriggerDetector()
return _detector

View File

@@ -1,10 +1,21 @@
from dataclasses import dataclass, field
from typing import TypedDict, Annotated, Sequence
from dataclasses import dataclass
from enum import Enum
from typing import Annotated, Any, Literal, TypedDict
from langchain_core.messages import BaseMessage
from app.agents.schemas.event import AgentEvent
from app.agents.schemas.message import AgentMessage
from app.agents.schemas.task import AgentTask, CollaborationBudget, InterruptRecord, RecoveryRecord, TaskResult, VerificationStatus
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langgraph.graph.message import add_messages
AgentPhase = Literal[
"phase_0_bootstrap",
"phase_1_routing",
"phase_2_controlled_collaboration",
"phase_3_dynamic_collaboration",
"phase_4_visibility_and_verification",
]
class AgentRole(str, Enum):
MASTER = "master"
@@ -22,41 +33,113 @@ class ConversationTurn:
model: str | None = None
def turn_to_message(turn: ConversationTurn) -> BaseMessage:
if turn.role == "user":
return HumanMessage(content=turn.content)
return AIMessage(content=turn.content)
class AgentState(TypedDict):
# Core message history with add_messages reducer
messages: Annotated[list[BaseMessage], add_messages]
# Session identifiers
user_id: str
conversation_id: str
parent_conversation_id: str | None
thread_id: str | None
last_message_id: str | None
message_sequence: int
agent_id: str | None
parent_agent_id: str | None
root_agent_id: str | None
collaboration_depth: int
spawned_agent_ids: list[str]
# Agent routing state
execution_mode: Literal["direct", "collaboration", "delegated", "verified"]
current_agent: str | None
next_step: str | None # For explicit graph routing
# Traceability
next_step: str | None
active_agents: list[AgentRole]
current_sub_commander: str | None
active_sub_commanders: list[str]
sub_commander_trace: list[dict[str, Any]]
agent_trace: list[str]
# Task & Entity Tracking (Business Logic)
pending_tasks: list[dict]
completed_tasks: list[dict]
created_entities: list[dict]
event_trace: list[AgentEvent | dict[str, Any]]
message_trace: list[AgentMessage | dict[str, Any]]
pending_tasks: list[dict[str, Any]]
completed_tasks: list[dict[str, Any]]
active_tasks: list[AgentTask | dict[str, Any]]
task_results: list[TaskResult | dict[str, Any]]
task_hierarchy: dict[str, list[str]]
interrupted_tasks: list[InterruptRecord | dict[str, Any]]
recovery_trace: list[RecoveryRecord | dict[str, Any]]
recovery_points: list[dict[str, Any]]
tool_calls: list[dict[str, Any]]
last_tool_result: str | None
action_results: list[dict[str, Any]]
created_entities: list[dict[str, Any]]
tool_outcomes: list[dict[str, Any]]
task_result_summary: dict[str, Any] | None
verifier_hints: dict[str, Any] | None
verification_status: VerificationStatus | None
verification_summary: str | None
verification_evidence: list[dict[str, Any]]
isolation_mode: str
isolation_id: str | None
isolation_workspace_path: str | None
isolation_parent_conversation_id: str | None
isolation_metadata: dict[str, Any]
input_tokens: int
output_tokens: int
estimated_cost: float | None
budget_warning: bool
cost_by_agent: dict[str, dict[str, Any]]
cost_thresholds: dict[str, Any]
budget_state: CollaborationBudget | dict[str, Any] | None
collaboration_budget_history: list[CollaborationBudget | dict[str, Any]]
current_phase: AgentPhase
phase_history: list[dict[str, Any]]
current_checkpoint: str | None
checkpoint_history: list[dict[str, Any]]
tool_strategy_used: str | None
tool_round_count: int
max_tool_rounds: int
retry_count: int
max_retries: int
iteration_count: int
max_iterations: int
routing_hops: int
max_routing_hops: int
terminated_due_to_loop_guard: bool
retrieval_trace: list[dict[str, Any]]
stop_reason: str | None
clarification_needed: bool
clarification_question: str | None
fallback_parse_error: str | None
should_respond: bool
# Context summaries (for long-term or cross-agent context)
knowledge_context: str | None
graph_context: str | None
schedule_context_summary: str | None
plan: str | None
plan_steps: list[dict[str, Any]]
analysis_report: str | None
# Output control
final_response: str | None
# Memory & Environment
memory_context: str | None
current_datetime_context: str | None
# Configuration
user_llm_config: dict | None
provider_capabilities: dict | None
current_datetime_reference: dict[str, str] | None
turn_context: dict[str, Any] | None
routing_decision: dict[str, Any] | None
continuity_state: dict[str, Any] | None
pending_action: dict[str, Any] | None
last_completed_action: dict[str, Any] | None
clarification_context: dict[str, Any] | None
user_llm_config: dict[str, Any] | None
provider_capabilities: dict[str, Any] | None
def initial_state(user_id: str, conversation_id: str) -> AgentState:
@@ -64,18 +147,103 @@ def initial_state(user_id: str, conversation_id: str) -> AgentState:
messages=[],
user_id=user_id,
conversation_id=conversation_id,
parent_conversation_id=None,
thread_id=None,
last_message_id=None,
message_sequence=0,
agent_id=AgentRole.MASTER.value,
parent_agent_id=None,
root_agent_id=AgentRole.MASTER.value,
collaboration_depth=0,
spawned_agent_ids=[],
execution_mode="direct",
current_agent=AgentRole.MASTER.value,
next_step=None,
active_agents=[AgentRole.MASTER],
current_sub_commander=None,
active_sub_commanders=[],
sub_commander_trace=[],
agent_trace=[AgentRole.MASTER.value],
event_trace=[],
message_trace=[],
pending_tasks=[],
completed_tasks=[],
active_tasks=[],
task_results=[],
task_hierarchy={},
interrupted_tasks=[],
recovery_trace=[],
recovery_points=[],
tool_calls=[],
last_tool_result=None,
action_results=[],
created_entities=[],
tool_outcomes=[],
task_result_summary=None,
verifier_hints=None,
verification_status=None,
verification_summary=None,
verification_evidence=[],
isolation_mode="none",
isolation_id=None,
isolation_workspace_path=None,
isolation_parent_conversation_id=None,
isolation_metadata={},
input_tokens=0,
output_tokens=0,
estimated_cost=None,
budget_warning=False,
cost_by_agent={},
cost_thresholds={},
budget_state=None,
collaboration_budget_history=[],
current_phase="phase_0_bootstrap",
phase_history=[
{
"phase": "phase_0_bootstrap",
"reason": "initial_state_created",
}
],
current_checkpoint="bootstrap.initialized",
checkpoint_history=[
{
"checkpoint": "bootstrap.initialized",
"phase": "phase_0_bootstrap",
"reason": "initial_state_created",
}
],
tool_strategy_used=None,
tool_round_count=0,
max_tool_rounds=2,
retry_count=0,
max_retries=1,
iteration_count=0,
max_iterations=3,
routing_hops=0,
max_routing_hops=2,
terminated_due_to_loop_guard=False,
retrieval_trace=[],
stop_reason=None,
clarification_needed=False,
clarification_question=None,
fallback_parse_error=None,
should_respond=True,
knowledge_context=None,
graph_context=None,
schedule_context_summary=None,
plan=None,
plan_steps=[],
analysis_report=None,
final_response=None,
memory_context=None,
current_datetime_context=None,
current_datetime_reference=None,
turn_context=None,
routing_decision=None,
continuity_state=None,
pending_action=None,
last_completed_action=None,
clarification_context=None,
user_llm_config=None,
provider_capabilities=None,
)

View File

@@ -0,0 +1,13 @@
"""Team 多 Agent 协作"""
from app.agents.team.leader import TeamLeader, TeamTask, TaskStatus
from app.agents.team.member import TeamMember, MemberStatus, MemberTask
__all__ = [
"TeamLeader",
"TeamTask",
"TaskStatus",
"TeamMember",
"MemberStatus",
"MemberTask",
]

View File

@@ -0,0 +1,121 @@
"""Team 多 Agent 协作 - Phase 10.1"""
from dataclasses import dataclass, field
from typing import Any
from enum import Enum
class TaskStatus(Enum):
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class TeamTask:
"""团队任务"""
id: str
description: str
assignee: str | None = None
status: TaskStatus = TaskStatus.PENDING
result: Any = None
error: str | None = None
class TeamLeader:
"""团队领导者
协调多个 Agent 成员执行任务。
"""
def __init__(self, team_id: str, members: list[str]):
"""
Args:
team_id: 团队 ID
members: 成员 ID 列表
"""
self.team_id = team_id
self.members = members
self._tasks: dict[str, TeamTask] = {}
def create_task(self, description: str) -> str:
"""创建任务
Args:
description: 任务描述
Returns:
任务 ID
"""
import uuid
task_id = str(uuid.uuid4())[:8]
self._tasks[task_id] = TeamTask(
id=task_id,
description=description,
)
return task_id
def assign_task(self, task_id: str, member: str) -> bool:
"""分配任务
Args:
task_id: 任务 ID
member: 成员 ID
Returns:
是否成功
"""
if task_id not in self._tasks:
return False
if member not in self.members:
return False
self._tasks[task_id].assignee = member
self._tasks[task_id].status = TaskStatus.IN_PROGRESS
return True
def broadcast_task(self, description: str) -> list[str]:
"""广播任务给所有成员
Args:
description: 任务描述
Returns:
创建的任务 ID 列表
"""
task_ids = []
for member in self.members:
task_id = self.create_task(description)
self.assign_task(task_id, member)
task_ids.append(task_id)
return task_ids
def collect_results(self) -> dict[str, Any]:
"""收集所有任务结果
Returns:
任务 ID -> 结果的映射
"""
return {
task_id: task.result
for task_id, task in self._tasks.items()
if task.status == TaskStatus.COMPLETED
}
def get_team_status(self) -> dict[str, Any]:
"""获取团队状态
Returns:
团队状态摘要
"""
return {
"team_id": self.team_id,
"members": self.members,
"task_count": len(self._tasks),
"completed": sum(1 for t in self._tasks.values() if t.status == TaskStatus.COMPLETED),
"failed": sum(1 for t in self._tasks.values() if t.status == TaskStatus.FAILED),
}

View File

@@ -0,0 +1,166 @@
"""TeamMember 实现 - Phase 10.1
团队成员实现,负责执行分配的任务。
"""
from dataclasses import dataclass, field
from typing import Any
from enum import Enum
class MemberStatus(Enum):
"""成员状态"""
IDLE = "idle"
BUSY = "busy"
OFFLINE = "offline"
@dataclass
class MemberTask:
"""成员任务"""
task_id: str
description: str
status: str = "pending" # pending, in_progress, completed, failed
result: Any = None
error: str | None = None
class TeamMember:
"""团队成员
代表团队中的一个 Agent 成员,负责执行分配的任务。
"""
def __init__(self, member_id: str, name: str, capabilities: list[str] | None = None):
"""
Args:
member_id: 成员 ID
name: 成员名称
capabilities: 成员能力列表
"""
self.member_id = member_id
self.name = name
self.capabilities = capabilities or []
self.status = MemberStatus.IDLE
self._tasks: dict[str, MemberTask] = {}
self._metadata: dict[str, Any] = {}
def assign_task(self, task_id: str, description: str) -> MemberTask:
"""接收任务分配
Args:
task_id: 任务 ID
description: 任务描述
Returns:
创建的任务对象
"""
task = MemberTask(task_id=task_id, description=description)
self._tasks[task_id] = task
self.status = MemberStatus.BUSY
return task
def update_task_status(
self, task_id: str, status: str, result: Any = None, error: str | None = None
) -> bool:
"""更新任务状态
Args:
task_id: 任务 ID
status: 新状态
result: 任务结果
error: 错误信息
Returns:
是否更新成功
"""
if task_id not in self._tasks:
return False
task = self._tasks[task_id]
task.status = status
if result is not None:
task.result = result
if error is not None:
task.error = error
if status in ("completed", "failed"):
self.status = MemberStatus.IDLE
return True
def get_task(self, task_id: str) -> MemberTask | None:
"""获取任务
Args:
task_id: 任务 ID
Returns:
任务对象或 None
"""
return self._tasks.get(task_id)
def get_pending_tasks(self) -> list[MemberTask]:
"""获取待处理任务
Returns:
待处理任务列表
"""
return [t for t in self._tasks.values() if t.status == "pending"]
def get_active_task(self) -> MemberTask | None:
"""获取当前执行中的任务
Returns:
当前任务或 None
"""
for task in self._tasks.values():
if task.status == "in_progress":
return task
return None
def get_completed_tasks(self) -> list[MemberTask]:
"""获取已完成任务
Returns:
已完成任务列表
"""
return [t for t in self._tasks.values() if t.status == "completed"]
def set_metadata(self, key: str, value: Any) -> None:
"""设置元数据
Args:
key: 元数据键
value: 元数据值
"""
self._metadata[key] = value
def get_metadata(self, key: str) -> Any:
"""获取元数据
Args:
key: 元数据键
Returns:
元数据值或 None
"""
return self._metadata.get(key)
def get_status(self) -> dict[str, Any]:
"""获取成员状态
Returns:
状态字典
"""
return {
"member_id": self.member_id,
"name": self.name,
"status": self.status.value,
"capabilities": self.capabilities,
"task_count": len(self._tasks),
"pending_count": len(self.get_pending_tasks()),
"active_task": self.get_active_task().__dict__ if self.get_active_task() else None,
}

View File

@@ -1,6 +1,9 @@
from app.agents.tools.search import (
search_knowledge, get_knowledge_graph_context,
build_knowledge_graph, hybrid_search, web_search,
search_knowledge,
get_knowledge_graph_context,
build_knowledge_graph,
hybrid_search,
web_search,
)
from app.agents.tools.task import get_tasks, create_task, update_task_status
from app.agents.tools.forum import get_forum_posts, create_forum_post, scan_forum_for_instructions
@@ -13,6 +16,58 @@ from app.agents.tools.schedule import (
)
from app.agents.tools.time_reasoning import resolve_time_expression
# Phase 6.1: Tool Registry exports
from app.agents.tools.registry import (
ToolRegistry,
get_tool_registry,
reset_tool_registry,
)
from app.agents.tools.manifest import (
HookConfig,
PermissionClass,
SideEffectScope,
ToolCategory,
ToolManifest,
)
from app.agents.tools.migration import (
migrate_tool,
migrate_all_tools,
get_tool_executor,
BackwardCompatTool,
)
# Phase 6.2: Hook System exports
from app.agents.tools.hooks import (
HookManager,
HookExecutor,
HookType,
HookDefinition,
HookResult,
ExecutionContext,
get_hook_manager,
get_hook_executor,
)
# Phase 6.3: Streaming Executor exports
from app.agents.tools.streaming import (
StreamingToolExecutor,
get_streaming_executor,
)
# Phase 6.4: Builtin Tools exports
from app.agents.tools.builtins import (
GlobTool,
GrepTool,
ReadFileTool,
WriteFileTool,
BashTool,
PowerShellTool,
LSPTools,
GitTool,
TeamAgentTool,
TaskBroadcastTool,
)
TASK_TOOLS = [
get_tasks,
create_task,

View File

@@ -0,0 +1,18 @@
from __future__ import annotations
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Any
_executor = ThreadPoolExecutor(max_workers=4)
def run_async(coro: Any, timeout: int = 30):
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
__all__ = ["run_async"]

View File

@@ -0,0 +1,161 @@
"""工具基类 - 工具系统重构 Phase 6.1"""
from abc import ABC, abstractmethod
from typing import Any, Generic, TypeVar
from app.agents.tools.manifest import (
PermissionClass,
SideEffectScope,
ToolCategory,
ToolManifest,
)
T = TypeVar("T")
class BaseTool(ABC, Generic[T]):
"""工具基类
提供工具的标准接口和默认实现。
所有自定义工具应继承此类。
"""
def __init__(
self,
name: str,
description: str,
category: ToolCategory,
permission_class: PermissionClass,
side_effect_scope: SideEffectScope = SideEffectScope.NONE,
requires_confirmation: bool = False,
is_streaming: bool = False,
tags: list[str] | None = None,
):
self.name = name
self.description = description
self.category = category
self.permission_class = permission_class
self.side_effect_scope = side_effect_scope
self.requires_confirmation = requires_confirmation
self.is_streaming = is_streaming
self.tags = tags or []
def get_manifest(self) -> ToolManifest:
"""获取工具元数据
Returns:
工具元数据
"""
return ToolManifest(
name=self.name,
description=self.description,
category=self.category,
parameters=self.get_parameters(),
return_schema=self.get_return_schema(),
permission_class=self.permission_class,
side_effect_scope=self.side_effect_scope,
requires_confirmation=self.requires_confirmation,
is_streaming=self.is_streaming,
tags=self.tags,
)
@abstractmethod
def get_parameters(self) -> dict[str, Any]:
"""获取参数 SchemaJSON Schema 格式)
Returns:
参数 schema
"""
pass
@abstractmethod
def get_return_schema(self) -> dict[str, Any]:
"""获取返回值 Schema
Returns:
返回值 schema
"""
pass
@abstractmethod
async def execute(self, **kwargs) -> T:
"""执行工具
Args:
**kwargs: 工具参数
Returns:
执行结果
"""
pass
async def execute_safe(self, **kwargs) -> dict[str, Any]:
"""安全执行工具,捕获异常
Args:
**kwargs: 工具参数
Returns:
包含 success 和 result/error 的字典
"""
try:
result = await self.execute(**kwargs)
return {"success": True, "result": result}
except Exception as e:
return {"success": False, "error": str(e)}
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(name={self.name!r})>"
class ReadTool(BaseTool):
"""只读工具基类"""
def __init__(self, **kwargs):
kwargs.setdefault("category", ToolCategory.READ)
kwargs.setdefault("permission_class", PermissionClass.READ)
kwargs.setdefault("side_effect_scope", SideEffectScope.NONE)
super().__init__(**kwargs)
class WriteTool(BaseTool):
"""写入工具基类"""
def __init__(self, **kwargs):
kwargs.setdefault("category", ToolCategory.WRITE)
kwargs.setdefault("permission_class", PermissionClass.WRITE)
kwargs.setdefault("side_effect_scope", SideEffectScope.LOCAL_STATE)
super().__init__(**kwargs)
class DBWriteTool(BaseTool):
"""数据库写入工具基类"""
def __init__(self, **kwargs):
kwargs.setdefault("category", ToolCategory.DB_WRITE)
kwargs.setdefault("permission_class", PermissionClass.WRITE)
kwargs.setdefault("side_effect_scope", SideEffectScope.DB_WRITE)
kwargs.setdefault("requires_confirmation", True)
super().__init__(**kwargs)
class ExternalTool(BaseTool):
"""外部工具基类(执行外部命令等)"""
def __init__(self, **kwargs):
kwargs.setdefault("category", ToolCategory.EXTERNAL)
kwargs.setdefault("permission_class", PermissionClass.EXTERNAL)
kwargs.setdefault("side_effect_scope", SideEffectScope.NETWORK)
kwargs.setdefault("requires_confirmation", True)
super().__init__(**kwargs)
class NetworkTool(BaseTool):
"""网络工具基类"""
def __init__(self, **kwargs):
kwargs.setdefault("category", ToolCategory.NETWORK)
kwargs.setdefault("permission_class", PermissionClass.EXTERNAL)
kwargs.setdefault("side_effect_scope", SideEffectScope.NETWORK)
super().__init__(**kwargs)

View File

@@ -0,0 +1,43 @@
"""内置工具集 - Phase 6.4
新的内置工具,使用 BaseTool 基类。
"""
from app.agents.tools.builtins.file_tools import (
GlobTool,
GrepTool,
ReadFileTool,
WriteFileTool,
)
from app.agents.tools.builtins.system_tools import (
BashTool,
PowerShellTool,
)
from app.agents.tools.builtins.dev_tools import (
LSPTools,
GitTool,
)
from app.agents.tools.builtins.collaboration_tools import (
TeamAgentTool,
TaskBroadcastTool,
)
__all__ = [
# File tools
"GlobTool",
"GrepTool",
"ReadFileTool",
"WriteFileTool",
# System tools
"BashTool",
"PowerShellTool",
# Dev tools
"LSPTools",
"GitTool",
# Collaboration tools
"TeamAgentTool",
"TaskBroadcastTool",
]

View File

@@ -0,0 +1,129 @@
"""协作工具 - Phase 6.4"""
from typing import Any
from app.agents.tools.base import WriteTool
from app.agents.tools.manifest import (
PermissionClass,
SideEffectScope,
)
class TeamAgentTool(WriteTool):
"""团队 Agent 通信工具
用于与其他 Agent 进行消息传递和协作。
"""
def __init__(self):
super().__init__(
name="team_agent",
description="向团队 Agent 发送消息或请求协作",
permission_class=PermissionClass.WRITE,
side_effect_scope=SideEffectScope.LOCAL_STATE,
tags=["collaboration", "team", "agent"],
)
def get_parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"agent_name": {
"type": "string",
"description": "目标 Agent 名称",
},
"message": {
"type": "string",
"description": "要发送的消息",
},
"action": {
"type": "string",
"enum": ["send", "request", "delegate"],
"description": "操作类型",
},
},
"required": ["agent_name", "message"],
}
def get_return_schema(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"success": {"type": "boolean"},
"response": {"type": "string"},
},
}
async def execute(self, agent_name: str, message: str, action: str = "send") -> dict[str, Any]:
# 注意:实际实现需要通过 Agent 通信协议
# 这里只是一个框架实现
return {
"success": True,
"response": f"Message '{action}' to agent '{agent_name}': {message}",
"agent_name": agent_name,
"action": action,
}
class TaskBroadcastTool(WriteTool):
"""任务广播工具
向多个 Agent 广播任务。
"""
def __init__(self):
super().__init__(
name="task_broadcast",
description="向多个 Agent 广播任务",
permission_class=PermissionClass.WRITE,
side_effect_scope=SideEffectScope.LOCAL_STATE,
tags=["collaboration", "broadcast", "task"],
)
def get_parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"agent_names": {
"type": "array",
"items": {"type": "string"},
"description": "目标 Agent 列表",
},
"task": {
"type": "string",
"description": "要广播的任务描述",
},
"priority": {
"type": "string",
"enum": ["low", "normal", "high", "urgent"],
"description": "任务优先级",
},
},
"required": ["agent_names", "task"],
}
def get_return_schema(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"success": {"type": "boolean"},
"broadcast_to": {"type": "array", "items": {"type": "string"}},
"responses": {"type": "array"},
},
}
async def execute(
self,
agent_names: list[str],
task: str,
priority: str = "normal",
) -> dict[str, Any]:
# 注意:实际实现需要通过 Agent 通信协议
# 这里只是一个框架实现
return {
"success": True,
"broadcast_to": agent_names,
"task": task,
"priority": priority,
"responses": [f"Acknowledged by {agent}" for agent in agent_names],
}

View File

@@ -0,0 +1,155 @@
"""开发工具 - Phase 6.4"""
from typing import Any
from app.agents.tools.base import ReadTool, WriteTool
from app.agents.tools.manifest import (
PermissionClass,
SideEffectScope,
)
class LSPTools(ReadTool):
"""语言服务器协议工具集
提供代码导航、查找引用等 LSP 功能。
"""
def __init__(self):
super().__init__(
name="lsp_tools",
description="LSP 代码导航和查找引用",
permission_class=PermissionClass.READ,
side_effect_scope=SideEffectScope.NONE,
tags=["development", "lsp", "code"],
)
def get_parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["goto_definition", "find_references", "document_symbols"],
"description": "LSP 操作类型",
},
"file": {
"type": "string",
"description": "文件路径",
},
"line": {
"type": "integer",
"description": "行号1-based",
},
"character": {
"type": "integer",
"description": "列号0-based",
},
},
"required": ["action", "file"],
}
def get_return_schema(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"success": {"type": "boolean"},
"results": {"type": "array"},
},
}
async def execute(
self,
action: str,
file: str,
line: int = 1,
character: int = 0,
) -> dict[str, Any]:
# 注意:实际 LSP 调用需要通过 lsp-utils 或类似库
# 这里只是一个框架实现
return {
"success": False,
"error": f"LSP action '{action}' not fully implemented - requires LSP server integration",
"action": action,
"file": file,
"position": {"line": line, "character": character},
}
class GitTool(ReadTool):
"""Git 操作工具
提供常用的 Git 操作。
"""
def __init__(self, repo_path: str = "."):
super().__init__(
name="git",
description="执行 Git 命令",
permission_class=PermissionClass.EXTERNAL,
side_effect_scope=SideEffectScope.LOCAL_STATE,
requires_confirmation=True,
tags=["development", "git", "version-control"],
)
self.repo_path = repo_path
def get_parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "Git 子命令和参数,如 'status''log --oneline -10'",
},
"repo_path": {
"type": "string",
"description": "仓库路径(可选)",
},
},
"required": ["command"],
}
def get_return_schema(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"stdout": {"type": "string"},
"stderr": {"type": "string"},
"returncode": {"type": "integer"},
},
}
async def execute(self, command: str, repo_path: str | None = None) -> dict[str, Any]:
import asyncio
import os
import platform
repo = repo_path or self.repo_path
# 构建完整的 git 命令
if platform.system() == "Windows":
full_command = f'git -C "{repo}" {command}'
else:
full_command = f"git -C '{repo}' {command}"
try:
process = await asyncio.create_subprocess_shell(
full_command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await process.communicate()
return {
"stdout": stdout.decode("utf-8", errors="replace"),
"stderr": stderr.decode("utf-8", errors="replace"),
"returncode": process.returncode,
}
except Exception as e:
return {
"stdout": "",
"stderr": str(e),
"returncode": -1,
}

View File

@@ -0,0 +1,255 @@
"""文件操作工具 - Phase 6.4"""
import os
from typing import Any
from app.agents.tools.base import ExternalTool, ReadTool, WriteTool
from app.agents.tools.manifest import (
PermissionClass,
SideEffectScope,
ToolCategory,
)
class GlobTool(ReadTool):
"""文件路径匹配工具
使用 glob 模式查找文件。
"""
def __init__(self, root_dir: str = "."):
super().__init__(
name="glob",
description="使用 glob 模式查找文件路径",
permission_class=PermissionClass.READ,
side_effect_scope=SideEffectScope.NONE,
tags=["file", "search", "glob"],
)
self.root_dir = root_dir
def get_parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"pattern": {
"type": "string",
"description": "Glob 模式,如 **/*.py",
},
"root_dir": {
"type": "string",
"description": "搜索根目录(可选)",
},
},
"required": ["pattern"],
}
def get_return_schema(self) -> dict[str, Any]:
return {
"type": "array",
"items": {"type": "string"},
}
async def execute(self, pattern: str, root_dir: str | None = None) -> list[str]:
import glob as glob_module
root = root_dir or self.root_dir
return glob_module.glob(pattern, root_dir=root, recursive=True)
class GrepTool(ReadTool):
"""文件内容搜索工具
在文件中搜索匹配的行。
"""
def __init__(self):
super().__init__(
name="grep",
description="在文件中搜索匹配的文本行",
permission_class=PermissionClass.READ,
side_effect_scope=SideEffectScope.NONE,
tags=["file", "search", "text"],
)
def get_parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"pattern": {
"type": "string",
"description": "正则表达式模式",
},
"paths": {
"type": "array",
"items": {"type": "string"},
"description": "要搜索的文件路径列表",
},
"case_sensitive": {
"type": "boolean",
"description": "是否区分大小写",
},
},
"required": ["pattern", "paths"],
}
def get_return_schema(self) -> dict[str, Any]:
return {
"type": "array",
"items": {
"type": "object",
"properties": {
"file": {"type": "string"},
"line": {"type": "integer"},
"content": {"type": "string"},
},
},
}
async def execute(
self, pattern: str, paths: list[str], case_sensitive: bool = True
) -> list[dict[str, Any]]:
import re
flags = 0 if case_sensitive else re.IGNORECASE
regex = re.compile(pattern, flags)
results = []
for path in paths:
if not os.path.isfile(path):
continue
try:
with open(path, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
if regex.search(line):
results.append(
{
"file": path,
"line": line_num,
"content": line.rstrip(),
}
)
except (UnicodeDecodeError, PermissionError):
continue
return results
class ReadFileTool(ReadTool):
"""文件读取工具"""
def __init__(self):
super().__init__(
name="read_file",
description="读取文件内容",
permission_class=PermissionClass.READ,
side_effect_scope=SideEffectScope.NONE,
tags=["file", "read"],
)
def get_parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "文件路径",
},
"limit": {
"type": "integer",
"description": "最大行数",
},
"offset": {
"type": "integer",
"description": "起始行号",
},
},
"required": ["path"],
}
def get_return_schema(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"content": {"type": "string"},
"lines": {"type": "integer"},
},
}
async def execute(self, path: str, limit: int | None = None, offset: int = 0) -> dict[str, Any]:
if not os.path.isfile(path):
raise FileNotFoundError(f"File not found: {path}")
with open(path, "r", encoding="utf-8") as f:
lines = f.readlines()
total_lines = len(lines)
start = max(0, offset)
end = len(lines) if limit is None else min(start + limit, len(lines))
content = "".join(lines[start:end])
return {
"content": content,
"lines": total_lines,
"truncated": limit is not None and end < len(lines),
}
class WriteFileTool(WriteTool):
"""文件写入工具"""
def __init__(self):
super().__init__(
name="write_file",
description="写入文件内容",
permission_class=PermissionClass.WRITE,
side_effect_scope=SideEffectScope.LOCAL_STATE,
requires_confirmation=True,
tags=["file", "write"],
)
def get_parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "文件路径",
},
"content": {
"type": "string",
"description": "文件内容",
},
"append": {
"type": "boolean",
"description": "是否追加模式",
},
},
"required": ["path", "content"],
}
def get_return_schema(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"success": {"type": "boolean"},
"bytes_written": {"type": "integer"},
},
}
async def execute(self, path: str, content: str, append: bool = False) -> dict[str, Any]:
mode = "a" if append else "w"
# 确保目录存在
directory = os.path.dirname(path)
if directory and not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
with open(path, mode, encoding="utf-8") as f:
bytes_written = f.write(content)
return {
"success": True,
"bytes_written": bytes_written,
}

View File

@@ -0,0 +1,193 @@
"""系统工具 - Phase 6.4"""
import asyncio
import shlex
from typing import Any
from app.agents.tools.base import ExternalTool
from app.agents.tools.manifest import (
PermissionClass,
SideEffectScope,
)
class BashTool(ExternalTool):
"""Bash 命令执行工具"""
def __init__(self, working_dir: str = "."):
super().__init__(
name="bash",
description="执行 Bash 命令",
permission_class=PermissionClass.EXTERNAL,
side_effect_scope=SideEffectScope.LOCAL_STATE,
requires_confirmation=True,
tags=["system", "bash", "shell"],
)
self.working_dir = working_dir
def get_parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "要执行的 Bash 命令",
},
"timeout": {
"type": "integer",
"description": "超时时间(秒)",
},
"working_dir": {
"type": "string",
"description": "工作目录(可选)",
},
},
"required": ["command"],
}
def get_return_schema(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"stdout": {"type": "string"},
"stderr": {"type": "string"},
"returncode": {"type": "integer"},
},
}
async def execute(
self, command: str, timeout: int = 30, working_dir: str | None = None
) -> dict[str, Any]:
import os
cwd = working_dir or self.working_dir
try:
process = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=cwd,
)
try:
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
except asyncio.TimeoutError:
process.kill()
await process.wait()
return {
"stdout": "",
"stderr": f"Command timed out after {timeout} seconds",
"returncode": -1,
}
return {
"stdout": stdout.decode("utf-8", errors="replace"),
"stderr": stderr.decode("utf-8", errors="replace"),
"returncode": process.returncode,
}
except Exception as e:
return {
"stdout": "",
"stderr": str(e),
"returncode": -1,
}
class PowerShellTool(ExternalTool):
"""PowerShell 命令执行工具"""
def __init__(self, working_dir: str = "."):
super().__init__(
name="powershell",
description="执行 PowerShell 命令",
permission_class=PermissionClass.EXTERNAL,
side_effect_scope=SideEffectScope.LOCAL_STATE,
requires_confirmation=True,
tags=["system", "powershell", "shell"],
)
self.working_dir = working_dir
def get_parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "要执行的 PowerShell 命令",
},
"timeout": {
"type": "integer",
"description": "超时时间(秒)",
},
"working_dir": {
"type": "string",
"description": "工作目录(可选)",
},
},
"required": ["command"],
}
def get_return_schema(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"stdout": {"type": "string"},
"stderr": {"type": "string"},
"returncode": {"type": "integer"},
},
}
async def execute(
self, command: str, timeout: int = 30, working_dir: str | None = None
) -> dict[str, Any]:
import platform
# 检测是否是 Windows 平台
is_windows = platform.system() == "Windows"
if not is_windows:
# 非 Windows 平台,可能没有 PowerShell
return {
"stdout": "",
"stderr": "PowerShell is not available on this platform",
"returncode": -1,
}
cwd = working_dir or self.working_dir
try:
process = await asyncio.create_subprocess_exec(
"powershell.exe",
"-NoProfile",
"-Command",
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=cwd,
)
try:
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
except asyncio.TimeoutError:
process.kill()
await process.wait()
return {
"stdout": "",
"stderr": f"Command timed out after {timeout} seconds",
"returncode": -1,
}
return {
"stdout": stdout.decode("utf-8", errors="replace"),
"stderr": stderr.decode("utf-8", errors="replace"),
"returncode": process.returncode,
}
except Exception as e:
return {
"stdout": "",
"stderr": str(e),
"returncode": -1,
}

View File

@@ -4,19 +4,12 @@ from langchain_core.tools import tool
from app.database import async_session
from app.models.forum import ForumPost, ForumReply
from app.agents.context import get_current_user
from app.agents.tools.async_bridge import run_async
from sqlalchemy import select
import asyncio
from concurrent.futures import ThreadPoolExecutor
_executor = ThreadPoolExecutor(max_workers=4)
def _run_async(coro, timeout: int = 30):
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
return run_async(coro, timeout=timeout)
@tool

View File

@@ -0,0 +1,46 @@
"""Hook 系统 - Phase 6.2"""
from app.agents.tools.hooks.types import (
HookDefinition,
HookResult,
HookStage,
HookTrigger,
HookType,
ExecutionContext,
HookHandler,
PreToolHook,
PostToolHook,
ErrorToolHook,
SkipToolHook,
)
from app.agents.tools.hooks.manager import (
HookManager,
get_hook_manager,
reset_hook_manager,
)
from app.agents.tools.hooks.executor import (
HookExecutor,
get_hook_executor,
)
__all__ = [
# Types
"HookType",
"HookStage",
"HookTrigger",
"HookDefinition",
"HookResult",
"ExecutionContext",
"HookHandler",
"PreToolHook",
"PostToolHook",
"ErrorToolHook",
"SkipToolHook",
# Manager
"HookManager",
"get_hook_manager",
"reset_hook_manager",
# Executor
"HookExecutor",
"get_hook_executor",
]

View File

@@ -0,0 +1,11 @@
"""内置 Hook 集合 - Phase 7"""
from app.agents.tools.hooks.builtins.audit_log import AuditLogHook
from app.agents.tools.hooks.builtins.dangerous_confirmation import DangerousConfirmationHook
from app.agents.tools.hooks.builtins.security_scan import SecurityScanHook
__all__ = [
"AuditLogHook",
"DangerousConfirmationHook",
"SecurityScanHook",
]

View File

@@ -0,0 +1,115 @@
"""审计日志 Hook - Phase 7.2
记录所有工具调用到审计日志。
"""
from typing import Any
from app.agents.tools.hooks.types import (
ExecutionContext,
HookResult,
HookType,
)
from app.agents.tools.manifest import ToolCategory
class AuditLogHook:
"""审计日志 Hook
记录所有工具调用的详细信息,包括:
- 调用时间
- 工具名称
- 输入参数
- 执行结果
- 执行时长
- 用户 ID
"""
def __init__(self, log_path: str | None = None):
"""
Args:
log_path: 日志文件路径None 则输出到 stdout
"""
self.log_path = log_path
self._logs: list[dict[str, Any]] = []
async def pre_tool_use(self, context: ExecutionContext) -> HookResult:
"""工具执行前记录"""
log_entry = {
"event": "pre_tool",
"tool_name": context.tool_name,
"input": context.tool_input,
"user_id": context.user_id,
"session_id": context.session_id,
}
self._logs.append(log_entry)
self._write_log(log_entry)
return HookResult(
hook_name="audit_log",
success=True,
continue_execution=True,
)
async def post_tool_use(self, context: ExecutionContext, result: Any) -> HookResult:
"""工具执行后记录"""
log_entry = {
"event": "post_tool",
"tool_name": context.tool_name,
"result": str(result)[:500] if result else None,
"duration_ms": (
(context.end_time - context.start_time) * 1000
if context.start_time and context.end_time
else None
),
}
self._logs.append(log_entry)
self._write_log(log_entry)
return HookResult(
hook_name="audit_log",
success=True,
continue_execution=True,
modified_output=result,
)
async def tool_error(self, context: ExecutionContext, error: Exception) -> HookResult:
"""工具出错时记录"""
log_entry = {
"event": "tool_error",
"tool_name": context.tool_name,
"error": str(error),
"error_type": type(error).__name__,
}
self._logs.append(log_entry)
self._write_log(log_entry)
return HookResult(
hook_name="audit_log",
success=False,
continue_execution=True,
error=str(error),
)
def _write_log(self, entry: dict[str, Any]) -> None:
"""写入日志"""
import json
import datetime
entry["timestamp"] = datetime.datetime.now().isoformat()
if self.log_path:
try:
with open(self.log_path, "a", encoding="utf-8") as f:
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
except Exception:
# 日志写入失败不影响主流程
pass
else:
# 输出到 stdout
print(f"[AUDIT] {json.dumps(entry, ensure_ascii=False)}")
def get_logs(self) -> list[dict[str, Any]]:
"""获取所有日志"""
return self._logs.copy()
def clear_logs(self) -> None:
"""清空日志"""
self._logs.clear()

View File

@@ -0,0 +1,142 @@
"""危险操作确认 Hook - Phase 7.2
对危险操作要求用户确认。
"""
from typing import Any
from app.agents.tools.hooks.types import (
ExecutionContext,
HookResult,
)
from app.agents.tools.manifest import SideEffectScope
# 危险操作关键词
DANGEROUS_PATTERNS = [
# 文件操作
"delete",
"remove",
"rm ",
"rmdir",
"unlink",
"format",
"truncate",
# 系统操作
"shutdown",
"reboot",
"kill",
"pkill",
"sudo",
"chmod",
"chown",
# 数据操作
"drop",
"truncate",
"delete from",
"delete.*where",
"insert into.*select",
"update.*set",
# 网络操作
"curl",
"wget",
"nc ",
"netcat",
"ssh ",
"scp ",
"sftp ",
# 环境变量
"export.*secret",
"export.*key",
"export.*token",
]
class DangerousConfirmationHook:
"""危险操作确认 Hook
检查工具调用是否包含危险操作,如是则要求确认。
"""
def __init__(self, auto_block: bool = False):
"""
Args:
auto_block: True 表示自动拦截危险操作False 表示仅警告
"""
self.auto_block = auto_block
self._pending_confirmations: dict[str, bool] = {}
async def pre_tool_use(self, context: ExecutionContext) -> HookResult:
"""检查是否为危险操作"""
is_dangerous = self._check_dangerous(context.tool_name, context.tool_input)
if is_dangerous:
if self.auto_block:
return HookResult(
hook_name="dangerous_confirmation",
success=False,
continue_execution=False,
error=f"危险操作被自动拦截: {context.tool_name}",
metadata={"dangerous": True, "auto_blocked": True},
)
else:
# 标记需要确认
context.metadata["requires_confirmation"] = True
context.metadata["dangerous_operation"] = True
return HookResult(
hook_name="dangerous_confirmation",
success=True,
continue_execution=True,
metadata={"dangerous": True, "requires_confirmation": True},
)
return HookResult(
hook_name="dangerous_confirmation",
success=True,
continue_execution=True,
)
def _check_dangerous(self, tool_name: str, tool_input: dict[str, Any]) -> bool:
"""检查是否为危险操作"""
# 检查工具名称
dangerous_tools = [
"delete",
"remove",
"drop",
"truncate",
"kill",
"shutdown",
"reboot",
"bash",
"powershell",
"shell",
]
if tool_name.lower() in dangerous_tools:
return True
# 检查输入参数
input_str = str(tool_input).lower()
for pattern in DANGEROUS_PATTERNS:
if pattern.lower() in input_str:
return True
return False
def confirm(self, session_id: str, confirmed: bool) -> None:
"""确认危险操作
Args:
session_id: 会话 ID
confirmed: True 表示用户确认False 表示取消
"""
self._pending_confirmations[session_id] = confirmed
def is_confirmed(self, session_id: str) -> bool:
"""检查是否已确认"""
return self._pending_confirmations.get(session_id, False)
def clear_confirmation(self, session_id: str) -> None:
"""清除确认状态"""
self._pending_confirmations.pop(session_id, None)

View File

@@ -0,0 +1,183 @@
"""安全扫描 Hook - Phase 7.2
扫描工具调用和结果中的敏感信息。
"""
import re
from typing import Any
from app.agents.tools.hooks.types import (
ExecutionContext,
HookResult,
)
# 敏感信息模式
SENSITIVE_PATTERNS = {
"api_key": [
r"api[_-]?key['\"]?\s*[:=]\s*['\"]?[a-zA-Z0-9_\-]{20,}",
r"apikey['\"]?\s*[:=]\s*['\"]?[a-zA-Z0-9_\-]{20,}",
],
"password": [
r"password['\"]?\s*[:=]\s*['\"]?[^\s'\"]{8,}",
r"passwd['\"]?\s*[:=]\s*['\"]?[^\s'\"]{8,}",
r"secret['\"]?\s*[:=]\s*['\"]?[a-zA-Z0-9_\-]{20,}",
],
"token": [
r"token['\"]?\s*[:=]\s*['\"]?[a-zA-Z0-9_\-\.]{20,}",
r"bearer\s+[a-zA-Z0-9_\-\.]+",
r"ghp_[a-zA-Z0-9]{36}",
r"sk-[a-zA-Z0-9]{48}",
],
"private_key": [
r"-----BEGIN (RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----",
r"-----END (RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----",
],
"ip_address": [
r"\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b",
],
"email": [
r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
],
}
class SecurityScanHook:
"""安全扫描 Hook
扫描工具输入和输出中的敏感信息,进行脱敏处理。
"""
def __init__(
self,
redact: bool = True,
block_on_detect: bool = False,
):
"""
Args:
redact: 是否对敏感信息进行脱敏
block_on_detect: 检测到敏感信息时是否阻止执行
"""
self.redact = redact
self.block_on_detect = block_on_detect
self._compiled_patterns = {
name: [re.compile(p, re.IGNORECASE) for p in patterns]
for name, patterns in SENSITIVE_PATTERNS.items()
}
async def pre_tool_use(self, context: ExecutionContext) -> HookResult:
"""扫描输入参数"""
detected = self._scan_dict(context.tool_input)
if detected:
context.metadata["security_detected"] = detected
if self.block_on_detect:
return HookResult(
hook_name="security_scan",
success=False,
continue_execution=False,
error=f"检测到敏感信息: {', '.join(detected.keys())}",
metadata={"detected": detected, "blocked": True},
)
if self.redact:
redacted_input = self._redact_dict(context.tool_input.copy())
return HookResult(
hook_name="security_scan",
success=True,
continue_execution=True,
modified_input=redacted_input,
metadata={"detected": detected, "redacted": True},
)
return HookResult(
hook_name="security_scan",
success=True,
continue_execution=True,
)
async def post_tool_use(self, context: ExecutionContext, result: Any) -> HookResult:
"""扫描输出结果"""
if isinstance(result, dict):
detected = self._scan_dict(result)
if detected:
context.metadata["security_detected_output"] = detected
if self.redact:
redacted_result = self._redact_dict(result.copy())
return HookResult(
hook_name="security_scan",
success=True,
continue_execution=True,
modified_output=redacted_result,
metadata={"detected": detected, "redacted": True},
)
elif isinstance(result, str):
detected = self._scan_string(result)
if detected:
context.metadata["security_detected_output"] = detected
if self.redact:
redacted_result = self._redact_string(result)
return HookResult(
hook_name="security_scan",
success=True,
continue_execution=True,
modified_output=redacted_result,
metadata={"detected": detected, "redacted": True},
)
return HookResult(
hook_name="security_scan",
success=True,
continue_execution=True,
modified_output=result,
)
def _scan_dict(self, data: dict[str, Any]) -> dict[str, list[str]]:
"""扫描字典中的敏感信息"""
result: dict[str, list[str]] = {}
for key, value in data.items():
if isinstance(value, str):
found = self._scan_string(value)
if found:
result[key] = found
return result
def _scan_string(self, text: str) -> list[str]:
"""扫描字符串中的敏感信息"""
found_types = []
for name, patterns in self._compiled_patterns.items():
for pattern in patterns:
if pattern.search(text):
if name not in found_types:
found_types.append(name)
break
return found_types
def _redact_dict(self, data: dict[str, Any]) -> dict[str, Any]:
"""脱敏字典中的敏感信息"""
for key, value in data.items():
if isinstance(value, str):
data[key] = self._redact_string(value)
elif isinstance(value, dict):
data[key] = self._redact_dict(value)
elif isinstance(value, list):
data[key] = [self._redact_string(v) if isinstance(v, str) else v for v in value]
return data
def _redact_string(self, text: str) -> str:
"""脱敏字符串中的敏感信息"""
for name, patterns in self._compiled_patterns.items():
for pattern in patterns:
text = pattern.sub(f"[REDACTED:{name}]", text)
return text

View File

@@ -0,0 +1,105 @@
"""Hook 配置持久化 - Phase 7.3"""
import json
import os
from dataclasses import asdict, dataclass
from typing import Any
from app.agents.tools.hooks.manager import get_hook_manager
@dataclass
class HookConfigEntry:
"""Hook 配置条目"""
name: str
hook_type: str
enabled: bool
tool_names: list[str] | None = None
categories: list[str] | None = None
priority: int = 0
class HookConfigPersistence:
"""Hook 配置持久化"""
def __init__(self, config_path: str | None = None):
"""
Args:
config_path: 配置文件路径None 则使用默认路径
"""
if config_path is None:
config_path = os.path.join(
os.path.dirname(__file__), "..", "..", "..", "..", "config", "hooks.json"
)
self.config_path = config_path
def load_config(self) -> list[HookConfigEntry]:
"""从文件加载 Hook 配置"""
if not os.path.exists(self.config_path):
return []
try:
with open(self.config_path, "r", encoding="utf-8") as f:
data = json.load(f)
return [HookConfigEntry(**entry) for entry in data]
except Exception:
return []
def save_config(self, entries: list[HookConfigEntry]) -> bool:
"""保存 Hook 配置到文件"""
try:
os.makedirs(os.path.dirname(self.config_path), exist_ok=True)
with open(self.config_path, "w", encoding="utf-8") as f:
json.dump([asdict(e) for e in entries], f, indent=2, ensure_ascii=False)
return True
except Exception:
return False
def apply_config(self) -> int:
"""应用配置到 HookManager
Returns:
应用的 Hook 数量
"""
from app.agents.tools.hooks.types import HookType
manager = get_hook_manager()
entries = self.load_config()
count = 0
for entry in entries:
if entry.enabled:
from app.agents.tools.hooks.types import HookDefinition, HookTrigger
trigger = HookTrigger(
tool_names=entry.tool_names,
categories=entry.categories,
)
# 创建空的 handler只是注册配置
hook_def = HookDefinition(
name=entry.name,
hook_type=HookType(entry.hook_type),
trigger=trigger,
handler=lambda ctx, *args: ctx,
priority=entry.priority,
enabled=True,
)
manager.register(hook_def)
count += 1
return count
# 全局单例
_persistence: HookConfigPersistence | None = None
def get_hook_config_persistence() -> HookConfigPersistence:
"""获取全局 Hook 配置持久化实例"""
global _persistence
if _persistence is None:
_persistence = HookConfigPersistence()
return _persistence

View File

@@ -0,0 +1,5 @@
"""自定义 Hook 加载器包"""
from app.agents.tools.hooks.custom.loader import CustomHookLoader, get_custom_hook_loader
__all__ = ["CustomHookLoader", "get_custom_hook_loader"]

View File

@@ -0,0 +1,153 @@
"""自定义 Hook 加载器 - Phase 7.4
支持动态加载用户自定义的 Hook。
"""
import importlib.util
import os
from typing import Any
from app.agents.tools.hooks.types import HookDefinition, HookType, HookTrigger, HookResult
class CustomHookLoader:
"""自定义 Hook 加载器
从指定目录动态加载自定义 Hook 模块。
"""
def __init__(self, hooks_dir: str | None = None):
"""
Args:
hooks_dir: Hook 目录None 则使用默认目录
"""
if hooks_dir is None:
hooks_dir = os.path.join(
os.path.dirname(__file__), "..", "..", "..", "data", "custom_hooks"
)
self.hooks_dir = hooks_dir
self._loaded_hooks: dict[str, HookDefinition] = {}
def load_all(self) -> list[HookDefinition]:
"""加载所有自定义 Hook
Returns:
Hook 定义列表
"""
hooks = []
if not os.path.exists(self.hooks_dir):
return hooks
for filename in os.listdir(self.hooks_dir):
if filename.endswith(".py") and not filename.startswith("_"):
hook_path = os.path.join(self.hooks_dir, filename)
hook_def = self._load_hook_from_file(hook_path, filename[:-3])
if hook_def:
hooks.append(hook_def)
self._loaded_hooks[hook_def.name] = hook_def
return hooks
def _load_hook_from_file(self, hook_path: str, module_name: str) -> HookDefinition | None:
"""从文件加载 Hook
Args:
hook_path: Hook 文件路径
module_name: 模块名
Returns:
Hook 定义或 None
"""
try:
spec = importlib.util.spec_from_file_location(module_name, hook_path)
if not spec or not spec.loader:
return None
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# 查找 HOOK_DEFINITION 或 hook_definition
hook_def = getattr(module, "HOOK_DEFINITION", None) or getattr(
module, "hook_definition", None
)
if hook_def and isinstance(hook_def, HookDefinition):
return hook_def
# 如果没有定义,尝试从函数自动推断
if hasattr(module, "pre_tool_hook") or hasattr(module, "post_tool_hook"):
return self._infer_hook_definition(module, module_name)
except Exception:
pass
return None
def _infer_hook_definition(self, module: Any, module_name: str) -> HookDefinition | None:
"""从模块函数推断 Hook 定义
Args:
module: 模块对象
module_name: 模块名
Returns:
Hook 定义或 None
"""
hook_type = None
handler = None
if hasattr(module, "pre_tool_hook"):
handler = module.pre_tool_hook
hook_type = HookType.PRE_TOOL_USE
elif hasattr(module, "post_tool_hook"):
handler = module.post_tool_hook
hook_type = HookType.POST_TOOL_USE
elif hasattr(module, "error_tool_hook"):
handler = module.error_tool_hook
hook_type = HookType.TOOL_ERROR
if not handler or not hook_type:
return None
return HookDefinition(
name=module_name,
hook_type=hook_type,
trigger=HookTrigger(),
handler=handler,
priority=0,
enabled=True,
description=f"Auto-loaded hook from {module_name}",
)
def get_hook(self, name: str) -> HookDefinition | None:
"""获取已加载的 Hook
Args:
name: Hook 名称
Returns:
Hook 定义或 None
"""
return self._loaded_hooks.get(name)
def reload(self) -> list[HookDefinition]:
"""重新加载所有 Hook
Returns:
重新加载的 Hook 列表
"""
self._loaded_hooks.clear()
return self.load_all()
# 全局加载器
_loader: CustomHookLoader | None = None
def get_custom_hook_loader() -> CustomHookLoader:
"""获取全局自定义 Hook 加载器"""
global _loader
if _loader is None:
_loader = CustomHookLoader()
return _loader

View File

@@ -0,0 +1,170 @@
"""Hook 执行器 - Phase 6.2
执行 Hook 拦截逻辑。
"""
import time
from typing import Any
from app.agents.tools.hooks.manager import get_hook_manager
from app.agents.tools.hooks.types import (
HookDefinition,
HookResult,
HookType,
ExecutionContext,
)
class HookExecutor:
"""Hook 执行器
负责在工具执行前后执行 Hook 逻辑。
"""
def __init__(self):
self._manager = get_hook_manager()
async def execute_pre_hooks(
self, context: ExecutionContext
) -> tuple[bool, dict[str, Any] | None]:
"""执行 pre-tool Hook
Args:
context: 执行上下文
Returns:
(是否继续执行, 修改后的输入)
"""
hooks = self._manager.get_hooks(HookType.PRE_TOOL_USE, context.tool_name)
modified_input = context.tool_input
for hook in hooks:
try:
# 调用 hook handler
handler = hook.handler
if callable(handler):
result = await self._call_hook(handler, context)
if result and not result.continue_execution:
# Hook 决定中断执行
return False, modified_input
if result.modified_input is not None:
modified_input = result.modified_input
except Exception as e:
# Hook 出错,默认继续执行
pass
return True, modified_input
async def execute_post_hooks(self, context: ExecutionContext, result: Any) -> Any:
"""执行 post-tool Hook
Args:
context: 执行上下文
result: 工具执行结果
Returns:
修改后的结果
"""
hooks = self._manager.get_hooks(HookType.POST_TOOL_USE, context.tool_name)
modified_result = result
for hook in hooks:
try:
handler = hook.handler
if callable(handler):
hook_result = await self._call_hook(handler, context, modified_result)
if hook_result and hook_result.modified_output is not None:
modified_result = hook_result.modified_output
except Exception:
# Hook 出错,默认保留原结果
pass
return modified_result
async def execute_error_hooks(
self, context: ExecutionContext, error: Exception
) -> HookResult | None:
"""执行 error Hook
Args:
context: 执行上下文
error: 异常
Returns:
Hook 结果,如果返回 None 则继续传播错误
"""
hooks = self._manager.get_hooks(HookType.TOOL_ERROR, context.tool_name)
for hook in hooks:
try:
handler = hook.handler
if callable(handler):
result = await self._call_hook(handler, context, error)
if result is not None and result.continue_execution:
return result
except Exception:
# Hook 出错,继续执行其他 error hooks
pass
return None
async def execute_skip_check(self, context: ExecutionContext) -> bool:
"""检查是否应跳过工具执行
Args:
context: 执行上下文
Returns:
True 表示跳过False 表示执行
"""
hooks = self._manager.get_hooks(HookType.TOOL_SKIP, context.tool_name)
for hook in hooks:
try:
handler = hook.handler
if callable(handler):
result = await self._call_hook(handler, context)
if result is not None and isinstance(result, bool):
return result
except Exception:
# Hook 出错,默认不跳过
pass
return False
async def _call_hook(
self, handler: Any, context: ExecutionContext, *args: Any
) -> HookResult | None:
"""调用 Hook 处理函数
Args:
handler: Hook 处理函数
context: 执行上下文
*args: 额外参数
Returns:
Hook 结果
"""
import asyncio
# 如果是普通函数,直接调用
if asyncio.iscoroutinefunction(handler):
return await handler(context, *args)
else:
return handler(context, *args)
# 全局单例
_executor: HookExecutor | None = None
def get_hook_executor() -> HookExecutor:
"""获取全局 Hook 执行器
Returns:
全局 HookExecutor 实例
"""
global _executor
if _executor is None:
_executor = HookExecutor()
return _executor

View File

@@ -0,0 +1,174 @@
"""Hook 管理器 - Phase 6.2
管理 Hook 的注册、查找和配置。
"""
from typing import Any
from app.agents.tools.hooks.types import (
HookDefinition,
HookResult,
HookTrigger,
HookType,
ExecutionContext,
)
class HookManager:
"""Hook 管理器
管理全局 Hook 的注册和配置。
"""
def __init__(self):
self._hooks: dict[HookType, list[HookDefinition]] = {
HookType.PRE_TOOL_USE: [],
HookType.POST_TOOL_USE: [],
HookType.TOOL_ERROR: [],
HookType.TOOL_SKIP: [],
}
self._global_hooks: list[HookDefinition] = [] # 全局 Hook对所有工具生效
def register(self, definition: HookDefinition) -> None:
"""注册 Hook
Args:
definition: Hook 定义
"""
if definition.trigger.tool_names is None and definition.trigger.categories is None:
# 全局 Hook
self._global_hooks.append(definition)
else:
# 特定工具 Hook
self._hooks[definition.hook_type].append(definition)
# 按优先级排序
self._hooks[definition.hook_type].sort(key=lambda h: h.priority, reverse=True)
self._global_hooks.sort(key=lambda h: h.priority, reverse=True)
def unregister(self, name: str) -> bool:
"""注销 Hook
Args:
name: Hook 名称
Returns:
是否成功注销
"""
# 从特定工具 Hook 中移除
for hooks in self._hooks.values():
for i, hook in enumerate(hooks):
if hook.name == name:
hooks.pop(i)
return True
# 从全局 Hook 中移除
for i, hook in enumerate(self._global_hooks):
if hook.name == name:
self._global_hooks.pop(i)
return True
return False
def get_hooks(self, hook_type: HookType, tool_name: str | None = None) -> list[HookDefinition]:
"""获取指定类型和工具的 Hook
Args:
hook_type: Hook 类型
tool_name: 工具名称(可选)
Returns:
匹配的 Hook 列表
"""
result: list[HookDefinition] = []
# 添加全局 Hook
for hook in self._global_hooks:
if hook.hook_type == hook_type and hook.enabled:
result.append(hook)
# 添加特定工具 Hook
for hook in self._hooks[hook_type]:
if not hook.enabled:
continue
if hook.trigger.tool_names is None and hook.trigger.categories is None:
continue
# 检查是否匹配
if hook.trigger.tool_names and tool_name not in hook.trigger.tool_names:
continue
result.append(hook)
return result
def list_all(self) -> list[HookDefinition]:
"""列出所有已注册的 Hook
Returns:
Hook 列表
"""
all_hooks = list(self._global_hooks)
for hooks in self._hooks.values():
all_hooks.extend(hooks)
return all_hooks
def enable(self, name: str) -> bool:
"""启用 Hook
Args:
name: Hook 名称
Returns:
是否成功启用
"""
for hook in self.list_all():
if hook.name == name:
hook.enabled = True
return True
return False
def disable(self, name: str) -> bool:
"""禁用 Hook
Args:
name: Hook 名称
Returns:
是否成功禁用
"""
for hook in self.list_all():
if hook.name == name:
hook.enabled = False
return True
return False
def clear(self) -> None:
"""清除所有 Hook"""
self._hooks = {ht: [] for ht in HookType}
self._global_hooks = []
# 全局单例
_global_hook_manager: HookManager | None = None
def get_hook_manager() -> HookManager:
"""获取全局 Hook 管理器
Returns:
全局 HookManager 实例
"""
global _global_hook_manager
if _global_hook_manager is None:
_global_hook_manager = HookManager()
return _global_hook_manager
def reset_hook_manager() -> None:
"""重置全局 Hook 管理器(用于测试)"""
global _global_hook_manager
if _global_hook_manager is not None:
_global_hook_manager.clear()
_global_hook_manager = None

View File

@@ -0,0 +1,90 @@
"""Hook 类型定义 - Phase 6.2
Hook 拦截系统类型定义。
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable
class HookType(Enum):
"""Hook 类型"""
PRE_TOOL_USE = "pre_tool_use" # 工具执行前
POST_TOOL_USE = "post_tool_use" # 工具执行后
TOOL_ERROR = "tool_error" # 工具执行出错
TOOL_SKIP = "tool_skip" # 工具跳过(条件执行)
class HookStage(Enum):
"""Hook 执行阶段"""
BEFORE = "before"
AFTER = "after"
ON_ERROR = "on_error"
@dataclass
class HookTrigger:
"""Hook 触发条件"""
tool_names: list[str] | None = None # 只对特定工具生效None 表示全部
categories: list[str] | None = None # 只对特定类别生效
conditions: dict[str, Any] | None = None # 自定义条件
@dataclass
class HookDefinition:
"""Hook 定义"""
name: str
hook_type: HookType
trigger: HookTrigger
handler: Callable[..., Any] # Hook 处理函数
priority: int = 0 # 优先级,数字越大越先执行
enabled: bool = True
description: str = ""
@dataclass
class HookResult:
"""Hook 执行结果"""
hook_name: str
success: bool
continue_execution: bool = True # False 表示中断执行
modified_input: Any = None # 修改后的输入
modified_output: Any = None # 修改后的输出
error: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class ExecutionContext:
"""工具执行上下文"""
tool_name: str
tool_input: dict[str, Any]
user_id: str | None = None
session_id: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
# 执行结果(由 HookExecutor 填充)
result: Any = None
error: Exception | None = None
start_time: float | None = None
end_time: float | None = None
# Hook 处理函数类型
HookHandler = Callable[[ExecutionContext, HookDefinition], HookResult]
# Pre-hook: 在工具执行前调用,可以修改输入或决定是否跳过
PreToolHook = Callable[[ExecutionContext], tuple[bool, dict[str, Any] | None]]
# post-hook: 在工具执行后调用,可以修改输出
PostToolHook = Callable[[ExecutionContext, Any], Any]
# Error hook: 在工具出错时调用
ErrorToolHook = Callable[[ExecutionContext, Exception], HookResult | None]
# Skip hook: 决定是否跳过工具执行
SkipToolHook = Callable[[ExecutionContext], bool]

View File

@@ -0,0 +1,77 @@
"""工具元数据和数据类型定义"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
class ToolCategory(Enum):
"""工具类别"""
READ = "read"
WRITE = "write"
EXTERNAL = "external"
DB_WRITE = "db_write"
NETWORK = "network"
class SideEffectScope(Enum):
"""副作用范围"""
NONE = "none"
LOCAL_STATE = "local_state"
DB_WRITE = "db_write"
NETWORK = "network"
class PermissionClass(Enum):
"""权限级别"""
READ = "read"
WRITE = "write"
EXTERNAL = "external"
@dataclass
class ToolManifest:
"""工具元数据"""
name: str
description: str
category: ToolCategory
parameters: dict[str, Any] # JSON Schema
return_schema: dict[str, Any]
permission_class: PermissionClass
side_effect_scope: SideEffectScope
requires_confirmation: bool = False
is_streaming: bool = False
tags: list[str] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
return {
"name": self.name,
"description": self.description,
"category": self.category.value,
"parameters": self.parameters,
"return_schema": self.return_schema,
"permission_class": self.permission_class.value,
"side_effect_scope": self.side_effect_scope.value,
"requires_confirmation": self.requires_confirmation,
"is_streaming": self.is_streaming,
"tags": self.tags,
}
@dataclass
class HookConfig:
"""Hook 配置"""
name: str
hook_type: str # "pre_tool_use", "post_tool_use", "tool_error", "tool_skip"
filter_names: list[str] | None = None # 只对特定工具生效None 表示全部
def matches_tool(self, tool_name: str) -> bool:
"""检查 Hook 是否对指定工具生效"""
if self.filter_names is None:
return True
return tool_name in self.filter_names

View File

@@ -0,0 +1,251 @@
"""工具迁移和向后兼容层 - Phase 6.1
将现有 @tool 装饰的工具迁移到 ToolRegistry同时保持向后兼容。
"""
from functools import wraps
from typing import Any, Callable
from app.agents.tools.manifest import (
PermissionClass,
SideEffectScope,
ToolCategory,
ToolManifest,
)
from app.agents.tools.registry import get_tool_registry
# 现有工具的类别映射
_TOOL_CATEGORY_MAP: dict[str, tuple[ToolCategory, PermissionClass, SideEffectScope]] = {
# 知识检索 - 只读
"search_knowledge": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
"get_knowledge_graph_context": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
"hybrid_search": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
"web_search": (ToolCategory.NETWORK, PermissionClass.EXTERNAL, SideEffectScope.NETWORK),
# 知识构建 - 写入
"build_knowledge_graph": (
ToolCategory.WRITE,
PermissionClass.WRITE,
SideEffectScope.LOCAL_STATE,
),
# 任务工具
"get_tasks": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
"create_task": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
"update_task_status": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
# 日程工具
"get_schedule_day": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
"create_todo": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
"create_schedule_task": (
ToolCategory.WRITE,
PermissionClass.WRITE,
SideEffectScope.LOCAL_STATE,
),
"create_reminder": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
"create_goal": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
"resolve_time_expression": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
# 论坛工具
"get_forum_posts": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
"create_forum_post": (ToolCategory.WRITE, PermissionClass.WRITE, SideEffectScope.LOCAL_STATE),
"scan_forum_for_instructions": (ToolCategory.READ, PermissionClass.READ, SideEffectScope.NONE),
}
def get_tool_category(name: str) -> tuple[ToolCategory, PermissionClass, SideEffectScope]:
"""获取工具的类别信息"""
return _TOOL_CATEGORY_MAP.get(
name,
(ToolCategory.EXTERNAL, PermissionClass.EXTERNAL, SideEffectScope.NETWORK),
)
def infer_tags_from_docstring(docstring: str | None) -> list[str]:
"""从 docstring 推断工具标签"""
if not docstring:
return []
tags = []
doc_lower = docstring.lower()
if "搜索" in docstring or "查询" in docstring or "search" in doc_lower:
tags.append("search")
if "创建" in docstring or "新建" in docstring or "create" in doc_lower:
tags.append("create")
if "获取" in docstring or "读取" in docstring or "get" in doc_lower:
tags.append("read")
if "更新" in docstring or "修改" in docstring or "update" in doc_lower:
tags.append("update")
return tags
def migrate_tool(tool_func: Callable) -> Callable:
"""将现有 @tool 装饰的函数迁移到 ToolRegistry
Args:
tool_func: LangChain @tool 装饰的函数
Returns:
原函数(已注册到 registry
"""
registry = get_tool_registry()
# 如果已经注册,跳过
if registry.get(tool_func.name):
return tool_func
# 获取类别信息
category, permission, side_effect = get_tool_category(tool_func.name)
# 从 docstring 提取 description
description = tool_func.description if hasattr(tool_func, "description") else ""
# 推断 tags
tags = infer_tags_from_docstring(description)
tags.append("migrated")
# 创建 manifest
manifest = ToolManifest(
name=tool_func.name,
description=description,
category=category,
parameters={}, # LangChain @tool 动态处理参数
return_schema={},
permission_class=permission,
side_effect_scope=side_effect,
requires_confirmation=side_effect != SideEffectScope.NONE,
is_streaming=False,
tags=tags,
)
# 注册到 registry
registry.register(manifest, tool_func)
return tool_func
def migrate_all_tools() -> int:
"""迁移所有现有工具到 ToolRegistry
Returns:
迁移的工具数量
"""
from app.agents.tools import (
ALL_TOOLS,
KNOWLEDGE_GRAPH_TOOLS,
KNOWLEDGE_RETRIEVAL_TOOLS,
SCHEDULE_READ_TOOLS,
SCHEDULE_WRITE_TOOLS,
TASK_TOOLS,
FORUM_TOOLS,
)
all_tools = (
KNOWLEDGE_RETRIEVAL_TOOLS
+ KNOWLEDGE_GRAPH_TOOLS
+ TASK_TOOLS
+ SCHEDULE_READ_TOOLS
+ SCHEDULE_WRITE_TOOLS
+ FORUM_TOOLS
)
count = 0
for tool in all_tools:
try:
migrate_tool(tool)
count += 1
except Exception as e:
print(f"Failed to migrate tool {getattr(tool, 'name', 'unknown')}: {e}")
return count
class BackwardCompatTool:
"""向后兼容工具包装器
确保现有代码通过 registry.get_executor() 仍能正常调用工具。
"""
def __init__(self, name: str):
self.name = name
self._registry = get_tool_registry()
def __call__(self, *args, **kwargs) -> Any:
executor = self._registry.get_executor(self.name)
if executor is None:
raise ValueError(f"Tool not found in registry: {self.name}")
return executor(*args, **kwargs)
def invoke(self, tool_input: dict[str, Any]) -> Any:
"""LangChain 风格的 invoke 调用"""
executor = self._registry.get_executor(self.name)
if executor is None:
raise ValueError(f"Tool not found in registry: {self.name}")
# 处理位置参数
if isinstance(tool_input, dict):
return executor(**tool_input)
return executor(tool_input)
def create_compat_layer() -> dict[str, BackwardCompatTool]:
"""创建向后兼容层
返回一个字典,允许通过名称访问兼容的工具包装器。
"""
registry = get_tool_registry()
tools = registry.list_all()
return {tool.name: BackwardCompatTool(tool.name) for tool in tools}
# 自动迁移装饰器
def auto_migrate(func: Callable) -> Callable:
"""自动迁移装饰器
用于装饰新的 @tool 函数,自动注册到 registry。
"""
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
# 迁移到 registry
migrate_tool(wrapper)
return wrapper
# 便捷函数:获取兼容的工具执行器
def get_tool_executor(name: str) -> Callable | None:
"""获取工具执行器(兼容层)
优先从 registry 获取fallback 到直接导入。
"""
registry = get_tool_registry()
executor = registry.get_executor(name)
if executor is not None:
return executor
# Fallback: 直接从模块导入(仅用于迁移期间)
try:
from app.agents.tools import (
TASK_TOOLS,
SCHEDULE_READ_TOOLS,
SCHEDULE_WRITE_TOOLS,
FORUM_TOOLS,
KNOWLEDGE_RETRIEVAL_TOOLS,
)
all_tools = (
KNOWLEDGE_RETRIEVAL_TOOLS
+ TASK_TOOLS
+ SCHEDULE_READ_TOOLS
+ SCHEDULE_WRITE_TOOLS
+ FORUM_TOOLS
)
for tool in all_tools:
if hasattr(tool, "name") and tool.name == name:
return tool
except ImportError:
pass
return None

View File

@@ -0,0 +1,206 @@
"""工具注册表 - 工具系统重构 Phase 6.1"""
from collections import defaultdict
from typing import Any, Callable
from app.agents.tools.manifest import HookConfig, ToolManifest
class ToolRegistry:
"""工具注册表
统一管理所有工具的注册、发现和调用。
支持工具元数据、权限分类、Hook 拦截。
"""
def __init__(self):
self._tools: dict[str, ToolManifest] = {}
self._executors: dict[str, Callable] = {}
self._hooks: dict[str, list[HookConfig]] = defaultdict(list)
def register(
self, manifest: ToolManifest, executor: Callable, hooks: list[HookConfig] | None = None
) -> None:
"""注册工具
Args:
manifest: 工具元数据
executor: 工具执行函数
hooks: 可选的 Hook 配置列表
"""
if manifest.name in self._tools:
raise ValueError(f"Tool already registered: {manifest.name}")
self._tools[manifest.name] = manifest
self._executors[manifest.name] = executor
if hooks:
for hook in hooks:
self._hooks[manifest.name].append(hook)
def unregister(self, name: str) -> bool:
"""注销工具
Args:
name: 工具名称
Returns:
是否成功注销
"""
if name not in self._tools:
return False
del self._tools[name]
del self._executors[name]
if name in self._hooks:
del self._hooks[name]
return True
def get(self, name: str) -> ToolManifest | None:
"""获取工具元数据
Args:
name: 工具名称
Returns:
工具元数据,不存在返回 None
"""
return self._tools.get(name)
def get_executor(self, name: str) -> Callable | None:
"""获取工具执行器
Args:
name: 工具名称
Returns:
工具执行函数,不存在返回 None
"""
return self._executors.get(name)
def get_hooks(self, name: str) -> list[HookConfig]:
"""获取工具的 Hook 配置
Args:
name: 工具名称
Returns:
Hook 配置列表
"""
return self._hooks.get(name, [])
def list_all(self) -> list[ToolManifest]:
"""列出所有已注册的工具
Returns:
工具元数据列表
"""
return list(self._tools.values())
def list_by_category(self, category: Any) -> list[ToolManifest]:
"""按类别列出工具
Args:
category: 工具类别
Returns:
该类别下的所有工具
"""
return [t for t in self._tools.values() if t.category == category]
def list_by_permission(self, permission: Any) -> list[ToolManifest]:
"""按权限级别列出工具
Args:
permission: 权限级别
Returns:
该权限级别下的所有工具
"""
return [t for t in self._tools.values() if t.permission_class == permission]
def search_by_tag(self, tag: str) -> list[ToolManifest]:
"""按标签搜索工具
Args:
tag: 标签
Returns:
包含该标签的工具
"""
return [t for t in self._tools.values() if tag in t.tags]
def search_by_name(self, keyword: str) -> list[ToolManifest]:
"""按名称关键词搜索工具
Args:
keyword: 关键词
Returns:
名称包含关键词的工具
"""
keyword = keyword.lower()
return [t for t in self._tools.values() if keyword in t.name.lower()]
def get_requires_confirmation(self, name: str) -> bool:
"""检查工具是否需要确认
Args:
name: 工具名称
Returns:
是否需要确认
"""
manifest = self._tools.get(name)
return manifest.requires_confirmation if manifest else False
def get_is_streaming(self, name: str) -> bool:
"""检查工具是否支持流式执行
Args:
name: 工具名称
Returns:
是否支持流式
"""
manifest = self._tools.get(name)
return manifest.is_streaming if manifest else False
def clear(self) -> None:
"""清空注册表"""
self._tools.clear()
self._executors.clear()
self._hooks.clear()
def __len__(self) -> int:
return len(self._tools)
def __contains__(self, name: str) -> bool:
return name in self._tools
def __iter__(self):
return iter(self._tools.values())
# 全局单例实例
_global_registry: ToolRegistry | None = None
def get_tool_registry() -> ToolRegistry:
"""获取全局工具注册表单例
Returns:
全局 ToolRegistry 实例
"""
global _global_registry
if _global_registry is None:
_global_registry = ToolRegistry()
return _global_registry
def reset_tool_registry() -> None:
"""重置全局工具注册表(用于测试)"""
global _global_registry
if _global_registry is not None:
_global_registry.clear()
_global_registry = None

View File

@@ -2,8 +2,6 @@
from __future__ import annotations
import asyncio
from concurrent.futures import ThreadPoolExecutor
from datetime import date, datetime
from zoneinfo import ZoneInfo
@@ -11,21 +9,16 @@ from langchain_core.tools import tool
from sqlalchemy import select
from app.agents.context import get_current_user
from app.agents.tools.async_bridge import run_async
from app.database import async_session
from app.models.goal import Goal, GoalStatus
from app.models.reminder import Reminder
from app.models.task import Task, TaskPriority, TaskStatus
from app.models.todo import DailyTodo, TodoSource
_executor = ThreadPoolExecutor(max_workers=4)
def _run_async(coro, timeout: int = 30):
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
return run_async(coro, timeout=timeout)
def _parse_date(value: str | None) -> date:

View File

@@ -5,25 +5,16 @@ Agent 工具集 - 知识库 & 图谱相关
由于 LangChain 工具系统是同步的,内部用 run_in_executor 处理 async 逻辑。
"""
from concurrent.futures import ThreadPoolExecutor
import asyncio
from langchain_core.tools import tool
from app.agents.context import get_current_user
from app.agents.tools.async_bridge import run_async
from app.database import async_session
_executor = ThreadPoolExecutor(max_workers=4)
def _run_async(coro, timeout: int = 30):
"""在同步上下文中运行 async 代码"""
try:
loop = asyncio.get_running_loop()
future = loop.run_in_executor(_executor, lambda: asyncio.run(coro))
return future.result(timeout=timeout)
except RuntimeError:
return asyncio.run(coro)
return run_async(coro, timeout=timeout)
@tool

View File

@@ -0,0 +1,210 @@
"""流式工具执行器 - Phase 6.3
支持流式输出的工具执行器。
"""
import asyncio
import json
from typing import Any, AsyncGenerator
from app.agents.tools.hooks.executor import get_hook_executor
from app.agents.tools.hooks.types import ExecutionContext
from app.agents.tools.registry import get_tool_registry
class StreamingToolExecutor:
"""流式工具执行器
支持:
- 普通工具的同步/异步执行
- 流式工具的流式输出
- Hook 拦截pre/post/error
"""
def __init__(self):
self._registry = get_tool_registry()
self._hook_executor = get_hook_executor()
async def execute(
self,
tool_name: str,
tool_input: dict[str, Any],
user_id: str | None = None,
) -> Any:
"""执行工具(非流式)
Args:
tool_name: 工具名称
tool_input: 工具输入参数
user_id: 用户 ID可选
Returns:
工具执行结果
"""
# 创建执行上下文
context = ExecutionContext(
tool_name=tool_name,
tool_input=tool_input,
user_id=user_id,
)
# 获取工具和执行器
manifest = self._registry.get(tool_name)
if manifest is None:
raise ValueError(f"Tool not found: {tool_name}")
executor = self._registry.get_executor(tool_name)
if executor is None:
raise ValueError(f"Executor not found for tool: {tool_name}")
# 检查是否跳过
if await self._hook_executor.execute_skip_check(context):
return {"skipped": True, "tool": tool_name}
# 执行 pre-hooks
continue_execution, modified_input = await self._hook_executor.execute_pre_hooks(context)
if not continue_execution:
return {"pre_hook_aborted": True, "tool": tool_name}
# 执行工具
try:
context.start_time = asyncio.get_event_loop().time()
# 判断是同步还是异步执行
if asyncio.iscoroutinefunction(executor):
result = await executor(**modified_input)
else:
# 同步函数在线程池中执行
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(None, lambda: executor(**modified_input))
context.result = result
context.end_time = asyncio.get_event_loop().time()
# 执行 post-hooks
result = await self._hook_executor.execute_post_hooks(context, result)
return result
except Exception as e:
context.error = e
context.end_time = asyncio.get_event_loop().time()
# 执行 error-hooks
error_result = await self._hook_executor.execute_error_hooks(context, e)
if error_result is not None:
return error_result
raise
async def execute_streaming(
self,
tool_name: str,
tool_input: dict[str, Any],
user_id: str | None = None,
) -> AsyncGenerator[dict[str, Any], None]:
"""执行流式工具
Args:
tool_name: 工具名称
tool_input: 工具输入参数
user_id: 用户 ID可选
Yields:
流式输出片段
"""
# 创建执行上下文
context = ExecutionContext(
tool_name=tool_name,
tool_input=tool_input,
user_id=user_id,
)
# 获取工具和执行器
manifest = self._registry.get(tool_name)
if manifest is None:
raise ValueError(f"Tool not found: {tool_name}")
if not manifest.is_streaming:
raise ValueError(f"Tool is not streaming: {tool_name}")
executor = self._registry.get_executor(tool_name)
if executor is None:
raise ValueError(f"Executor not found for tool: {tool_name}")
# 检查是否跳过
if await self._hook_executor.execute_skip_check(context):
yield {"type": "skipped", "tool": tool_name}
return
# 执行 pre-hooks
continue_execution, modified_input = await self._hook_executor.execute_pre_hooks(context)
if not continue_execution:
yield {"type": "pre_hook_aborted", "tool": tool_name}
return
# 执行流式工具
try:
context.start_time = asyncio.get_event_loop().time()
# 调用执行器(应该返回 AsyncGenerator
result = executor(**modified_input)
# 如果是 async generator
if asyncio.isasyncgen(result):
async for chunk in result:
yield {"type": "chunk", "data": chunk}
else:
# 普通协程
data = await result
yield {"type": "chunk", "data": data}
context.end_time = asyncio.get_event_loop().time()
yield {"type": "done", "tool": tool_name}
except Exception as e:
context.error = e
context.end_time = asyncio.get_event_loop().time()
yield {"type": "error", "error": str(e), "tool": tool_name}
# 执行 error-hooks
await self._hook_executor.execute_error_hooks(context, e)
async def execute_batch(
self,
tool_calls: list[dict[str, Any]],
user_id: str | None = None,
) -> list[Any]:
"""批量执行工具
Args:
tool_calls: 工具调用列表,每个元素包含 tool_name 和 tool_input
user_id: 用户 ID可选
Returns:
执行结果列表
"""
tasks = []
for call in tool_calls:
tool_name = call.get("tool_name") or call.get("name")
tool_input = call.get("tool_input") or call.get("input") or {}
tasks.append(self.execute(tool_name, tool_input, user_id))
return await asyncio.gather(*tasks, return_exceptions=True)
# 全局单例
_executor: StreamingToolExecutor | None = None
def get_streaming_executor() -> StreamingToolExecutor:
"""获取全局流式执行器
Returns:
全局 StreamingToolExecutor 实例
"""
global _executor
if _executor is None:
_executor = StreamingToolExecutor()
return _executor

View File

@@ -8,21 +8,13 @@ from langchain_core.tools import tool
from sqlalchemy import select
from app.agents.context import get_current_user
from app.agents.tools.async_bridge import run_async
from app.database import async_session
from app.models.task import Task, TaskPriority, TaskStatus
import asyncio
from concurrent.futures import ThreadPoolExecutor
_executor = ThreadPoolExecutor(max_workers=4)
def _run_async(coro, timeout: int = 30):
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
return run_async(coro, timeout=timeout)
def _normalize_title(title: str | None, content: str | None) -> str:

View File

@@ -241,6 +241,10 @@ def normalize_tool_time_arguments(tool_name: str, args: dict, current_datetime_c
if raw_value and not _is_iso_datetime(raw_value):
payload = resolve_time_expression_data(raw_value, current_datetime_context=current_datetime_context, prefer="datetime")
normalized["reminder_at"] = payload["resolved_datetime"]
raw_date = normalized.get("date")
if isinstance(raw_date, str) and raw_date.strip() and not _is_iso_date(raw_date):
payload = resolve_time_expression_data(raw_date, current_datetime_context=current_datetime_context, prefer="date")
normalized["date"] = payload["resolved_date"]
return normalized
if tool_name in {"create_schedule_task", "create_task"}:

View File

@@ -0,0 +1,113 @@
"""远程传输层 - Phase 10.2"""
import asyncio
import json
from typing import Any
from dataclasses import dataclass
@dataclass
class StructuredMessage:
"""结构化消息"""
type: str # response, event, tool_call, error
data: dict[str, Any]
session_id: str | None = None
class RemoteTransport:
"""远程传输层
处理与远程 Agent 的通信。
"""
def __init__(self):
self._connections: dict[str, Any] = {}
self._handlers: dict[str, Any] = {}
async def send_response(self, session_id: str, response: dict[str, Any]) -> bool:
"""发送响应
Args:
session_id: 会话 ID
response: 响应数据
Returns:
是否发送成功
"""
message = StructuredMessage(
type="response",
data=response,
session_id=session_id,
)
return await self._send(session_id, message)
async def send_event(self, session_id: str, event: dict[str, Any]) -> bool:
"""发送事件
Args:
session_id: 会话 ID
event: 事件数据
Returns:
是否发送成功
"""
message = StructuredMessage(
type="event",
data=event,
session_id=session_id,
)
return await self._send(session_id, message)
async def send_tool_call(self, session_id: str, tool_call: dict[str, Any]) -> bool:
"""发送工具调用
Args:
session_id: 会话 ID
tool_call: 工具调用数据
Returns:
是否发送成功
"""
message = StructuredMessage(
type="tool_call",
data=tool_call,
session_id=session_id,
)
return await self._send(session_id, message)
async def _send(self, session_id: str, message: StructuredMessage) -> bool:
"""内部发送方法"""
if session_id not in self._connections:
return False
try:
connection = self._connections[session_id]
if hasattr(connection, "send"):
await connection.send(json.dumps(message.__dict__))
return True
except Exception:
pass
return False
def register_handler(self, event_type: str, handler: Any) -> None:
"""注册消息处理器
Args:
event_type: 事件类型
handler: 处理函数
"""
self._handlers[event_type] = handler
async def handle_message(self, session_id: str, message: dict[str, Any]) -> None:
"""处理收到的消息
Args:
session_id: 会话 ID
message: 消息数据
"""
msg_type = message.get("type")
handler = self._handlers.get(msg_type)
if handler:
await handler(session_id, message.get("data"))

View File

@@ -0,0 +1,86 @@
"""Structured IO for typed input/output - Phase 10.2"""
from dataclasses import dataclass
from typing import Any, Generic, TypeVar
T = TypeVar("T")
@dataclass
class StructuredInput:
"""Structured input wrapper"""
skill_name: str
parameters: dict[str, Any]
metadata: dict[str, Any]
@dataclass
class StructuredOutput:
"""Structured output wrapper"""
skill_name: str
result: Any
success: bool
error: str | None = None
metadata: dict[str, Any] | None = None
class StructuredIO:
"""Handles structured input/output for agent communication"""
def parse_input(self, data: dict[str, Any]) -> StructuredInput:
"""Parse structured input from dictionary.
Args:
data: Dictionary containing skill_name, parameters, and metadata
Returns:
StructuredInput instance
Raises:
ValueError: If required fields are missing
"""
if not isinstance(data, dict):
raise ValueError("Input data must be a dictionary")
skill_name = data.get("skill_name")
if not skill_name:
raise ValueError("Missing required field: skill_name")
if not isinstance(skill_name, str):
raise ValueError("skill_name must be a string")
parameters = data.get("parameters")
if parameters is None:
raise ValueError("Missing required field: parameters")
if not isinstance(parameters, dict):
raise ValueError("parameters must be a dictionary")
metadata = data.get("metadata", {})
if not isinstance(metadata, dict):
raise ValueError("metadata must be a dictionary")
return StructuredInput(skill_name=skill_name, parameters=parameters, metadata=metadata)
def format_output(self, output: StructuredOutput) -> dict[str, Any]:
"""Format structured output to dictionary.
Args:
output: StructuredOutput instance
Returns:
Dictionary representation of the output
"""
result = {
"skill_name": output.skill_name,
"result": output.result,
"success": output.success,
}
if output.error is not None:
result["error"] = output.error
if output.metadata is not None:
result["metadata"] = output.metadata
return result

View File

@@ -0,0 +1,207 @@
"""WebSocket 连接管理 - Phase 10.2
管理 WebSocket 连接的生命周期。
"""
import asyncio
import json
from typing import Any, Callable
from dataclasses import dataclass
@dataclass
class WSConnection:
"""WebSocket 连接"""
session_id: str
websocket: Any # WebSocket 连接
user_id: str | None = None
created_at: float | None = None
last_ping: float | None = None
class WebSocketManager:
"""WebSocket 连接管理器
管理所有 WebSocket 连接的生命周期。
"""
def __init__(self, ping_interval: float = 30.0):
"""
Args:
ping_interval: 心跳间隔(秒)
"""
self._connections: dict[str, WSConnection] = {}
self._handlers: dict[str, Callable] = {}
self._ping_interval = ping_interval
self._ping_tasks: dict[str, asyncio.Task] = {}
async def connect(self, session_id: str, websocket: Any, user_id: str | None = None) -> bool:
"""建立连接
Args:
session_id: 会话 ID
websocket: WebSocket 连接
user_id: 用户 ID
Returns:
是否连接成功
"""
import time
if session_id in self._connections:
return False
conn = WSConnection(
session_id=session_id,
websocket=websocket,
user_id=user_id,
created_at=time.time(),
last_ping=time.time(),
)
self._connections[session_id] = conn
# 启动心跳
self._ping_tasks[session_id] = asyncio.create_task(self._ping_loop(session_id))
return True
async def disconnect(self, session_id: str) -> bool:
"""断开连接
Args:
session_id: 会话 ID
Returns:
是否断开成功
"""
if session_id not in self._connections:
return False
# 停止心跳
if session_id in self._ping_tasks:
self._ping_tasks[session_id].cancel()
del self._ping_tasks[session_id]
del self._connections[session_id]
return True
async def send(self, session_id: str, message: dict[str, Any]) -> bool:
"""发送消息
Args:
session_id: 会话 ID
message: 消息内容
Returns:
是否发送成功
"""
if session_id not in self._connections:
return False
try:
conn = self._connections[session_id]
await conn.websocket.send_json(message)
return True
except Exception:
return False
async def broadcast(self, message: dict[str, Any]) -> int:
"""广播消息
Args:
message: 消息内容
Returns:
发送成功的数量
"""
count = 0
for session_id in list(self._connections.keys()):
if await self.send(session_id, message):
count += 1
return count
async def _ping_loop(self, session_id: str) -> None:
"""心跳循环
Args:
session_id: 会话 ID
"""
import time
while session_id in self._connections:
await asyncio.sleep(self._ping_interval)
if session_id not in self._connections:
break
try:
conn = self._connections[session_id]
await conn.websocket.send_json({"type": "ping", "timestamp": time.time()})
conn.last_ping = time.time()
except Exception:
await self.disconnect(session_id)
break
def register_handler(self, event_type: str, handler: Callable) -> None:
"""注册消息处理器
Args:
event_type: 事件类型
handler: 处理函数
"""
self._handlers[event_type] = handler
async def handle_message(self, session_id: str, message: dict[str, Any]) -> None:
"""处理消息
Args:
session_id: 会话 ID
message: 消息内容
"""
msg_type = message.get("type")
handler = self._handlers.get(msg_type)
if handler:
await handler(session_id, message.get("data"))
def get_connection(self, session_id: str) -> WSConnection | None:
"""获取连接
Args:
session_id: 会话 ID
Returns:
连接信息或 None
"""
return self._connections.get(session_id)
def list_connections(self) -> list[WSConnection]:
"""列出所有连接
Returns:
连接列表
"""
return list(self._connections.values())
def is_connected(self, session_id: str) -> bool:
"""检查是否连接
Args:
session_id: 会话 ID
Returns:
是否已连接
"""
return session_id in self._connections
# 全局单例
_ws_manager: WebSocketManager | None = None
def get_websocket_manager() -> WebSocketManager:
"""获取全局 WebSocket 管理器"""
global _ws_manager
if _ws_manager is None:
_ws_manager = WebSocketManager()
return _ws_manager

View File

@@ -0,0 +1,93 @@
from __future__ import annotations
from typing import Any, cast
from pydantic import BaseModel, Field
from app.agents.schemas.task import AgentTask, TaskResult, TaskResultStatus, VerificationStatus
from app.agents.state import AgentState
class VerificationVerdict(BaseModel):
status: VerificationStatus
summary: str | None = None
evidence: list[dict[str, Any]] = Field(default_factory=list)
def normalize_task_result(
task_result: TaskResult | dict[str, Any],
*,
default_task_id: str | None = None,
) -> TaskResult:
payload = task_result.model_dump(mode="json") if isinstance(task_result, TaskResult) else dict(task_result or {})
normalized_status = payload.get("status")
if normalized_status not in {"completed", "failed", "blocked", "passed", "skipped"}:
normalized_status = "failed"
return TaskResult(
task_id=str(payload.get("task_id") or default_task_id or "unknown-task"),
status=cast(TaskResultStatus, normalized_status),
summary=payload.get("summary"),
evidence=list(payload.get("evidence") or []),
owner_agent_id=payload.get("owner_agent_id"),
parent_task_id=payload.get("parent_task_id"),
child_task_ids=list(payload.get("child_task_ids") or []),
thread_id=payload.get("thread_id"),
message_id=payload.get("message_id"),
message_index=payload.get("message_index") if isinstance(payload.get("message_index"), int) else None,
interrupt_records=list(payload.get("interrupt_records") or []),
recovery_records=list(payload.get("recovery_records") or []),
budget_snapshot=payload.get("budget_snapshot") if isinstance(payload.get("budget_snapshot"), dict) else None,
next_action=payload.get("next_action"),
output_data=payload.get("output_data") if isinstance(payload.get("output_data"), dict) else None,
)
def verify_task_result(
*,
task: AgentTask | dict[str, Any] | None = None,
result: TaskResult | dict[str, Any] | None = None,
summary: str | None = None,
evidence: list[dict[str, Any]] | None = None,
status: VerificationStatus | None = None,
) -> VerificationVerdict:
normalized_result = result.model_dump() if isinstance(result, TaskResult) else dict(result or {})
normalized_task = task.model_dump() if isinstance(task, AgentTask) else dict(task or {})
normalized_summary = summary or normalized_result.get("summary") or normalized_task.get("result_summary")
normalized_evidence = list(evidence or normalized_result.get("evidence") or normalized_task.get("evidence") or [])
if status is not None:
return VerificationVerdict(status=status, summary=normalized_summary, evidence=normalized_evidence)
normalized_status = normalized_result.get("status")
if normalized_status in {"passed", "failed", "skipped"}:
inferred_status = normalized_status
elif normalized_status == "completed":
inferred_status = "passed"
elif normalized_status == "blocked":
inferred_status = "skipped"
elif normalized_result.get("success") is True:
inferred_status = "passed"
elif normalized_result.get("success") is False:
inferred_status = "failed"
elif normalized_summary or normalized_evidence:
inferred_status = "skipped"
else:
inferred_status = "failed"
normalized_summary = "No verification input available."
return VerificationVerdict(
status=inferred_status,
summary=normalized_summary,
evidence=normalized_evidence,
)
def apply_verification_verdict(state: AgentState, verdict: VerificationVerdict) -> AgentState:
next_state = dict(state)
next_state["verification_status"] = verdict.status
next_state["verification_summary"] = verdict.summary
next_state["verification_evidence"] = list(verdict.evidence)
return AgentState(**next_state)
__all__ = ["VerificationVerdict", "apply_verification_verdict", "normalize_task_result", "verify_task_result"]

View File

@@ -1,10 +1,13 @@
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase
from app.config import settings
from collections.abc import AsyncGenerator
import os
import re
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from app.config import settings
os.makedirs(settings.DATA_DIR, exist_ok=True)
engine = create_async_engine(
@@ -24,12 +27,9 @@ class Base(DeclarativeBase):
pass
async def get_db() -> AsyncSession:
async def get_db() -> AsyncGenerator[AsyncSession, None]:
async with async_session() as session:
try:
yield session
finally:
await session.close()
yield session
async def init_db():
@@ -37,6 +37,7 @@ async def init_db():
await conn.run_sync(Base.metadata.create_all)
await ensure_log_columns(conn)
await ensure_message_columns(conn)
await ensure_conversation_columns(conn)
await ensure_document_columns(conn)
await ensure_user_columns(conn)
await ensure_forum_columns(conn)
@@ -79,6 +80,20 @@ async def ensure_message_columns(conn):
await conn.execute(text(ddl))
async def ensure_conversation_columns(conn):
rows = await _get_table_info(conn, 'conversations')
if not rows:
return
columns = {row[1] for row in rows}
required_columns = {
'agent_state': "ALTER TABLE conversations ADD COLUMN agent_state JSON",
}
for column, ddl in required_columns.items():
if column not in columns:
await conn.execute(text(ddl))
async def ensure_document_columns(conn):
result = await conn.execute(text("PRAGMA table_info(documents)"))
rows = result.fetchall()

View File

@@ -23,6 +23,11 @@ from app.routers import (
log_router,
system_router,
brain_router,
hooks_router,
plugins_router,
marketplace_router,
agent_skills_router,
agent_sessions_router,
)
from app.routers.scheduler import router as scheduler_router
from app.services.scheduler_service import start_scheduler, stop_scheduler, get_scheduler_status
@@ -40,15 +45,15 @@ import os
INSECURE_SECRET_KEYS = {
'change-me-in-production',
'change-me-to-a-random-secret-key',
'jarvis-secret-key-change-in-production',
"change-me-in-production",
"change-me-to-a-random-secret-key",
"jarvis-secret-key-change-in-production",
}
def validate_startup_security() -> None:
if not settings.DEBUG and settings.SECRET_KEY in INSECURE_SECRET_KEYS:
raise RuntimeError('SECRET_KEY must be changed before running with DEBUG disabled')
raise RuntimeError("SECRET_KEY must be changed before running with DEBUG disabled")
async def run_startup() -> None:
@@ -117,6 +122,11 @@ app.include_router(log_router)
app.include_router(system_router)
app.include_router(brain_router)
app.include_router(scheduler_router)
app.include_router(hooks_router)
app.include_router(plugins_router)
app.include_router(marketplace_router)
app.include_router(agent_skills_router)
app.include_router(agent_sessions_router)
@app.get("/api/health")

View File

@@ -9,6 +9,7 @@ class Conversation(BaseModel):
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
title = Column(String(500), nullable=True)
message_count = Column(Integer, default=0)
agent_state = Column(JSON, nullable=True)
messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan")

View File

@@ -15,3 +15,8 @@ from app.routers.skill import router as skill_router
from app.routers.log import router as log_router
from app.routers.system import router as system_router
from app.routers.brain import router as brain_router
from app.routers.hooks import router as hooks_router
from app.routers.plugins import router as plugins_router
from app.routers.plugins import _marketplace_router as marketplace_router
from app.routers.agent_skills import router as agent_skills_router
from app.routers.agent_sessions import router as agent_sessions_router

View File

@@ -1,12 +1,42 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import datetime
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.agents.registry import load_builtin_registry_indexes
from app.agents.runtime_metrics import coerce_cost_thresholds, estimate_token_cost, is_cost_budget_warning
from app.models.agent import Agent
from app.models.conversation import Conversation
from app.models.skill import Skill
from app.models.user import User
from app.routers.auth import get_current_user
from app.schemas.agent import AgentCreate, AgentOut, AgentStats, AgentConfigUpdate, AgentConfigOut
from app.schemas.agent import (
AgentConfigOut,
AgentConfigUpdate,
AgentCreate,
AgentOut,
AgentStats,
AgentVisibilityCostByAgentOut,
AgentVisibilityCostOut,
AgentVisibilityCostSummaryOut,
AgentVisibilityEvidenceOut,
AgentVisibilityEventsResponse,
AgentVisibilityEventOut,
AgentVisibilityIsolationOut,
AgentVisibilityRuntimeSummaryOut,
AgentVisibilityTaskSummaryOut,
AgentVisibilityThreadMessageOut,
AgentVisibilityThreadOut,
AgentVisibilityTopologyNodeOut,
AgentVisibilityTopologyOut,
AgentVisibilityToolGovernanceItemOut,
AgentVisibilityToolGovernanceOut,
AgentVisibilityVerifierOut,
)
from app.services.agent_service import _extract_continuity_snapshot
router = APIRouter(prefix="/api/agents", tags=["Agent"])
@@ -21,6 +51,295 @@ SUB_COMMANDERS_BY_ROLE = {
"librarian": ["librarian_retrieval", "librarian_graph"],
"analyst": ["analyst_progress", "analyst_insights"],
}
ALLOWED_AGENT_ROLES = set(DEFAULT_AGENT_ROLES) | {
role
for sub_roles in SUB_COMMANDERS_BY_ROLE.values()
for role in sub_roles
}
def _parse_visibility_datetime(value: str | None) -> datetime | None:
if value is None:
return None
try:
return datetime.fromisoformat(value.replace("Z", "+00:00"))
except ValueError as exc:
raise HTTPException(status_code=400, detail="时间参数必须是 ISO 8601 格式") from exc
async def _get_visibility_state(
conversation_id: str,
*,
current_user: User,
db: AsyncSession,
) -> dict[str, Any]:
result = await db.execute(
select(Conversation).where(
Conversation.id == conversation_id,
Conversation.user_id == current_user.id,
)
)
conversation = result.scalar_one_or_none()
if conversation is None:
raise HTTPException(status_code=404, detail="对话不存在")
snapshot = _extract_continuity_snapshot(conversation.agent_state)
if snapshot is None:
raise HTTPException(status_code=404, detail="当前会话暂无可视化运行时数据")
return snapshot
def _coerce_event_payload(event: dict[str, Any]) -> AgentVisibilityEventOut:
return AgentVisibilityEventOut.model_validate(event)
def _filter_events(
events: list[dict[str, Any]],
*,
agent_id: str | None,
thread_id: str | None,
event_type: str | None,
started_after: datetime | None,
ended_before: datetime | None,
) -> list[dict[str, Any]]:
filtered: list[dict[str, Any]] = []
for event in events:
if agent_id and event.get("agent_id") != agent_id:
continue
if thread_id and event.get("thread_id") != thread_id:
continue
if event_type and event.get("event_type") != event_type:
continue
timestamp_raw = event.get("timestamp")
timestamp = None
if isinstance(timestamp_raw, str):
try:
timestamp = datetime.fromisoformat(timestamp_raw.replace("Z", "+00:00"))
except ValueError:
timestamp = None
if started_after and timestamp and timestamp < started_after:
continue
if ended_before and timestamp and timestamp > ended_before:
continue
filtered.append(event)
return filtered
def _summarize_tasks(tasks: list[dict[str, Any]], task_results: list[dict[str, Any]]) -> list[AgentVisibilityTaskSummaryOut]:
result_by_task_id = {item.get("task_id"): item for item in task_results}
summaries: list[AgentVisibilityTaskSummaryOut] = []
for task in tasks:
task_id = str(task.get("task_id") or "")
result = result_by_task_id.get(task_id) or {}
evidence = result.get("evidence") or task.get("evidence") or []
summaries.append(
AgentVisibilityTaskSummaryOut(
task_id=task_id,
role=task.get("role"),
owner_agent_id=task.get("owner_agent_id") or result.get("owner_agent_id"),
status=result.get("status") or task.get("status"),
summary=result.get("summary") or task.get("result_summary"),
evidence_count=len(evidence),
)
)
return summaries
def _build_topology_nodes(
state: dict[str, Any],
tasks: list[dict[str, Any]],
task_results: list[dict[str, Any]],
) -> list[AgentVisibilityTopologyNodeOut]:
task_counts: dict[str, int] = {}
completed_counts: dict[str, int] = {}
for task in tasks:
owner = str(task.get("owner_agent_id") or "")
if owner:
task_counts[owner] = task_counts.get(owner, 0) + 1
for result in task_results:
owner = str(result.get("owner_agent_id") or "")
if owner and result.get("status") == "completed":
completed_counts[owner] = completed_counts.get(owner, 0) + 1
root_agent_id = str(state.get("root_agent_id") or state.get("agent_id") or "") or None
current_agent = str(state.get("current_agent") or "") or None
parent_agent_id = str(state.get("parent_agent_id") or "") or None
nodes: dict[str, AgentVisibilityTopologyNodeOut] = {}
if root_agent_id:
nodes[root_agent_id] = AgentVisibilityTopologyNodeOut(
agent_id=root_agent_id,
role=root_agent_id.split("-")[0],
parent_agent_id=parent_agent_id if root_agent_id != state.get("agent_id") else None,
source="root",
task_count=task_counts.get(root_agent_id, 0),
completed_task_count=completed_counts.get(root_agent_id, 0),
)
for agent_id in state.get("spawned_agent_ids") or []:
agent_id = str(agent_id)
nodes[agent_id] = AgentVisibilityTopologyNodeOut(
agent_id=agent_id,
role=agent_id.split("-")[0],
parent_agent_id=root_agent_id,
source="spawned",
task_count=task_counts.get(agent_id, 0),
completed_task_count=completed_counts.get(agent_id, 0),
)
if current_agent and current_agent not in nodes:
nodes[current_agent] = AgentVisibilityTopologyNodeOut(
agent_id=current_agent,
role=current_agent.split("-")[0],
parent_agent_id=None if current_agent == root_agent_id else root_agent_id,
source="current",
task_count=task_counts.get(current_agent, 0),
completed_task_count=completed_counts.get(current_agent, 0),
)
return list(nodes.values())
def _estimate_runtime_cost(input_tokens: int, output_tokens: int) -> float | None:
return estimate_token_cost(input_tokens, output_tokens)
def _build_cost_summary(
state: dict[str, Any],
*,
conversation_id: str,
) -> AgentVisibilityCostSummaryOut:
input_tokens = int(state.get("input_tokens") or 0)
output_tokens = int(state.get("output_tokens") or 0)
estimated_cost = _estimate_runtime_cost(input_tokens, output_tokens)
thresholds = coerce_cost_thresholds(state.get("cost_thresholds"))
total_budget_warning = bool(state.get("budget_warning") or False) or is_cost_budget_warning(
input_tokens,
output_tokens,
estimated_cost,
thresholds,
)
by_agent_items: list[AgentVisibilityCostByAgentOut] = []
for agent_id, payload in dict(state.get("cost_by_agent") or {}).items():
payload_dict = dict(payload or {})
agent_input_tokens = int(payload_dict.get("input_tokens") or 0)
agent_output_tokens = int(payload_dict.get("output_tokens") or 0)
agent_estimated_cost = payload_dict.get("estimated_cost")
if agent_estimated_cost is None:
agent_estimated_cost = _estimate_runtime_cost(agent_input_tokens, agent_output_tokens)
by_agent_items.append(
AgentVisibilityCostByAgentOut(
agent_id=str(payload_dict.get("agent_id") or agent_id),
input_tokens=agent_input_tokens,
output_tokens=agent_output_tokens,
total_tokens=int(payload_dict.get("total_tokens") or (agent_input_tokens + agent_output_tokens)),
estimated_cost=agent_estimated_cost,
budget_warning=bool(payload_dict.get("budget_warning") or False),
)
)
by_agent_items.sort(key=lambda item: item.total_tokens, reverse=True)
return AgentVisibilityCostSummaryOut(
conversation_id=conversation_id,
total=AgentVisibilityCostOut(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
estimated_cost=estimated_cost,
budget_warning=total_budget_warning,
),
thresholds=thresholds,
by_agent=by_agent_items,
)
def _build_tool_governance(
state: dict[str, Any],
*,
conversation_id: str,
) -> AgentVisibilityToolGovernanceOut:
indexes = load_builtin_registry_indexes()
tool_outcomes = [dict(item) for item in state.get("tool_outcomes") or [] if isinstance(item, dict)]
usage_count_by_tool: dict[str, int] = {}
last_result_preview_by_tool: dict[str, str | None] = {}
for item in tool_outcomes:
tool_name = str(item.get("tool_name") or "")
if tool_name == "search_web":
tool_name = "web_search"
if not tool_name:
continue
usage_count_by_tool[tool_name] = usage_count_by_tool.get(tool_name, 0) + 1
preview = item.get("result_preview")
if isinstance(preview, str) and preview:
last_result_preview_by_tool[tool_name] = preview
items = [
AgentVisibilityToolGovernanceItemOut(
capability_id=capability.capability_id,
tool_name=capability.tool_name,
permission_class=capability.permission_class.value,
side_effect_scope=capability.side_effect_scope.value,
supports_retry=capability.supports_retry,
idempotent=capability.idempotent,
safe_for_parallel_use=capability.safe_for_parallel_use,
requires_confirmation=capability.requires_confirmation,
usage_count=usage_count_by_tool.get(capability.tool_name, 0),
last_result_preview=last_result_preview_by_tool.get(capability.tool_name),
)
for capability in indexes.capability_by_id.values()
]
items.sort(key=lambda item: (-item.usage_count, item.tool_name))
return AgentVisibilityToolGovernanceOut(
conversation_id=conversation_id,
total_tools=len(items),
used_tools=sum(1 for item in items if item.usage_count > 0),
items=items,
upgrade_candidates=[
"worktree_manager",
"cost_inspector",
"runtime_event_drilldown",
"tool_policy_explorer",
],
)
def _build_runtime_summary(
state: dict[str, Any],
*,
conversation_id: str,
) -> AgentVisibilityRuntimeSummaryOut:
tasks = [dict(item) for item in state.get("active_tasks") or []]
task_results = [dict(item) for item in state.get("task_results") or []]
topology_nodes = _build_topology_nodes(state, tasks, task_results)
cost_summary = _build_cost_summary(state, conversation_id=conversation_id)
input_tokens = cost_summary.total.input_tokens
output_tokens = cost_summary.total.output_tokens
recent_events_raw = [dict(item) for item in (state.get("event_trace") or [])[-10:]]
isolation_mode = str(state.get("isolation_mode") or "none")
return AgentVisibilityRuntimeSummaryOut(
conversation_id=conversation_id,
execution_mode=state.get("execution_mode"),
current_phase=state.get("current_phase"),
current_checkpoint=state.get("current_checkpoint"),
phase_history=list(state.get("phase_history") or []),
checkpoint_history=list(state.get("checkpoint_history") or []),
verifier=AgentVisibilityVerifierOut(
conversation_id=conversation_id,
status=state.get("verification_status"),
summary=state.get("verification_summary"),
evidence=list(state.get("verification_evidence") or []),
),
isolation=AgentVisibilityIsolationOut(
mode=isolation_mode,
isolation_id=state.get("isolation_id"),
workspace_path=state.get("isolation_workspace_path"),
parent_conversation_id=state.get("isolation_parent_conversation_id") or state.get("parent_conversation_id"),
metadata=dict(state.get("isolation_metadata") or {}),
),
cost=cost_summary.total,
topology_node_count=len(topology_nodes),
active_task_count=len(tasks),
completed_task_count=sum(1 for item in task_results if item.get("status") == "completed"),
recent_events=[_coerce_event_payload(item) for item in recent_events_raw],
)
def record_agent_call(agent_id: str):
@@ -83,6 +402,7 @@ async def get_agent_hierarchy_stats(
@router.get("/config/{agent_id}", response_model=AgentConfigOut)
async def get_agent_config(
agent_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(select(Agent).where(Agent.role == agent_id))
@@ -172,12 +492,189 @@ async def update_agent_config(
)
@router.get("/visibility/events", response_model=AgentVisibilityEventsResponse)
async def get_visibility_events(
conversation_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
agent_id: str | None = None,
thread_id: str | None = None,
event_type: str | None = None,
started_after: str | None = None,
ended_before: str | None = None,
limit: int = Query(default=50, ge=1, le=200),
offset: int = Query(default=0, ge=0),
):
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
events = [dict(item) for item in state.get("event_trace") or []]
filtered = _filter_events(
events,
agent_id=agent_id,
thread_id=thread_id,
event_type=event_type,
started_after=_parse_visibility_datetime(started_after),
ended_before=_parse_visibility_datetime(ended_before),
)
paged = filtered[offset:offset + limit]
return AgentVisibilityEventsResponse(
conversation_id=conversation_id,
total=len(filtered),
limit=limit,
offset=offset,
items=[_coerce_event_payload(item) for item in paged],
)
@router.get("/visibility/topology", response_model=AgentVisibilityTopologyOut)
async def get_visibility_topology(
conversation_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
tasks = [dict(item) for item in state.get("active_tasks") or []]
task_results = [dict(item) for item in state.get("task_results") or []]
nodes = _build_topology_nodes(state, tasks, task_results)
root_agent_id = str(state.get("root_agent_id") or state.get("agent_id") or "") or None
edges = [
{"parent_agent_id": root_agent_id, "child_agent_id": node.agent_id}
for node in nodes
if node.parent_agent_id and root_agent_id and node.agent_id != root_agent_id
]
return AgentVisibilityTopologyOut(
conversation_id=conversation_id,
root_agent_id=root_agent_id,
current_agent=str(state.get("current_agent") or "") or None,
nodes=nodes,
edges=edges,
tasks=_summarize_tasks(tasks, task_results),
task_hierarchy=dict(state.get("task_hierarchy") or {}),
)
@router.get("/visibility/tasks/{task_id}/evidence", response_model=AgentVisibilityEvidenceOut)
async def get_visibility_task_evidence(
task_id: str,
conversation_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
tasks = [dict(item) for item in state.get("active_tasks") or []]
task = next((item for item in tasks if item.get("task_id") == task_id), None)
task_results = [dict(item) for item in state.get("task_results") or []]
result = next((item for item in task_results if item.get("task_id") == task_id), None)
if task is None and result is None:
raise HTTPException(status_code=404, detail="任务不存在")
tool_outcomes = [
dict(evidence)
for evidence in (result or {}).get("evidence") or []
if isinstance(evidence, dict) and evidence.get("tool_name")
]
verification_entry = next(
(
dict(evidence)
for evidence in (result or {}).get("evidence") or []
if isinstance(evidence, dict) and evidence.get("type") == "verification"
),
None,
)
verifier = {
"status": (verification_entry or {}).get("status"),
"summary": (verification_entry or {}).get("summary"),
"evidence": [dict(item) for item in state.get("verification_evidence") or [] if item.get("task_id") == task_id],
}
return AgentVisibilityEvidenceOut(
conversation_id=conversation_id,
task_id=task_id,
task=task,
result=result,
tool_outcomes=tool_outcomes,
verifier=verifier,
)
@router.get("/visibility/threads/{thread_id}/messages", response_model=AgentVisibilityThreadOut)
async def get_visibility_thread_messages(
thread_id: str,
conversation_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
items = [
AgentVisibilityThreadMessageOut.model_validate(item)
for item in state.get("message_trace") or []
if isinstance(item, dict) and item.get("thread_id") == thread_id
]
if not items:
raise HTTPException(status_code=404, detail="线程不存在")
return AgentVisibilityThreadOut(
conversation_id=conversation_id,
thread_id=thread_id,
total=len(items),
items=items,
)
@router.get("/visibility/verifier", response_model=AgentVisibilityVerifierOut)
async def get_visibility_verifier(
conversation_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
return AgentVisibilityVerifierOut(
conversation_id=conversation_id,
status=state.get("verification_status"),
summary=state.get("verification_summary"),
evidence=list(state.get("verification_evidence") or []),
)
@router.get("/visibility/runtime-summary", response_model=AgentVisibilityRuntimeSummaryOut)
async def get_visibility_runtime_summary(
conversation_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
return _build_runtime_summary(state, conversation_id=conversation_id)
@router.get("/visibility/cost", response_model=AgentVisibilityCostSummaryOut)
async def get_visibility_cost(
conversation_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
return _build_cost_summary(state, conversation_id=conversation_id)
@router.get("/visibility/tools", response_model=AgentVisibilityToolGovernanceOut)
async def get_visibility_tools(
conversation_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
state = await _get_visibility_state(conversation_id, current_user=current_user, db=db)
return _build_tool_governance(state, conversation_id=conversation_id)
@router.post("", response_model=AgentOut, status_code=201)
async def create_agent(
data: AgentCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
if not current_user.is_superuser:
raise HTTPException(status_code=403, detail="仅管理员可创建 Agent")
if not data.spawn_permission:
raise HTTPException(status_code=400, detail="缺少 spawn_permission禁止直接创建 runtime agent")
if data.role not in ALLOWED_AGENT_ROLES:
raise HTTPException(status_code=400, detail="不支持的 Agent 角色")
agent = Agent(
name=data.name,
role=data.role,
@@ -193,6 +690,7 @@ async def create_agent(
@router.get("/{agent_id}", response_model=AgentOut)
async def get_agent(
agent_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(select(Agent).where(Agent.id == agent_id))

View File

@@ -0,0 +1,113 @@
"""Agent Session API 路由 - Phase 10.3"""
from typing import Any
from fastapi import APIRouter, HTTPException
from app.agents.session.manager import AgentSession, create_agent_session, get_agent_session
router = APIRouter(prefix="/api/agent/sessions", tags=["Agent Sessions"])
@router.post("", response_model=dict[str, Any])
async def create_session(
user_id: str | None = None,
parent_session_id: str | None = None,
) -> dict[str, Any]:
"""创建新会话"""
session = create_agent_session(
user_id=user_id,
parent_session_id=parent_session_id,
)
return await session.initialize()
@router.get("/{session_id}", response_model=dict[str, Any])
async def get_session(session_id: str) -> dict[str, Any]:
"""获取会话信息"""
session = get_agent_session(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
return await session.get_session_summary()
@router.post("/{session_id}/message", response_model=dict[str, str])
async def process_message(
session_id: str,
message: str,
response: str,
) -> dict[str, str]:
"""处理消息"""
session = get_agent_session(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
await session.process_message(message, response)
return {"status": "recorded", "session_id": session_id}
@router.post("/{session_id}/spawn", response_model=dict[str, Any])
async def spawn_child_session(
session_id: str,
user_id: str | None = None,
) -> dict[str, Any]:
"""创建子会话"""
session = get_agent_session(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
child = await session.spawn_child_session(user_id=user_id)
return await child.get_session_summary()
@router.get("/{session_id}/history", response_model=dict[str, Any])
async def get_session_history(session_id: str) -> dict[str, Any]:
"""获取会话历史"""
session = get_agent_session(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
return {
"session_id": session_id,
"history": session.get_history(),
"count": len(session._history),
}
@router.post("/{session_id}/persist", response_model=dict[str, str])
async def persist_session(session_id: str) -> dict[str, str]:
"""持久化会话"""
session = get_agent_session(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
success = await session.persist()
if success:
return {"status": "persisted", "session_id": session_id}
raise HTTPException(status_code=500, detail="Failed to persist session")
@router.post("/{session_id}/metadata", response_model=dict[str, Any])
async def set_session_metadata(
session_id: str,
key: str,
value: Any,
) -> dict[str, Any]:
"""设置会话元数据"""
session = get_agent_session(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
session.add_metadata(key, value)
await session.persist()
return {"key": key, "value": value}
@router.get("/{session_id}/metadata/{key}", response_model=dict[str, Any])
async def get_session_metadata(
session_id: str,
key: str,
) -> dict[str, Any]:
"""获取会话元数据"""
session = get_agent_session(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
value = session.get_metadata(key)
if value is None:
raise HTTPException(status_code=404, detail=f"Metadata key '{key}' not found")
return {"key": key, "value": value}

View File

@@ -0,0 +1,126 @@
"""Agent Skills API 路由 - Phase 9.6
使用新的 SkillRegistry (file-based) 而不是 DB-based skill 系统。
"""
from typing import Any
from fastapi import APIRouter, HTTPException
from app.agents.skills.registry import get_skill_registry, SkillRegistry
router = APIRouter(prefix="/api/agent/skills", tags=["Agent Skills"])
def _skill_to_dict(skill) -> dict[str, Any]:
"""将 SkillMetadata 转换为字典"""
return {
"name": skill.name,
"description": skill.description,
"tags": skill.tags,
"triggers": skill.triggers,
"enabled": skill.enabled,
"content_preview": skill.content[:200] + "..."
if len(skill.content) > 200
else skill.content,
}
@router.get("", response_model=dict[str, Any])
async def list_agent_skills() -> dict[str, Any]:
"""列出所有已加载的 Agent Skills"""
registry = get_skill_registry()
skills = registry.list_all()
return {
"skills": [_skill_to_dict(s) for s in skills],
"count": len(skills),
}
@router.get("/search", response_model=dict[str, Any])
async def search_agent_skills(
query: str,
) -> dict[str, Any]:
"""搜索 Skills"""
registry = get_skill_registry()
results = registry.search(query)
return {
"skills": [_skill_to_dict(s) for s in results],
"count": len(results),
"query": query,
}
@router.get("/{skill_name}", response_model=dict[str, Any])
async def get_agent_skill(skill_name: str) -> dict[str, Any]:
"""获取指定 Skill 详情"""
registry = get_skill_registry()
skill = registry.get_skill(skill_name)
if not skill:
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
return {
"name": skill.name,
"description": skill.description,
"tags": skill.tags,
"triggers": skill.triggers,
"enabled": skill.enabled,
"content": skill.content,
}
@router.get("/{skill_name}/context", response_model=dict[str, str])
async def get_skill_context(skill_name: str) -> dict[str, str]:
"""获取 Skill 上下文字符串"""
registry = get_skill_registry()
context = registry.get_skill_context([skill_name])
if not context:
raise HTTPException(
status_code=404, detail=f"Skill '{skill_name}' not found or not enabled"
)
return {"skill_name": skill_name, "context": context}
@router.post("/context/batch", response_model=dict[str, str])
async def get_batch_skill_context(
skill_names: list[str],
) -> dict[str, str]:
"""批量获取多个 Skill 的上下文"""
registry = get_skill_registry()
context = registry.get_skill_context(skill_names)
return {"skills": skill_names, "context": context}
@router.post("/reload", response_model=dict[str, Any])
async def reload_skills(
skills_dir: str | None = None,
) -> dict[str, Any]:
"""重新加载所有 Skills"""
registry = get_skill_registry()
# 清除旧 skills
for name in list(registry._skills.keys()):
registry.unregister(name)
# 重新加载
count = registry.load_all(skills_dir)
return {"loaded": count, "message": f"Loaded {count} skills"}
@router.post("/{skill_name}/enable", response_model=dict[str, str])
async def enable_skill(skill_name: str) -> dict[str, str]:
"""启用 Skill"""
registry = get_skill_registry()
skill = registry.get_skill(skill_name)
if not skill:
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
skill.enabled = True
return {"status": "enabled", "skill_name": skill_name}
@router.post("/{skill_name}/disable", response_model=dict[str, str])
async def disable_skill(skill_name: str) -> dict[str, str]:
"""禁用 Skill"""
registry = get_skill_registry()
skill = registry.get_skill(skill_name)
if not skill:
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
skill.enabled = False
return {"status": "disabled", "skill_name": skill_name}

View File

@@ -0,0 +1,241 @@
"""Hook API 路由 - Phase 7.5"""
from typing import Any
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from app.agents.tools.hooks import HookType
from app.agents.tools.hooks.builtins import (
AuditLogHook,
DangerousConfirmationHook,
SecurityScanHook,
)
from app.agents.tools.hooks.config import (
HookConfigEntry,
get_hook_config_persistence,
)
from app.agents.tools.hooks.manager import get_hook_manager
router = APIRouter(prefix="/api/hooks", tags=["Hooks"])
class HookInfo(BaseModel):
"""Hook 信息"""
name: str
hook_type: str
description: str
builtin: bool
class HookConfigUpdate(BaseModel):
"""更新 Hook 配置"""
entries: list[HookConfigEntry]
class HookConfigResponse(BaseModel):
"""Hook 配置响应"""
entries: list[dict[str, Any]]
count: int
class HookStatusResponse(BaseModel):
"""Hook 状态响应"""
name: str
enabled: bool
hook_type: str
registered: bool
# 内置 Hook 注册表
BUILTIN_HOOKS: dict[str, dict[str, str]] = {
"audit_log": {
"name": "audit_log",
"hook_type": "pre_tool_use,post_tool_use,tool_error",
"description": "审计日志 Hook - 记录所有工具调用",
"class": "AuditLogHook",
},
"dangerous_confirmation": {
"name": "dangerous_confirmation",
"hook_type": "pre_tool_use",
"description": "危险操作确认 Hook - 拦截危险工具调用",
"class": "DangerousConfirmationHook",
},
"security_scan": {
"name": "security_scan",
"hook_type": "post_tool_use",
"description": "安全扫描 Hook - 检测敏感信息泄露",
"class": "SecurityScanHook",
},
}
@router.get("/available", response_model=list[HookInfo])
async def list_available_hooks() -> list[HookInfo]:
"""列出所有可用的内置 Hook"""
return [
HookInfo(
name=info["name"],
hook_type=info["hook_type"],
description=info["description"],
builtin=True,
)
for info in BUILTIN_HOOKS.values()
]
@router.get("/config", response_model=HookConfigResponse)
async def get_hook_config() -> HookConfigResponse:
"""获取当前 Hook 配置"""
persistence = get_hook_config_persistence()
entries = persistence.load_config()
return HookConfigResponse(
entries=[vars(e) if isinstance(e, HookConfigEntry) else e for e in entries],
count=len(entries),
)
@router.post("/config", response_model=HookConfigResponse)
async def update_hook_config(
entries: list[HookConfigEntry],
) -> HookConfigResponse:
"""更新 Hook 配置"""
persistence = get_hook_config_persistence()
success = persistence.save_config(entries)
if not success:
raise HTTPException(status_code=500, detail="Failed to save hook config")
# 应用配置到 HookManager
manager = get_hook_manager()
manager.clear() # 清除旧配置
persistence.apply_config() # 应用新配置
return HookConfigResponse(
entries=[vars(e) if isinstance(e, HookConfigEntry) else e for e in entries],
count=len(entries),
)
@router.post("/apply-config", response_model=dict[str, Any])
async def apply_hook_config() -> dict[str, Any]:
"""应用配置文件到 HookManager"""
persistence = get_hook_config_persistence()
manager = get_hook_manager()
manager.clear()
count = persistence.apply_config()
return {"applied": count, "message": f"Applied {count} hook configurations"}
@router.get("/status", response_model=list[HookStatusResponse])
async def get_hook_status() -> list[HookStatusResponse]:
"""获取所有已注册 Hook 的状态"""
manager = get_hook_manager()
all_hooks = manager.list_all()
# 按名称索引已注册的 hooks
registered: dict[str, dict[str, Any]] = {}
for hook in all_hooks:
registered[hook.name] = {
"name": hook.name,
"enabled": hook.enabled,
"hook_type": hook.hook_type.value,
"registered": True,
}
# 合并内置 Hook 信息
result: list[HookStatusResponse] = []
seen: set[str] = set()
# 先添加已注册的
for hook in all_hooks:
result.append(
HookStatusResponse(
name=hook.name,
enabled=hook.enabled,
hook_type=hook.hook_type.value,
registered=True,
)
)
seen.add(hook.name)
# 再添加内置但未注册的
for name, info in BUILTIN_HOOKS.items():
if name not in seen:
result.append(
HookStatusResponse(
name=name,
enabled=False,
hook_type=info["hook_type"],
registered=False,
)
)
return result
@router.post("/{name}/enable", response_model=dict[str, str])
async def enable_hook(name: str) -> dict[str, str]:
"""启用指定 Hook"""
manager = get_hook_manager()
if manager.enable(name):
return {"status": "enabled", "name": name}
raise HTTPException(status_code=404, detail=f"Hook '{name}' not found")
@router.post("/{name}/disable", response_model=dict[str, str])
async def disable_hook(name: str) -> dict[str, str]:
"""禁用指定 Hook"""
manager = get_hook_manager()
if manager.disable(name):
return {"status": "disabled", "name": name}
raise HTTPException(status_code=404, detail=f"Hook '{name}' not found")
@router.post("/register-builtin", response_model=dict[str, str])
async def register_builtin_hook(
name: str,
hook_type: str = "pre_tool_use",
) -> dict[str, str]:
"""注册内置 Hook 到 HookManager"""
from app.agents.tools.hooks.types import HookDefinition, HookTrigger
manager = get_hook_manager()
if name == "audit_log":
hook_instance = AuditLogHook()
handler = hook_instance.pre_tool_use
hook_types = [HookType.PRE_TOOL_USE, HookType.POST_TOOL_USE, HookType.TOOL_ERROR]
elif name == "dangerous_confirmation":
hook_instance = DangerousConfirmationHook()
handler = hook_instance.pre_tool_use
hook_types = [HookType.PRE_TOOL_USE]
elif name == "security_scan":
hook_instance = SecurityScanHook()
handler = hook_instance.post_tool_use
hook_types = [HookType.POST_TOOL_USE]
else:
raise HTTPException(status_code=404, detail=f"Unknown builtin hook: {name}")
registered = []
for ht in hook_types:
hook_def = HookDefinition(
name=f"{name}_{ht.value}",
hook_type=ht,
trigger=HookTrigger(),
handler=handler,
priority=0,
enabled=True,
description=f"Built-in {name} hook",
)
manager.register(hook_def)
registered.append(ht.value)
return {
"status": "registered",
"name": name,
"hook_types": ", ".join(registered),
}

View File

@@ -0,0 +1,222 @@
"""Plugin API 路由 - Phase 8.6"""
import os
import tempfile
import zipfile
from typing import Any
import httpx
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from app.agents.plugins import get_plugin_manager, PluginManifest
router = APIRouter(prefix="/api/plugins", tags=["Plugins"])
class PluginInfo(BaseModel):
"""插件信息"""
id: str
name: str
version: str
description: str
author: str
enabled: bool
main: str
class PluginInstallRequest(BaseModel):
"""插件安装请求"""
plugin_path: str
class PluginListResponse(BaseModel):
"""插件列表响应"""
plugins: list[dict[str, Any]]
count: int
# 全局插件市场(简单内存实现)
_plugin_marketplace: list[dict[str, str]] = []
def _manifest_to_dict(manifest: PluginManifest, enabled: bool) -> dict[str, Any]:
"""将 PluginManifest 转换为字典"""
return {
"id": manifest.id,
"name": manifest.name,
"version": manifest.version,
"description": manifest.description,
"author": manifest.author,
"enabled": enabled,
"main": manifest.main,
}
@router.get("", response_model=PluginListResponse)
async def list_plugins() -> PluginListResponse:
"""列出所有已安装的插件"""
manager = get_plugin_manager()
plugins = manager.list_plugins()
result = []
for p in plugins:
enabled = manager.is_enabled(p.id)
result.append(_manifest_to_dict(p, enabled))
return PluginListResponse(plugins=result, count=len(result))
@router.get("/{plugin_id}", response_model=dict[str, Any])
async def get_plugin(plugin_id: str) -> dict[str, Any]:
"""获取指定插件信息"""
manager = get_plugin_manager()
manifest = manager.get_plugin(plugin_id)
if not manifest:
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found")
enabled = manager.is_enabled(plugin_id)
return _manifest_to_dict(manifest, enabled)
@router.post("/install", response_model=dict[str, str])
async def install_plugin(request: PluginInstallRequest) -> dict[str, str]:
"""安装插件"""
manager = get_plugin_manager()
if not os.path.exists(request.plugin_path):
raise HTTPException(status_code=400, detail="Plugin path does not exist")
if manager.install(request.plugin_path):
return {"status": "installed", "path": request.plugin_path}
raise HTTPException(status_code=500, detail="Failed to install plugin")
@router.post("/{plugin_id}/enable", response_model=dict[str, str])
async def enable_plugin(plugin_id: str) -> dict[str, str]:
"""启用插件"""
manager = get_plugin_manager()
if manager.enable(plugin_id):
return {"status": "enabled", "plugin_id": plugin_id}
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found")
@router.post("/{plugin_id}/disable", response_model=dict[str, str])
async def disable_plugin(plugin_id: str) -> dict[str, str]:
"""禁用插件"""
manager = get_plugin_manager()
if manager.disable(plugin_id):
return {"status": "disabled", "plugin_id": plugin_id}
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found")
@router.delete("/{plugin_id}", response_model=dict[str, str])
async def uninstall_plugin(plugin_id: str) -> dict[str, str]:
"""卸载插件"""
manager = get_plugin_manager()
if manager.uninstall(plugin_id):
return {"status": "uninstalled", "plugin_id": plugin_id}
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found")
@router.post("/{plugin_id}/reload", response_model=dict[str, str])
async def reload_plugin(plugin_id: str) -> dict[str, str]:
"""重新加载插件"""
manager = get_plugin_manager()
if manager.reload(plugin_id):
return {"status": "reloaded", "plugin_id": plugin_id}
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found")
# === Plugin Marketplace ===
_marketplace_router = APIRouter(prefix="/api/marketplace", tags=["Plugin Marketplace"])
@_marketplace_router.get("/plugins", response_model=dict[str, Any])
async def search_marketplace_plugins(
query: str | None = None,
category: str | None = None,
) -> dict[str, Any]:
"""搜索插件市场"""
results = _plugin_marketplace
if query:
results = [
p
for p in results
if query.lower() in p.get("name", "").lower()
or query.lower() in p.get("description", "").lower()
]
if category:
results = [p for p in results if p.get("category") == category]
return {"plugins": results, "count": len(results)}
@_marketplace_router.get("/plugins/{plugin_id}", response_model=dict[str, Any])
async def get_marketplace_plugin(plugin_id: str) -> dict[str, Any]:
"""获取市场中的插件详情"""
for plugin in _plugin_marketplace:
if plugin.get("id") == plugin_id:
return plugin
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found in marketplace")
@_marketplace_router.post("/plugins", response_model=dict[str, str])
async def add_to_marketplace(plugin: dict[str, str]) -> dict[str, str]:
"""添加插件到市场(仅供测试/开发)"""
if "id" not in plugin or "name" not in plugin:
raise HTTPException(status_code=400, detail="Plugin must have id and name")
# 移除已存在的同 ID 插件
global _plugin_marketplace
_plugin_marketplace = [p for p in _plugin_marketplace if p.get("id") != plugin["id"]]
_plugin_marketplace.append(plugin)
return {"status": "added", "id": plugin["id"]}
@_marketplace_router.post("/plugins/{plugin_id}/download", response_model=dict[str, str])
async def download_plugin(plugin_id: str) -> dict[str, str]:
"""从市场下载并安装插件"""
# Find plugin in marketplace
plugin = None
for p in _plugin_marketplace:
if p.get("id") == plugin_id:
plugin = p
break
if not plugin:
raise HTTPException(
status_code=404, detail=f"Plugin '{plugin_id}' not found in marketplace"
)
download_url = plugin.get("download_url")
if not download_url:
raise HTTPException(status_code=400, detail="Plugin has no download URL")
try:
# Download the plugin archive
async with httpx.AsyncClient() as client:
response = await client.get(download_url, timeout=60.0)
response.raise_for_status()
archive_content = response.content
# Extract to temp directory and install
with tempfile.TemporaryDirectory() as temp_dir:
archive_path = os.path.join(temp_dir, "plugin.zip")
with open(archive_path, "wb") as f:
f.write(archive_content)
extract_dir = os.path.join(temp_dir, "extracted")
with zipfile.ZipFile(archive_path, "r") as zf:
zf.extractall(extract_dir)
# Install the plugin
manager = get_plugin_manager()
if manager.install(extract_dir):
return {"status": "installed", "plugin_id": plugin_id}
raise HTTPException(status_code=500, detail="Failed to install plugin")
except httpx.HTTPError as e:
raise HTTPException(status_code=502, detail=f"Download failed: {str(e)}")
except zipfile.BadZipFile:
raise HTTPException(status_code=502, detail="Invalid plugin archive")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Installation failed: {str(e)}")

View File

@@ -1,4 +1,7 @@
from pydantic import BaseModel
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
class AgentCreate(BaseModel):
@@ -6,6 +9,7 @@ class AgentCreate(BaseModel):
role: str
description: str | None = None
system_prompt: str
spawn_permission: bool = False
class AgentOut(BaseModel):
@@ -55,3 +59,163 @@ class AgentConfigOut(BaseModel):
selected_skill_ids: list[str]
model_config = {"from_attributes": True}
class AgentVisibilityEventOut(BaseModel):
event_id: str
event_type: str
timestamp: datetime
conversation_id: str | None = None
agent_id: str | None = None
sub_commander_id: str | None = None
task_id: str | None = None
parent_task_id: str | None = None
child_task_id: str | None = None
thread_id: str | None = None
message_id: str | None = None
interrupt_id: str | None = None
recovery_id: str | None = None
payload: dict[str, Any] = Field(default_factory=dict)
severity: str = "info"
class AgentVisibilityEventsResponse(BaseModel):
conversation_id: str
total: int
limit: int
offset: int
items: list[AgentVisibilityEventOut]
class AgentVisibilityTaskSummaryOut(BaseModel):
task_id: str
role: str | None = None
owner_agent_id: str | None = None
status: str | None = None
summary: str | None = None
evidence_count: int = 0
class AgentVisibilityTopologyNodeOut(BaseModel):
agent_id: str
role: str | None = None
parent_agent_id: str | None = None
source: str
task_count: int = 0
completed_task_count: int = 0
class AgentVisibilityTopologyOut(BaseModel):
conversation_id: str
root_agent_id: str | None = None
current_agent: str | None = None
nodes: list[AgentVisibilityTopologyNodeOut]
edges: list[dict[str, str]]
tasks: list[AgentVisibilityTaskSummaryOut]
task_hierarchy: dict[str, list[str]] = Field(default_factory=dict)
class AgentVisibilityEvidenceOut(BaseModel):
conversation_id: str
task_id: str
task: dict[str, Any] | None = None
result: dict[str, Any] | None = None
tool_outcomes: list[dict[str, Any]] = Field(default_factory=list)
verifier: dict[str, Any]
class AgentVisibilityThreadMessageOut(BaseModel):
message_id: str
thread_id: str
from_agent_id: str
to_agent_id: str
task_id: str | None = None
reply_to_message_id: str | None = None
message_type: str
content_summary: str
created_at: datetime
payload: dict[str, Any] = Field(default_factory=dict)
class AgentVisibilityThreadOut(BaseModel):
conversation_id: str
thread_id: str
total: int
items: list[AgentVisibilityThreadMessageOut]
class AgentVisibilityVerifierOut(BaseModel):
conversation_id: str
status: str | None = None
summary: str | None = None
evidence: list[dict[str, Any]] = Field(default_factory=list)
class AgentVisibilityIsolationOut(BaseModel):
mode: str = "none"
isolation_id: str | None = None
workspace_path: str | None = None
parent_conversation_id: str | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
class AgentVisibilityCostOut(BaseModel):
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0
estimated_cost: float | None = None
budget_warning: bool = False
currency: str = "USD"
class AgentVisibilityCostByAgentOut(BaseModel):
agent_id: str
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0
estimated_cost: float | None = None
budget_warning: bool = False
class AgentVisibilityCostSummaryOut(BaseModel):
conversation_id: str
total: AgentVisibilityCostOut
thresholds: dict[str, float] = Field(default_factory=dict)
by_agent: list[AgentVisibilityCostByAgentOut] = Field(default_factory=list)
class AgentVisibilityToolGovernanceItemOut(BaseModel):
capability_id: str
tool_name: str
permission_class: str
side_effect_scope: str
supports_retry: bool = False
idempotent: bool = False
safe_for_parallel_use: bool = False
requires_confirmation: bool = False
usage_count: int = 0
last_result_preview: str | None = None
class AgentVisibilityToolGovernanceOut(BaseModel):
conversation_id: str
total_tools: int = 0
used_tools: int = 0
items: list[AgentVisibilityToolGovernanceItemOut] = Field(default_factory=list)
upgrade_candidates: list[str] = Field(default_factory=list)
class AgentVisibilityRuntimeSummaryOut(BaseModel):
conversation_id: str
execution_mode: str | None = None
current_phase: str | None = None
current_checkpoint: str | None = None
phase_history: list[dict[str, Any]] = Field(default_factory=list)
checkpoint_history: list[dict[str, Any]] = Field(default_factory=list)
verifier: AgentVisibilityVerifierOut
isolation: AgentVisibilityIsolationOut
cost: AgentVisibilityCostOut
topology_node_count: int = 0
active_task_count: int = 0
completed_task_count: int = 0
recent_events: list[AgentVisibilityEventOut] = Field(default_factory=list)

View File

@@ -21,6 +21,7 @@ from app.models.conversation import Conversation, Message
from app.models.user import User
from app.agents.graph import get_agent_graph
from app.agents.context import set_current_user, clear_current_user
from app.agents.skills.registry import get_skill_registry
from app.services import memory_service
from app.services.brain_service import BrainService
from app.services.llm_service import create_llm_from_config, resolve_provider_capabilities
@@ -30,6 +31,56 @@ from app.agents.state import initial_state
logger = logging.getLogger(__name__)
MEMORY_SECTION_HEADERS = (
"【用户记忆】",
"【之前对话摘要】",
"【知识大脑】",
)
def _split_memory_context_sections(memory_context: str | None) -> dict[str, str]:
text = (memory_context or "").strip()
if not text:
return {}
sections: dict[str, str] = {}
current_header: str | None = None
current_lines: list[str] = []
for line in text.splitlines():
stripped = line.strip()
if stripped in MEMORY_SECTION_HEADERS:
if current_header and current_lines:
sections[current_header] = "\n".join(current_lines).strip()
current_header = stripped
current_lines = [stripped]
continue
if current_header:
current_lines.append(line)
if current_header and current_lines:
sections[current_header] = "\n".join(current_lines).strip()
return sections
def _derive_role_memory_contexts(memory_context: str | None) -> dict[str, str | None]:
sections = _split_memory_context_sections(memory_context)
user_memory = sections.get("【用户记忆】")
summaries = sections.get("【之前对话摘要】")
knowledge = sections.get("【知识大脑】")
def _join_parts(*parts: str | None) -> str | None:
values = [part for part in parts if part]
return "\n\n".join(values) if values else None
return {
"schedule_context_summary": _join_parts(user_memory, summaries),
"knowledge_context": knowledge,
"analysis_report": _join_parts(summaries, knowledge),
}
def _is_streaming_rejection_error(error: Exception, user_llm_config: dict | None) -> bool:
capabilities = resolve_provider_capabilities(user_llm_config)
error_text = str(error).lower()
@@ -45,9 +96,8 @@ def _is_streaming_rejection_error(error: Exception, user_llm_config: dict | None
]
if isinstance(error, BadRequestError):
return (
getattr(capabilities, "provider", None) not in {"openai", "claude"}
and any(marker in error_text for marker in markers)
return getattr(capabilities, "provider", None) not in {"openai", "claude"} and any(
marker in error_text for marker in markers
)
return any(marker in error_text for marker in markers)
@@ -84,14 +134,165 @@ _CONTINUITY_SNAPSHOT_FIELDS = (
"current_agent",
"next_step",
"agent_trace",
"agent_id",
"parent_agent_id",
"root_agent_id",
"collaboration_depth",
"thread_id",
"last_message_id",
"message_sequence",
"spawned_agent_ids",
"current_sub_commander",
"active_sub_commanders",
"sub_commander_trace",
"event_trace",
"message_trace",
"active_tasks",
"task_results",
"task_hierarchy",
"verification_status",
"verification_summary",
"verification_evidence",
"isolation_mode",
"isolation_id",
"isolation_workspace_path",
"isolation_parent_conversation_id",
"isolation_metadata",
"input_tokens",
"output_tokens",
"estimated_cost",
"budget_warning",
"cost_by_agent",
"cost_thresholds",
"budget_state",
"collaboration_budget_history",
"current_phase",
"phase_history",
"current_checkpoint",
"checkpoint_history",
)
def _normalize_legacy_turn_context(turn_context: Any, current_agent: Any) -> dict[str, Any] | None:
if not isinstance(turn_context, dict):
return None
normalized = dict(turn_context)
active_agent = normalized.pop("active_agent", None)
active_sub_flow = normalized.pop("active_sub_flow", None)
if isinstance(active_agent, str) and active_agent and "active_agent" not in normalized:
normalized["active_agent"] = active_agent
if (
isinstance(active_sub_flow, str)
and active_sub_flow
and "active_sub_commander" not in normalized
):
normalized["active_sub_commander"] = active_sub_flow
if not normalized.get("active_agent") and isinstance(current_agent, str) and current_agent:
normalized["active_agent"] = current_agent
return normalized or None
def _normalize_legacy_pending_action(pending_action: Any) -> dict[str, Any] | None:
if not isinstance(pending_action, dict):
return None
normalized = dict(pending_action)
legacy_action_type = normalized.pop("action_type", None)
if legacy_action_type and "type" not in normalized:
normalized["type"] = legacy_action_type
legacy_agent = normalized.pop("agent", None)
legacy_sub_flow = normalized.pop("sub_flow", None)
if legacy_agent and "owner_agent" not in normalized:
normalized["owner_agent"] = legacy_agent
if legacy_sub_flow and "owner_sub_commander" not in normalized:
normalized["owner_sub_commander"] = legacy_sub_flow
legacy_status = normalized.get("status")
if legacy_status == "awaiting_confirmation":
normalized["status"] = "pending"
elif legacy_status == "awaiting_clarification":
normalized["status"] = "blocked_on_clarification"
return normalized or None
def _normalize_legacy_clarification_context(
clarification_context: Any,
pending_action: dict[str, Any] | None,
current_agent: Any,
) -> dict[str, Any] | None:
if not isinstance(clarification_context, dict):
return None
normalized = dict(clarification_context)
active_agent = normalized.pop("active_agent", None)
sub_flow = normalized.pop("sub_flow", None)
awaiting_user_input = normalized.pop("awaiting_user_input", None)
if isinstance(active_agent, str) and active_agent and "owning_agent" not in normalized:
normalized["owning_agent"] = active_agent
if isinstance(sub_flow, str) and sub_flow and "owning_sub_commander" not in normalized:
normalized["owning_sub_commander"] = sub_flow
if "target_action" not in normalized:
target_action = None
if pending_action:
pending_type = pending_action.get("type")
if isinstance(pending_type, str) and pending_type and pending_type != "clarification":
target_action = pending_type
if target_action is None and isinstance(sub_flow, str) and sub_flow.startswith("create_"):
target_action = sub_flow
if target_action:
normalized["target_action"] = target_action
if not normalized.get("owning_agent") and isinstance(current_agent, str) and current_agent:
normalized["owning_agent"] = current_agent
if awaiting_user_input is True and "status" not in normalized:
normalized["status"] = "pending"
return normalized or None
def _normalize_legacy_continuity_state(
continuity_state: Any,
clarification_context: dict[str, Any] | None,
) -> dict[str, Any] | None:
if not isinstance(continuity_state, dict):
return None
normalized = dict(continuity_state)
normalized.pop("active_agent", None)
normalized.pop("active_sub_flow", None)
legacy_status = normalized.get("status")
if legacy_status == "awaiting_clarification":
normalized["status"] = "fresh"
if clarification_context and "mode" not in normalized:
normalized["mode"] = "resume_after_clarification"
return normalized or None
def _normalize_continuity_snapshot(state: dict[str, Any]) -> dict[str, Any]:
normalized = dict(state)
current_agent = normalized.get("current_agent")
pending_action = _normalize_legacy_pending_action(normalized.get("pending_action"))
clarification_context = _normalize_legacy_clarification_context(
normalized.get("clarification_context"),
pending_action,
current_agent,
)
continuity_state = _normalize_legacy_continuity_state(
normalized.get("continuity_state"),
clarification_context,
)
turn_context = _normalize_legacy_turn_context(normalized.get("turn_context"), current_agent)
if pending_action is not None:
normalized["pending_action"] = pending_action
if clarification_context is not None:
normalized["clarification_context"] = clarification_context
if continuity_state is not None:
normalized["continuity_state"] = continuity_state
if turn_context is not None:
normalized["turn_context"] = turn_context
return normalized
def _build_continuity_snapshot(state: dict[str, Any]) -> dict[str, Any] | None:
normalized_state = _normalize_continuity_snapshot(state)
snapshot = {
field: state.get(field)
field: normalized_state.get(field)
for field in _CONTINUITY_SNAPSHOT_FIELDS
if state.get(field) is not None
if normalized_state.get(field) is not None
}
if not snapshot:
return None
@@ -116,7 +317,7 @@ def _extract_continuity_snapshot(payload: Any) -> dict[str, Any] | None:
return None
state = payload.get("state")
if isinstance(state, dict):
return state
return _normalize_continuity_snapshot(state)
return None
@@ -160,11 +361,32 @@ class AgentService:
"【当前时间】\n"
f"- current_time_utc: {reference['current_time_iso']}\n"
f"- current_date_utc: {reference['current_date_iso']}\n"
"说明:解析今天/明天/后天/本周/下周等相对时间时,请以 current_time_utc 为准。"
"说明:解析'今天/明天/后天/本周/下周'等相对时间时,请以 current_time_utc 为准。"
)
return context, reference
async def _get_user_llm_config(self, user_id: str, model_name: str | None = None) -> dict | None:
def build_skill_context(self, skill_names: list[str]) -> dict:
"""构建 Skills 上下文
Args:
skill_names: Skill 名称列表
Returns:
包含 skills 上下文的字典
"""
registry = get_skill_registry()
merged_context = registry.get_skill_context(skill_names)
return {
"skills_context": merged_context,
"skills_metadata": {
"skills": skill_names,
"count": len(skill_names),
},
}
async def _get_user_llm_config(
self, user_id: str, model_name: str | None = None
) -> dict | None:
"""获取用户的 LLM 模型配置"""
user = await self.db.get(User, user_id)
if not user or not user.llm_config:
@@ -187,7 +409,7 @@ class AgentService:
return None
async def _load_continuity_snapshot(self, conversation: Conversation) -> dict[str, Any] | None:
snapshot = _extract_continuity_snapshot(conversation.agent_state)
snapshot = _extract_continuity_snapshot(getattr(conversation, "agent_state", None))
if snapshot:
return snapshot
@@ -214,13 +436,15 @@ class AgentService:
user_llm_config: dict | None,
) -> dict[str, Any]:
state = initial_state(user_id, conversation.id)
state.update({
"messages": [HumanMessage(content=full_message)],
"memory_context": memory_context,
"current_datetime_context": current_datetime_context,
"current_datetime_reference": current_datetime_reference,
"user_llm_config": user_llm_config,
})
state.update(
{
"messages": [HumanMessage(content=full_message)],
"memory_context": memory_context,
"current_datetime_context": current_datetime_context,
"current_datetime_reference": current_datetime_reference,
"user_llm_config": user_llm_config,
}
)
previous_snapshot = await self._load_continuity_snapshot(conversation)
if previous_snapshot:
state.update(previous_snapshot)
@@ -282,6 +506,7 @@ class AgentService:
file_context = ""
if file_ids:
from app.services.document_service import DocumentService
doc_svc = DocumentService(self.db)
for file_id in file_ids:
content = await doc_svc.get_document_content(user_id, file_id)
@@ -347,7 +572,9 @@ class AgentService:
set_current_user(user_id)
try:
graph = get_agent_graph()
current_datetime_context, current_datetime_reference = self._build_current_datetime_context()
current_datetime_context, current_datetime_reference = (
self._build_current_datetime_context()
)
state = await self._build_agent_state(
user_id=user_id,
@@ -358,8 +585,11 @@ class AgentService:
current_datetime_reference=current_datetime_reference,
user_llm_config=user_llm_config,
)
state.update(_derive_role_memory_contexts(memory_ctx))
yield self._build_progress_event("thinking", "Jarvis 正在分析请求", agent="master", step="理解你的问题")
yield self._build_progress_event(
"thinking", "Jarvis 正在分析请求", agent="master", step="理解你的问题"
)
try:
async for event in graph.astream_events(state, version="v2"):
@@ -368,7 +598,13 @@ class AgentService:
metadata = event.get("metadata", {})
data = event.get("data", {})
if kind == "on_chain_start" and event_name in {"master", "schedule_planner", "executor", "librarian", "analyst"}:
if kind == "on_chain_start" and event_name in {
"master",
"schedule_planner",
"executor",
"librarian",
"analyst",
}:
stage_map = {
"master": ("thinking", "Jarvis 正在理解请求"),
"schedule_planner": ("planning", "Jarvis 正在编排日程"),
@@ -376,9 +612,13 @@ class AgentService:
"librarian": ("tool", "Jarvis 正在检索知识"),
"analyst": ("thinking", "Jarvis 正在分析信息"),
}
stage, label = stage_map.get(event_name, ("thinking", "Jarvis 正在思考"))
yield self._build_progress_event(stage, label, agent=event_name, step=label)
stage, label = stage_map.get(
event_name, ("thinking", "Jarvis 正在思考")
)
yield self._build_progress_event(
stage, label, agent=event_name, step=label
)
elif kind == "on_tool_start":
yield self._build_progress_event(
"tool",
@@ -387,7 +627,7 @@ class AgentService:
tool_name=event_name,
step=f"正在执行 {event_name}",
)
elif kind == "on_tool_end":
tool_result = data.get("output")
step = f"已完成 {event_name}"
@@ -400,14 +640,16 @@ class AgentService:
tool_name=event_name,
step=step,
)
elif kind == "on_chat_model_stream":
chunk = data.get("chunk")
content = _coerce_event_text(getattr(chunk, "content", "") if chunk else "")
content = _coerce_event_text(
getattr(chunk, "content", "") if chunk else ""
)
if content:
collected += content
yield {"type": "chunk", "content": content}
elif kind == "on_chain_end":
output = data.get("output")
final_resp = None
@@ -422,7 +664,9 @@ class AgentService:
elif kind == "on_chat_model_end":
output = data.get("output")
final_content = _coerce_event_text(getattr(output, "content", "") if output else "")
final_content = _coerce_event_text(
getattr(output, "content", "") if output else ""
)
if final_content:
final_text = final_content
if final_text != collected:
@@ -431,12 +675,16 @@ class AgentService:
except Exception as e:
if _is_streaming_rejection_error(e, user_llm_config) and not collected:
yield self._build_progress_event("responding", "Jarvis 正在生成回复", agent="master", step="fallback")
yield self._build_progress_event(
"responding", "Jarvis 正在生成回复", agent="master", step="fallback"
)
try:
result_state = await graph.ainvoke(state)
if isinstance(result_state, dict):
state.update(result_state)
fallback_content = result_state.get("final_response") or str(result_state.get("messages", [AIMessage(content="")])[-1].content)
fallback_content = result_state.get("final_response") or str(
result_state.get("messages", [AIMessage(content="")])[-1].content
)
collected = str(fallback_content)
yield {"type": "chunk", "content": collected}
except Exception:
@@ -460,11 +708,24 @@ class AgentService:
if collected:
assistant_msg.content = collected
continuity_snapshot = _build_continuity_snapshot(state or {})
assistant_msg.attachments = ([{
"kind": "agent_continuity_state",
**continuity_snapshot,
}] if continuity_snapshot else None)
conv.agent_state = continuity_snapshot
assistant_msg.attachments = (
[
{
"kind": "agent_continuity_state",
**continuity_snapshot,
}
]
if continuity_snapshot
else None
)
conv.agent_state = (
{
"kind": "agent_continuity_state",
**continuity_snapshot,
}
if continuity_snapshot
else None
)
await BrainService(self.db).create_event(
user_id,
**_build_assistant_event_payload(collected),
@@ -542,12 +803,16 @@ class AgentService:
importance_signal=1.0,
)
memory_ctx = await memory_service.build_memory_context(self.db, user_id, conversation_id, message)
memory_ctx = await memory_service.build_memory_context(
self.db, user_id, conversation_id, message
)
set_current_user(user_id)
try:
graph = get_agent_graph()
current_datetime_context, current_datetime_reference = self._build_current_datetime_context()
current_datetime_context, current_datetime_reference = (
self._build_current_datetime_context()
)
state = await self._build_agent_state(
user_id=user_id,
conversation=conv,
@@ -557,9 +822,11 @@ class AgentService:
current_datetime_reference=current_datetime_reference,
user_llm_config=user_llm_config,
)
state.update(_derive_role_memory_contexts(memory_ctx))
result_state = await graph.ainvoke(state)
response_content = result_state.get("final_response") or str(result_state.get("messages", [AIMessage(content="")])[-1].content)
response_content = result_state.get("final_response") or str(
result_state.get("messages", [AIMessage(content="")])[-1].content
)
except Exception as e:
logger.exception("agent_chat_simple_failed")
response_content = "抱歉,发生错误。"
@@ -580,12 +847,27 @@ class AgentService:
)
assistant_msg.content = response_content
continuity_snapshot = _build_continuity_snapshot(result_state) if 'result_state' in locals() else None
assistant_msg.attachments = ([{
"kind": "agent_continuity_state",
**continuity_snapshot,
}] if continuity_snapshot else None)
conv.agent_state = continuity_snapshot
continuity_snapshot = (
_build_continuity_snapshot(result_state) if "result_state" in locals() else None
)
assistant_msg.attachments = (
[
{
"kind": "agent_continuity_state",
**continuity_snapshot,
}
]
if continuity_snapshot
else None
)
conv.agent_state = (
{
"kind": "agent_continuity_state",
**continuity_snapshot,
}
if continuity_snapshot
else None
)
await self.db.commit()
await self.db.refresh(assistant_msg)

View File

@@ -4,12 +4,15 @@ Jarvis 记忆系统 (基于 Mem0)
底层使用 Mem0 实现事实提取、时间线、矛盾解决和遗忘机制
"""
import logging
import os
from datetime import datetime
import re
from datetime import UTC, datetime
from typing import Optional, Any
from sqlalchemy import select, desc, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.conversation import Conversation, Message
from app.models.memory import UserMemory
from app.models.user import User
from app.services.brain_service import BrainService
from app.config import settings as _settings
@@ -23,6 +26,9 @@ except ImportError:
Memory = None
logger = logging.getLogger(__name__)
async def _get_user_embedding_config(db: AsyncSession, user_id: str) -> dict | None:
"""从用户配置中获取 embedding 模型配置"""
result = await db.execute(select(User).where(User.id == user_id))
@@ -296,6 +302,23 @@ async def extract_user_memories(
return []
def _extract_memory_query_tokens(query: str) -> list[str]:
normalized_query = (query or "").lower()
tokens = [token for token in re.findall(r"[a-z0-9]+", normalized_query) if len(token) >= 3]
for chunk in re.findall(r"[\u4e00-\u9fff]+", query or ""):
stripped_chunk = chunk.strip()
if len(stripped_chunk) >= 4:
tokens.append(stripped_chunk)
if len(stripped_chunk) > 6:
tokens.extend(
stripped_chunk[index:index + 4]
for index in range(len(stripped_chunk) - 3)
)
return list(dict.fromkeys(tokens))
async def recall_user_memories(
db: AsyncSession,
user_id: str,
@@ -304,7 +327,7 @@ async def recall_user_memories(
) -> list[dict]:
"""
根据当前输入召回相关的用户记忆。
使用 Mem0 的语义搜索。
使用 Mem0 的语义搜索;如果 Mem0 不可用或失败,则回退到本地 UserMemory
"""
try:
mem0 = await get_mem0(db, user_id)
@@ -313,10 +336,56 @@ async def recall_user_memories(
filters={"user_id": user_id},
limit=top_k,
)
return results.get("results", [])
mem0_results = results.get("results", [])
if mem0_results:
return mem0_results
except Exception as e:
print(f"Mem0 search error: {e}")
return []
query_tokens = _extract_memory_query_tokens(query)
statement = select(UserMemory).where(UserMemory.user_id == user_id)
result = await db.execute(statement.order_by(UserMemory.importance.desc(), UserMemory.created_at.desc()))
fallback_memories = list(result.scalars().all())
if _contains_hint(_normalize_query(query), MEMORY_QUERY_HINTS) or _matches_memory_query_pattern(_normalize_query(query)):
return fallback_memories[:top_k]
if query_tokens:
matched_memories = [
memory for memory in fallback_memories
if any(token in (memory.content or '').lower() for token in query_tokens)
]
return matched_memories[:top_k]
return []
async def _mark_memories_recalled(db: AsyncSession, memories: list[UserMemory]) -> None:
recalled_at = datetime.now(UTC)
updated = False
for memory in memories:
memory.is_recalled = True
memory.recall_count = (memory.recall_count or 0) + 1
memory.last_recalled_at = recalled_at
updated = True
if updated:
await db.commit()
async def _run_tolerated_section(
db: AsyncSession,
section_name: str,
builder,
) -> str:
try:
return await builder()
except Exception:
logger.warning(
"[MemoryService] %s失败,继续构建剩余上下文",
section_name,
exc_info=True,
)
return ""
async def get_user_profile(db: AsyncSession, user_id: str) -> dict:
@@ -339,6 +408,131 @@ async def get_user_profile(db: AsyncSession, user_id: str) -> dict:
# ———— 记忆组装: 供 Agent 使用的上下文 ————
MEMORY_QUERY_HINTS = (
"记住",
"记下",
"记一下",
"记着",
"提醒",
"偏好",
"习惯",
)
MEMORY_QUERY_PATTERNS = (
re.compile(r"\bremember\s+(?:that\s+)?i\b"),
)
GROUNDING_QUERY_HINTS = (
"根据文档",
"严格根据",
"只根据",
"文档内容",
"grounded",
"strictly based on",
"based on the document",
"based on the docs",
"document only",
"docs only",
"only use the document",
"only use the docs",
)
AVOID_USER_MEMORY_HINTS = (
"不要结合我的个人偏好",
"不要结合个人偏好",
"不要结合偏好",
"不要结合我的记忆",
"不要结合记忆",
)
def _normalize_query(text: str) -> str:
return text.strip().lower()
def _contains_hint(text: str, hints: tuple[str, ...]) -> bool:
return any(hint in text for hint in hints)
def _matches_memory_query_pattern(text: str) -> bool:
return any(pattern.search(text) for pattern in MEMORY_QUERY_PATTERNS)
def _should_include_user_memories(query: str) -> bool:
normalized_query = _normalize_query(query)
if _contains_hint(normalized_query, GROUNDING_QUERY_HINTS):
return False
if _contains_hint(normalized_query, AVOID_USER_MEMORY_HINTS):
return False
return True
def _should_include_summaries(query: str) -> bool:
normalized_query = _normalize_query(query)
if _contains_hint(normalized_query, GROUNDING_QUERY_HINTS):
return False
if _contains_hint(normalized_query, MEMORY_QUERY_HINTS):
return False
if _matches_memory_query_pattern(normalized_query):
return False
return True
async def _build_user_memory_section(
db: AsyncSession,
user_id: str,
current_query: str,
) -> str:
memories = await recall_user_memories(db, user_id, current_query, top_k=5)
if not memories:
return ""
lines = []
recalled_user_memories: list[UserMemory] = []
for memory in memories:
if isinstance(memory, UserMemory):
memory_text = memory.content
memory_type = memory.memory_type
recalled_user_memories.append(memory)
else:
memory_text = memory.get("memory", memory.get("text", ""))
memory_type = memory.get("memory_type")
if not memory_text:
continue
if memory_type:
lines.append(f" [{memory_type}] {memory_text}")
else:
lines.append(f" - {memory_text}")
if not lines:
return ""
if recalled_user_memories:
await _mark_memories_recalled(db, recalled_user_memories)
return "【用户记忆】\n" + "\n".join(lines)
async def _build_summary_section(db: AsyncSession, conversation_id: str) -> str:
summaries = await get_summaries(db, conversation_id)
if not summaries:
return ""
recent = summaries[-2:]
lines = [f"[对话摘要{i + 1}] {summary.summary_text}" for i, summary in enumerate(recent)]
return "【之前对话摘要】\n" + "\n".join(lines)
async def _build_brain_section(
db: AsyncSession,
user_id: str,
current_query: str,
) -> str:
brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3)
if not brain_memories:
return ""
lines = [f"- {memory.title}: {memory.content}" for memory in brain_memories]
return "【知识大脑】\n" + "\n".join(lines)
async def build_memory_context(
db: AsyncSession,
@@ -350,30 +544,33 @@ async def build_memory_context(
构建完整的记忆上下文字符串,
供注入到 Agent system prompt 中使用。
"""
parts = []
parts: list[str] = []
memories = await recall_user_memories(db, user_id, current_query, top_k=5)
if memories:
lines = []
for m in memories:
memory_text = m.get("memory", m.get("text", ""))
if memory_text:
lines.append(f" - {memory_text}")
if lines:
parts.append("【用户记忆】\n" + "\n".join(lines))
if _should_include_user_memories(current_query):
user_memory_section = await _run_tolerated_section(
db,
"用户记忆召回",
lambda: _build_user_memory_section(db, user_id, current_query),
)
if user_memory_section:
parts.append(user_memory_section)
summaries = await get_summaries(db, conversation_id)
if summaries:
recent = summaries[-2:]
lines = [f"[对话摘要{i + 1}] {s.summary_text}" for i, s in enumerate(recent)]
parts.append("【之前对话摘要】\n" + "\n".join(lines))
if _should_include_summaries(current_query):
summary_section = await _run_tolerated_section(
db,
"对话摘要加载",
lambda: _build_summary_section(db, conversation_id),
)
if summary_section:
parts.append(summary_section)
brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3)
if brain_memories:
lines = []
for memory in brain_memories:
lines.append(f"- {memory.title}: {memory.content}")
parts.append("【知识大脑】\n" + "\n".join(lines))
brain_section = await _run_tolerated_section(
db,
"知识大脑召回",
lambda: _build_brain_section(db, user_id, current_query),
)
if brain_section:
parts.append(brain_section)
if not parts:
return ""

View File

@@ -0,0 +1,167 @@
from app.agents.schemas.event import AgentEvent
from app.agents.schemas.task import AgentTask, CollaborationBudget, InterruptRecord, RecoveryRecord, TaskResult
def test_agent_task_accepts_day1_fields():
task = AgentTask(
task_id="task-1",
title="Verify foundation",
status="in_progress",
owner_agent_id="executor",
role="verifier",
goal="check output",
expected_evidence=[{"type": "assertion"}],
evidence=[{"type": "log"}],
result_summary="running",
)
assert task.task_id == "task-1"
assert task.owner_agent_id == "executor"
assert task.status == "in_progress"
assert task.expected_evidence == [{"type": "assertion"}]
assert task.evidence == [{"type": "log"}]
assert task.result_summary == "running"
def test_agent_task_accepts_day3_runtime_fields():
task = AgentTask(
task_id="task-2",
title="Recover interrupted collaboration",
owner_agent_id="executor",
parent_task_id="task-1",
child_task_ids=["task-2a"],
thread_id="thread-1",
message_id="msg-1",
message_index=2,
interrupt_records=[
InterruptRecord(
interrupt_id="interrupt-1",
reason="manual stop",
requested_by="coordinator",
)
],
recovery_records=[
RecoveryRecord(
recovery_id="recovery-1",
source_interrupt_id="interrupt-1",
resumed_from_task_id="task-2",
resumed_from_thread_id="thread-1",
strategy="resume_from_checkpoint",
)
],
collaboration_budget=CollaborationBudget(
mode="collaboration",
max_parallel_tasks=2,
remaining_parallel_tasks=1,
max_tool_calls=4,
remaining_tool_calls=3,
max_iterations=5,
remaining_iterations=4,
escalation_threshold=1,
metadata={"max_spawn_depth": 2},
),
)
assert task.parent_task_id == "task-1"
assert task.child_task_ids == ["task-2a"]
assert task.thread_id == "thread-1"
assert task.message_id == "msg-1"
assert task.message_index == 2
assert task.interrupt_records[0].interrupt_id == "interrupt-1"
assert task.recovery_records[0].recovery_id == "recovery-1"
assert task.collaboration_budget.mode == "collaboration"
assert task.collaboration_budget.metadata == {"max_spawn_depth": 2}
def test_agent_event_accepts_day1_fields():
event = AgentEvent(
event_id="evt-1",
event_type="agent.verify.completed",
conversation_id="conv-1",
agent_id="executor",
sub_commander_id="executor_tasks",
task_id="task-1",
payload={"status": "passed"},
severity="info",
)
assert event.event_id == "evt-1"
assert event.event_type == "agent.verify.completed"
assert event.conversation_id == "conv-1"
assert event.payload == {"status": "passed"}
assert event.severity == "info"
def test_agent_event_accepts_day3_trace_fields():
event = AgentEvent(
event_id="evt-2",
event_type="agent.collaboration.budget.updated",
conversation_id="conv-1",
agent_id="coordinator",
task_id="task-2",
parent_task_id="task-1",
child_task_id="task-2a",
thread_id="thread-1",
message_id="msg-3",
interrupt_id="interrupt-1",
recovery_id="recovery-1",
payload={"remaining_parallel_tasks": 1},
severity="warning",
)
assert event.parent_task_id == "task-1"
assert event.child_task_id == "task-2a"
assert event.thread_id == "thread-1"
assert event.message_id == "msg-3"
assert event.interrupt_id == "interrupt-1"
assert event.recovery_id == "recovery-1"
assert event.severity == "warning"
def test_task_result_supports_collaboration_result_fields():
result = TaskResult(
task_id="task-1",
status="completed",
summary="retrieval finished",
evidence=[{"type": "source"}],
owner_agent_id="librarian",
next_action="handoff_to_analyst",
)
assert result.status == "completed"
assert result.owner_agent_id == "librarian"
assert result.next_action == "handoff_to_analyst"
def test_task_result_supports_day3_thread_budget_and_recovery_fields():
result = TaskResult(
task_id="task-2",
status="blocked",
owner_agent_id="executor",
parent_task_id="task-1",
child_task_ids=["task-2a"],
thread_id="thread-1",
message_id="msg-4",
message_index=4,
interrupt_records=[{"interrupt_id": "interrupt-1", "reason": "budget exceeded"}],
recovery_records=[{"recovery_id": "recovery-1", "strategy": "resume_after_budget_reset"}],
budget_snapshot=CollaborationBudget(
mode="collaboration",
max_parallel_tasks=2,
remaining_parallel_tasks=0,
max_tool_calls=4,
remaining_tool_calls=0,
),
next_action="resume_after_budget_reset",
)
assert result.parent_task_id == "task-1"
assert result.child_task_ids == ["task-2a"]
assert result.thread_id == "thread-1"
assert result.message_id == "msg-4"
assert result.message_index == 4
assert result.interrupt_records[0].interrupt_id == "interrupt-1"
assert result.recovery_records[0].recovery_id == "recovery-1"
assert result.budget_snapshot.mode == "collaboration"
assert result.budget_snapshot.remaining_parallel_tasks == 0
assert result.next_action == "resume_after_budget_reset"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,317 @@
import sys
from types import SimpleNamespace
from unittest.mock import Mock
from langchain_core.messages import AIMessage, HumanMessage
sys.modules.setdefault("trafilatura", Mock())
from app.agents.graph import _build_system_messages, _run_sub_commander
from app.agents.state import AgentRole
def _base_state(message: str, user_llm_config: dict | None = None) -> dict:
return {
"messages": [HumanMessage(content=message)],
"user_id": "u1",
"conversation_id": "c1",
"current_agent": AgentRole.MASTER,
"active_agents": [AgentRole.MASTER],
"current_sub_commander": None,
"active_sub_commanders": [],
"sub_commander_trace": [],
"pending_tasks": [],
"completed_tasks": [],
"tool_calls": [],
"last_tool_result": None,
"action_results": [],
"created_entities": [],
"tool_strategy_used": None,
"provider_capabilities": None,
"fallback_parse_error": None,
"knowledge_context": None,
"graph_context": None,
"schedule_context_summary": None,
"plan": None,
"plan_steps": [],
"analysis_report": None,
"final_response": None,
"should_respond": True,
"memory_context": "memory context",
"current_datetime_context": "CURRENT_TIME: 2026-03-28T12:00:00+08:00",
"current_datetime_reference": {
"current_time_iso": "2026-03-28T12:00:00+08:00",
"current_date_iso": "2026-03-28",
"timezone": "UTC",
},
"user_llm_config": user_llm_config,
}
class FakeTool:
def __init__(self, name: str, result: str):
self.name = name
self.result = result
self.invocations: list[dict] = []
def invoke(self, args: dict):
self.invocations.append(args)
return self.result
class SingleSystemMessageLLM:
def __init__(self):
self.calls = 0
self.system_message_counts: list[int] = []
self._jarvis_provider_capabilities = SimpleNamespace(
provider="minimax",
supports_native_tools=False,
preferred_tool_strategy="json_fallback",
)
async def ainvoke(self, messages):
self.calls += 1
self.system_message_counts.append(
sum(1 for message in messages if getattr(message, "type", None) == "system")
)
if self.system_message_counts[-1] != 1:
raise AssertionError(
f"expected exactly one system message, got {self.system_message_counts[-1]}"
)
if self.calls == 1:
return AIMessage(
content=(
'{"mode":"tool_call","tool_calls":[{"name":"create_reminder",'
'"arguments":{"title":"blanket","reminder_at":"\\u660e\\u5929 09:00"}}]}'
)
)
return AIMessage(content="created reminder for blanket")
def test_build_system_messages_includes_structured_continuity_summary():
state = _base_state("创建")
state["pending_action"] = {
"type": "schedule_creation",
"summary": "为周报安排明天下午提醒",
"status": "pending",
}
state["routing_decision"] = {
"target_agent": AgentRole.SCHEDULE_PLANNER.value,
"reason": "continue_pending_action",
}
state["continuity_state"] = {"status": "fresh"}
messages = _build_system_messages(
state,
"manager prompt",
AgentRole.SCHEDULE_PLANNER,
"schedule_planning",
)
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
assert "pending_action" in system_text
assert "schedule_creation" in system_text
assert "continue_pending_action" in system_text
assert "为周报安排明天下午提醒" in system_text
def test_build_system_messages_skips_structured_continuity_when_pending_action_is_not_pending():
state = _base_state("创建")
state["pending_action"] = {
"type": "schedule_creation",
"summary": "为周报安排明天下午提醒",
"status": "completed",
}
state["routing_decision"] = {
"target_agent": AgentRole.SCHEDULE_PLANNER.value,
"reason": "continue_pending_action",
}
state["continuity_state"] = {"status": "fresh"}
messages = _build_system_messages(
state,
"manager prompt",
AgentRole.SCHEDULE_PLANNER,
"schedule_planning",
)
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
assert "structured_continuity" not in system_text
assert "continue_pending_action" not in system_text
def test_build_system_messages_skips_structured_continuity_when_routing_reason_is_not_continuation():
state = _base_state("创建")
state["pending_action"] = {
"type": "schedule_creation",
"summary": "为周报安排明天下午提醒",
"status": "pending",
}
state["routing_decision"] = {
"target_agent": AgentRole.SCHEDULE_PLANNER.value,
"reason": "initial_schedule_detection",
}
state["continuity_state"] = {"status": "fresh"}
messages = _build_system_messages(
state,
"manager prompt",
AgentRole.SCHEDULE_PLANNER,
"schedule_planning",
)
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
assert "structured_continuity" not in system_text
assert "continue_pending_action" not in system_text
def test_build_system_messages_skips_structured_continuity_when_routing_decision_missing():
state = _base_state("创建")
state["pending_action"] = {
"type": "schedule_creation",
"summary": "为周报安排明天下午提醒",
}
state["routing_decision"] = None
messages = _build_system_messages(
state,
"manager prompt",
AgentRole.SCHEDULE_PLANNER,
"schedule_planning",
)
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
assert "pending_action" not in system_text
assert "schedule_creation" not in system_text
assert "为周报安排明天下午提醒" not in system_text
def test_build_system_messages_skips_stale_structured_continuity_for_unrelated_new_request():
state = _base_state(
"帮我搜索 Rust 异步 trait 最佳实践",
{
"provider": "openai",
"model": "MiniMax-M2.7-highspeed",
"base_url": "https://api.minimaxi.com/v1",
},
)
state["current_agent"] = AgentRole.SCHEDULE_PLANNER
state["pending_action"] = {
"type": "schedule_creation",
"summary": "为周报安排明天下午提醒",
"status": "pending",
}
state["routing_decision"] = {
"target_agent": AgentRole.SCHEDULE_PLANNER.value,
"reason": "continue_pending_action",
}
state["continuity_state"] = {
"status": "stale",
"override_reason": "new_explicit_request",
}
messages = _build_system_messages(
state,
"manager prompt",
AgentRole.SCHEDULE_PLANNER,
"schedule_planning",
)
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
assert "structured_continuity" not in system_text
assert "pending_action" not in system_text
assert "continue_pending_action" not in system_text
def test_build_system_messages_uses_role_scoped_context_instead_of_raw_memory_blob():
state = _base_state("帮我搜索 Rust 异步 trait 最佳实践")
state["memory_context"] = "【用户记忆】\n- 用户喜欢燕麦拿铁。\n\n【之前对话摘要】\n[对话摘要1] 之前聊过提醒。\n\n【知识大脑】\n- Rust Async: trait object 需要 pin。"
state["schedule_context_summary"] = "【用户记忆】\n- 用户喜欢燕麦拿铁。\n\n【之前对话摘要】\n[对话摘要1] 之前聊过提醒。"
state["knowledge_context"] = "【知识大脑】\n- Rust Async: trait object 需要 pin。"
state["analysis_report"] = "【之前对话摘要】\n[对话摘要1] 之前聊过提醒。\n\n【知识大脑】\n- Rust Async: trait object 需要 pin。"
messages = _build_system_messages(
state,
"manager prompt",
AgentRole.LIBRARIAN,
"librarian_retrieval",
)
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
assert "角色上下文" in system_text
assert "【知识大脑】" in system_text
assert "Rust Async" in system_text
assert "用户喜欢燕麦拿铁" not in system_text
assert "记忆上下文" not in system_text
def test_build_system_messages_keeps_fresh_structured_continuity_for_matching_followup():
state = _base_state(
"创建",
{
"provider": "openai",
"model": "MiniMax-M2.7-highspeed",
"base_url": "https://api.minimaxi.com/v1",
},
)
state["current_agent"] = AgentRole.SCHEDULE_PLANNER
state["pending_action"] = {
"type": "schedule_creation",
"summary": "为周报安排明天下午提醒",
"status": "pending",
}
state["routing_decision"] = {
"target_agent": AgentRole.SCHEDULE_PLANNER.value,
"reason": "continue_pending_action",
}
state["continuity_state"] = {
"status": "fresh",
}
messages = _build_system_messages(
state,
"manager prompt",
AgentRole.SCHEDULE_PLANNER,
"schedule_planning",
)
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
assert "pending_action" in system_text
assert "continue_pending_action" in system_text
async def test_run_sub_commander_coalesces_system_messages_for_openai_compatible_provider(
monkeypatch,
):
fake_llm = SingleSystemMessageLLM()
fake_tool = FakeTool("create_reminder", "created reminder: blanket @ tomorrow 09:00")
monkeypatch.setattr("app.agents.graph._get_llm_for_state", lambda state: fake_llm)
monkeypatch.setitem(
__import__("app.agents.graph", fromlist=["SUB_COMMANDER_TOOLSETS"]).SUB_COMMANDER_TOOLSETS,
"schedule_planning",
[fake_tool],
)
state = _base_state(
"给我设置明天的提醒,提醒我收被子",
{
"provider": "openai",
"model": "MiniMax-M2.7-highspeed",
"base_url": "https://api.minimaxi.com/v1",
},
)
state["current_agent"] = AgentRole.SCHEDULE_PLANNER
result = await _run_sub_commander(
state,
AgentRole.SCHEDULE_PLANNER,
"manager prompt",
"给我设置明天的提醒,提醒我收被子",
use_tools=True,
)
assert fake_llm.system_message_counts == [1, 1]
assert result["tool_strategy_used"] == "json_fallback"
assert fake_tool.invocations == [{"title": "blanket", "reminder_at": "2026-03-29T09:00:00"}]
assert result["final_response"] == "created reminder for blanket"

View File

@@ -1,4 +1,4 @@
from app.agents.prompts import MASTER_SYSTEM_PROMPT
from app.agents.prompts import COORDINATOR_SYSTEM_PROMPT, MASTER_SYSTEM_PROMPT
def test_master_prompt_forbids_subagent_rollcall_in_simple_greetings():
@@ -10,3 +10,10 @@ def test_master_prompt_does_not_include_full_canned_answers_for_greetings_or_ide
assert 'Jarvis您好。我在。' not in MASTER_SYSTEM_PROMPT
assert 'Jarvis我是 Jarvis。' not in MASTER_SYSTEM_PROMPT
assert 'Jarvis主要做三件事。' not in MASTER_SYSTEM_PROMPT
def test_coordinator_prompt_limits_collaboration_scope():
assert "2~4 个子任务" in COORDINATOR_SYSTEM_PROMPT
assert "禁止无限递归拆分" in COORDINATOR_SYSTEM_PROMPT
assert "schedule_planner" in COORDINATOR_SYSTEM_PROMPT
assert "librarian" in COORDINATOR_SYSTEM_PROMPT

View File

@@ -5,11 +5,13 @@ from app.agents.prompts import (
SUB_COMMANDER_PROMPTS_BY_KEY,
TOP_LEVEL_SYSTEM_PROMPTS_BY_KEY,
)
from app.agents.registry import build_registry_indexes, load_builtin_registry_bundle
from app.agents.registry import build_registry_indexes, load_builtin_registry_bundle, load_builtin_registry_indexes
from app.agents.registry.indexes import summarize_registry_indexes
from app.agents.registry.models import (
AgentManifest,
CapabilityManifest,
PermissionClass,
SideEffectScope,
SpecialistTemplateManifest,
SubCommanderManifest,
)
@@ -251,17 +253,34 @@ def test_builtin_capabilities_reference_actual_runtime_tool_names() -> None:
assert manifest_tool_names == expected_tool_names
def test_builtin_sub_commander_capabilities_match_runtime_toolsets() -> None:
capabilities_by_tool_name = {
manifest.tool_name: manifest.capability_id for manifest in BUILTIN_CAPABILITY_MANIFESTS
}
def test_builtin_capability_metadata_distinguishes_read_and_write_surfaces() -> None:
capability_by_id = {manifest.capability_id: manifest for manifest in BUILTIN_CAPABILITY_MANIFESTS}
for sub_commander in BUILTIN_SUB_COMMANDER_MANIFESTS:
expected_capability_ids = {
capabilities_by_tool_name[tool.name]
for tool in SUB_COMMANDER_TOOLSETS[sub_commander.sub_commander_id]
}
assert set(sub_commander.capability_ids) == expected_capability_ids
assert capability_by_id["get_tasks"].permission_class == PermissionClass.READ
assert capability_by_id["get_tasks"].side_effect_scope == SideEffectScope.NONE
assert capability_by_id["get_tasks"].supports_retry is True
assert capability_by_id["get_tasks"].idempotent is True
assert capability_by_id["get_tasks"].safe_for_parallel_use is True
assert capability_by_id["get_tasks"].requires_confirmation is False
assert capability_by_id["create_reminder"].permission_class == PermissionClass.WRITE
assert capability_by_id["create_reminder"].side_effect_scope == SideEffectScope.LOCAL_STATE
assert capability_by_id["create_reminder"].supports_retry is False
assert capability_by_id["create_reminder"].idempotent is False
assert capability_by_id["create_reminder"].safe_for_parallel_use is False
assert capability_by_id["create_reminder"].requires_confirmation is True
assert capability_by_id["web_search"].permission_class == PermissionClass.EXTERNAL
assert capability_by_id["web_search"].side_effect_scope == SideEffectScope.NETWORK
def test_load_builtin_registry_indexes_is_cached_and_matches_bundle_indexes() -> None:
cached = load_builtin_registry_indexes()
rebuilt = build_registry_indexes(load_builtin_registry_bundle())
assert cached is load_builtin_registry_indexes()
assert cached.capability_id_by_tool_name == rebuilt.capability_id_by_tool_name
assert cached.capability_by_id["create_reminder"].requires_confirmation is True
def test_builtin_manifests_form_a_valid_registry_bundle() -> None:
@@ -288,6 +307,7 @@ def test_build_registry_indexes_exposes_manifest_lookups_by_id() -> None:
indexes = build_registry_indexes(bundle)
assert indexes.agent_by_id
assert indexes.agent_by_role_value
assert indexes.sub_commander_by_id
assert indexes.capability_by_id
assert isinstance(indexes.specialist_template_by_id, Mapping)
@@ -343,6 +363,14 @@ def test_build_registry_indexes_exposes_prompt_keys_skill_context_keys_and_capab
sub_commander.sub_commander_id: tuple(sub_commander.capability_ids)
for sub_commander in bundle.sub_commanders
}
assert indexes.agent_by_role_value == {
agent.role_value: agent for agent in bundle.agents
}
assert indexes.spawnable_role_values_by_agent_id == {
agent.agent_id: tuple(agent.allowed_spawn_role_values)
for agent in bundle.agents
if agent.can_spawn_children and agent.allowed_spawn_role_values
}
def test_validate_registry_bundle_accepts_loaded_builtin_registry_bundle() -> None:

View File

@@ -0,0 +1,135 @@
from app.agents.schemas import AgentEvent, AgentTask, TaskResult
from app.agents.schemas.task import CollaborationBudget, InterruptRecord, RecoveryRecord
from app.agents.state import initial_state
from app.agents.verifier import apply_verification_verdict, normalize_task_result, verify_task_result
def test_agent_task_supports_day3_interrupt_recovery_and_budget_fields():
interrupt = InterruptRecord(interrupt_id="interrupt-1", reason="user_cancel")
recovery = RecoveryRecord(recovery_id="recovery-1", source_interrupt_id="interrupt-1", resumed_from_task_id="task-1")
budget = CollaborationBudget(
mode="collaboration",
max_parallel_tasks=3,
remaining_parallel_tasks=2,
max_tool_calls=6,
remaining_tool_calls=4,
metadata={"phase": "day3"},
)
task = AgentTask(
task_id="task-1",
title="Recover interrupted collaboration task",
owner_agent_id="analyst",
role="analyst",
parent_task_id="parent-1",
child_task_ids=["child-1"],
thread_id="thread-1",
message_id="message-1",
message_index=3,
interrupt_records=[interrupt],
recovery_records=[recovery],
collaboration_budget=budget,
)
payload = task.model_dump(mode="json")
assert payload["parent_task_id"] == "parent-1"
assert payload["child_task_ids"] == ["child-1"]
assert payload["thread_id"] == "thread-1"
assert payload["message_id"] == "message-1"
assert payload["message_index"] == 3
assert payload["interrupt_records"][0]["interrupt_id"] == "interrupt-1"
assert payload["recovery_records"][0]["recovery_id"] == "recovery-1"
assert payload["collaboration_budget"]["mode"] == "collaboration"
assert payload["collaboration_budget"]["remaining_tool_calls"] == 4
def test_agent_event_supports_day3_thread_interrupt_and_recovery_metadata():
event = AgentEvent(
event_id="evt-1",
event_type="agent.task.recovered",
conversation_id="conv-1",
agent_id="executor",
task_id="task-1",
parent_task_id="parent-1",
child_task_id="child-1",
thread_id="thread-1",
message_id="message-1",
interrupt_id="interrupt-1",
recovery_id="recovery-1",
severity="warning",
payload={"status": "resumed"},
)
payload = event.model_dump(mode="json")
assert payload["event_type"] == "agent.task.recovered"
assert payload["parent_task_id"] == "parent-1"
assert payload["child_task_id"] == "child-1"
assert payload["thread_id"] == "thread-1"
assert payload["message_id"] == "message-1"
assert payload["interrupt_id"] == "interrupt-1"
assert payload["recovery_id"] == "recovery-1"
assert payload["severity"] == "warning"
def test_normalize_task_result_preserves_day3_metadata_fields():
result = normalize_task_result(
{
"task_id": "task-1",
"status": "completed",
"summary": "Recovered successfully.",
"owner_agent_id": "executor",
"parent_task_id": "parent-1",
"child_task_ids": ["child-1"],
"thread_id": "thread-1",
"message_id": "message-1",
"message_index": 2,
"interrupt_records": [{"interrupt_id": "interrupt-1", "reason": "user_pause"}],
"recovery_records": [{"recovery_id": "recovery-1", "source_interrupt_id": "interrupt-1"}],
"budget_snapshot": {"mode": "collaboration", "max_parallel_tasks": 4},
"next_action": "notify_user",
"output_data": {"ok": True},
}
)
assert result.parent_task_id == "parent-1"
assert result.child_task_ids == ["child-1"]
assert result.thread_id == "thread-1"
assert result.message_id == "message-1"
assert result.message_index == 2
assert result.interrupt_records[0].interrupt_id == "interrupt-1"
assert result.recovery_records[0].recovery_id == "recovery-1"
assert result.budget_snapshot.mode == "collaboration"
assert result.budget_snapshot.max_parallel_tasks == 4
assert result.next_action == "notify_user"
assert result.output_data == {"ok": True}
def test_apply_verification_verdict_updates_state_with_recovery_evidence():
state = initial_state("u1", "c1")
verdict = verify_task_result(
status="passed",
summary="Interrupt and recovery chain verified.",
evidence=[
{
"task_id": "task-1",
"thread_id": "thread-1",
"interrupt_id": "interrupt-1",
"recovery_id": "recovery-1",
}
],
)
updated_state = apply_verification_verdict(state, verdict)
assert updated_state["verification_status"] == "passed"
assert updated_state["verification_summary"] == "Interrupt and recovery chain verified."
assert updated_state["verification_evidence"] == [
{
"task_id": "task-1",
"thread_id": "thread-1",
"interrupt_id": "interrupt-1",
"recovery_id": "recovery-1",
}
]

View File

@@ -47,3 +47,27 @@ def test_web_search_tool_returns_stable_message_when_unavailable(monkeypatch):
result = web_search.func('Jarvis')
assert result == '网页搜索不可用: 网页搜索未启用或未配置'
@pytest.mark.asyncio
async def test_web_search_tool_runs_from_active_event_loop(monkeypatch):
class FakeService:
async def search(self, query: str, limit: int | None = None):
assert query == 'Jarvis 最新更新'
assert limit == 1
return [
FakeResult(
title='Jarvis release notes',
url='https://example.com/jarvis-release',
snippet='Latest Jarvis changes.',
source='duckduckgo',
published_at='2026-03-29',
)
]
monkeypatch.setattr('app.services.web_search_service.WebSearchService', FakeService)
result = web_search.func('Jarvis 最新更新', top_k=1)
assert '[1] Jarvis release notes' in result
assert '链接: https://example.com/jarvis-release' in result

View File

@@ -2,6 +2,7 @@ import pytest
from app.agents.tools import forum as forum_tools
from app.agents.tools import schedule as schedule_tools
from app.agents.tools import search as search_tools
from app.agents.tools import task as task_tools
@@ -12,6 +13,7 @@ from app.agents.tools import task as task_tools
(task_tools, "task"),
(schedule_tools, "schedule"),
(forum_tools, "forum"),
(search_tools, "search"),
],
)
async def test_run_async_bridge_works_inside_running_event_loop(module, label):

Some files were not shown because too many files have changed in this diff Show More