211 lines
6.3 KiB
Python
211 lines
6.3 KiB
Python
|
|
"""流式工具执行器 - Phase 6.3
|
|||
|
|
|
|||
|
|
支持流式输出的工具执行器。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import asyncio
|
|||
|
|
import json
|
|||
|
|
from typing import Any, AsyncGenerator
|
|||
|
|
|
|||
|
|
from app.agents.tools.hooks.executor import get_hook_executor
|
|||
|
|
from app.agents.tools.hooks.types import ExecutionContext
|
|||
|
|
from app.agents.tools.registry import get_tool_registry
|
|||
|
|
|
|||
|
|
|
|||
|
|
class StreamingToolExecutor:
|
|||
|
|
"""流式工具执行器
|
|||
|
|
|
|||
|
|
支持:
|
|||
|
|
- 普通工具的同步/异步执行
|
|||
|
|
- 流式工具的流式输出
|
|||
|
|
- Hook 拦截(pre/post/error)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
self._registry = get_tool_registry()
|
|||
|
|
self._hook_executor = get_hook_executor()
|
|||
|
|
|
|||
|
|
async def execute(
|
|||
|
|
self,
|
|||
|
|
tool_name: str,
|
|||
|
|
tool_input: dict[str, Any],
|
|||
|
|
user_id: str | None = None,
|
|||
|
|
) -> Any:
|
|||
|
|
"""执行工具(非流式)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
tool_name: 工具名称
|
|||
|
|
tool_input: 工具输入参数
|
|||
|
|
user_id: 用户 ID(可选)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
工具执行结果
|
|||
|
|
"""
|
|||
|
|
# 创建执行上下文
|
|||
|
|
context = ExecutionContext(
|
|||
|
|
tool_name=tool_name,
|
|||
|
|
tool_input=tool_input,
|
|||
|
|
user_id=user_id,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 获取工具和执行器
|
|||
|
|
manifest = self._registry.get(tool_name)
|
|||
|
|
if manifest is None:
|
|||
|
|
raise ValueError(f"Tool not found: {tool_name}")
|
|||
|
|
|
|||
|
|
executor = self._registry.get_executor(tool_name)
|
|||
|
|
if executor is None:
|
|||
|
|
raise ValueError(f"Executor not found for tool: {tool_name}")
|
|||
|
|
|
|||
|
|
# 检查是否跳过
|
|||
|
|
if await self._hook_executor.execute_skip_check(context):
|
|||
|
|
return {"skipped": True, "tool": tool_name}
|
|||
|
|
|
|||
|
|
# 执行 pre-hooks
|
|||
|
|
continue_execution, modified_input = await self._hook_executor.execute_pre_hooks(context)
|
|||
|
|
if not continue_execution:
|
|||
|
|
return {"pre_hook_aborted": True, "tool": tool_name}
|
|||
|
|
|
|||
|
|
# 执行工具
|
|||
|
|
try:
|
|||
|
|
context.start_time = asyncio.get_event_loop().time()
|
|||
|
|
|
|||
|
|
# 判断是同步还是异步执行
|
|||
|
|
if asyncio.iscoroutinefunction(executor):
|
|||
|
|
result = await executor(**modified_input)
|
|||
|
|
else:
|
|||
|
|
# 同步函数在线程池中执行
|
|||
|
|
loop = asyncio.get_event_loop()
|
|||
|
|
result = await loop.run_in_executor(None, lambda: executor(**modified_input))
|
|||
|
|
|
|||
|
|
context.result = result
|
|||
|
|
context.end_time = asyncio.get_event_loop().time()
|
|||
|
|
|
|||
|
|
# 执行 post-hooks
|
|||
|
|
result = await self._hook_executor.execute_post_hooks(context, result)
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
context.error = e
|
|||
|
|
context.end_time = asyncio.get_event_loop().time()
|
|||
|
|
|
|||
|
|
# 执行 error-hooks
|
|||
|
|
error_result = await self._hook_executor.execute_error_hooks(context, e)
|
|||
|
|
if error_result is not None:
|
|||
|
|
return error_result
|
|||
|
|
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
async def execute_streaming(
|
|||
|
|
self,
|
|||
|
|
tool_name: str,
|
|||
|
|
tool_input: dict[str, Any],
|
|||
|
|
user_id: str | None = None,
|
|||
|
|
) -> AsyncGenerator[dict[str, Any], None]:
|
|||
|
|
"""执行流式工具
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
tool_name: 工具名称
|
|||
|
|
tool_input: 工具输入参数
|
|||
|
|
user_id: 用户 ID(可选)
|
|||
|
|
|
|||
|
|
Yields:
|
|||
|
|
流式输出片段
|
|||
|
|
"""
|
|||
|
|
# 创建执行上下文
|
|||
|
|
context = ExecutionContext(
|
|||
|
|
tool_name=tool_name,
|
|||
|
|
tool_input=tool_input,
|
|||
|
|
user_id=user_id,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 获取工具和执行器
|
|||
|
|
manifest = self._registry.get(tool_name)
|
|||
|
|
if manifest is None:
|
|||
|
|
raise ValueError(f"Tool not found: {tool_name}")
|
|||
|
|
|
|||
|
|
if not manifest.is_streaming:
|
|||
|
|
raise ValueError(f"Tool is not streaming: {tool_name}")
|
|||
|
|
|
|||
|
|
executor = self._registry.get_executor(tool_name)
|
|||
|
|
if executor is None:
|
|||
|
|
raise ValueError(f"Executor not found for tool: {tool_name}")
|
|||
|
|
|
|||
|
|
# 检查是否跳过
|
|||
|
|
if await self._hook_executor.execute_skip_check(context):
|
|||
|
|
yield {"type": "skipped", "tool": tool_name}
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 执行 pre-hooks
|
|||
|
|
continue_execution, modified_input = await self._hook_executor.execute_pre_hooks(context)
|
|||
|
|
if not continue_execution:
|
|||
|
|
yield {"type": "pre_hook_aborted", "tool": tool_name}
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 执行流式工具
|
|||
|
|
try:
|
|||
|
|
context.start_time = asyncio.get_event_loop().time()
|
|||
|
|
|
|||
|
|
# 调用执行器(应该返回 AsyncGenerator)
|
|||
|
|
result = executor(**modified_input)
|
|||
|
|
|
|||
|
|
# 如果是 async generator
|
|||
|
|
if asyncio.isasyncgen(result):
|
|||
|
|
async for chunk in result:
|
|||
|
|
yield {"type": "chunk", "data": chunk}
|
|||
|
|
else:
|
|||
|
|
# 普通协程
|
|||
|
|
data = await result
|
|||
|
|
yield {"type": "chunk", "data": data}
|
|||
|
|
|
|||
|
|
context.end_time = asyncio.get_event_loop().time()
|
|||
|
|
yield {"type": "done", "tool": tool_name}
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
context.error = e
|
|||
|
|
context.end_time = asyncio.get_event_loop().time()
|
|||
|
|
|
|||
|
|
yield {"type": "error", "error": str(e), "tool": tool_name}
|
|||
|
|
|
|||
|
|
# 执行 error-hooks
|
|||
|
|
await self._hook_executor.execute_error_hooks(context, e)
|
|||
|
|
|
|||
|
|
async def execute_batch(
|
|||
|
|
self,
|
|||
|
|
tool_calls: list[dict[str, Any]],
|
|||
|
|
user_id: str | None = None,
|
|||
|
|
) -> list[Any]:
|
|||
|
|
"""批量执行工具
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
tool_calls: 工具调用列表,每个元素包含 tool_name 和 tool_input
|
|||
|
|
user_id: 用户 ID(可选)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
执行结果列表
|
|||
|
|
"""
|
|||
|
|
tasks = []
|
|||
|
|
for call in tool_calls:
|
|||
|
|
tool_name = call.get("tool_name") or call.get("name")
|
|||
|
|
tool_input = call.get("tool_input") or call.get("input") or {}
|
|||
|
|
tasks.append(self.execute(tool_name, tool_input, user_id))
|
|||
|
|
|
|||
|
|
return await asyncio.gather(*tasks, return_exceptions=True)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局单例
|
|||
|
|
_executor: StreamingToolExecutor | None = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_streaming_executor() -> StreamingToolExecutor:
|
|||
|
|
"""获取全局流式执行器
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
全局 StreamingToolExecutor 实例
|
|||
|
|
"""
|
|||
|
|
global _executor
|
|||
|
|
if _executor is None:
|
|||
|
|
_executor = StreamingToolExecutor()
|
|||
|
|
return _executor
|