150 lines
3.9 KiB
Python
150 lines
3.9 KiB
Python
|
|
"""
|
||
|
|
API 路由定义
|
||
|
|
"""
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
from fastapi import APIRouter, Depends, HTTPException
|
||
|
|
from pydantic import BaseModel
|
||
|
|
|
||
|
|
from app.agent.core.agent import AgentManager
|
||
|
|
from app.security.approval import ApprovalService
|
||
|
|
|
||
|
|
|
||
|
|
router = APIRouter()
|
||
|
|
|
||
|
|
# 全局依赖(实际应该注入)
|
||
|
|
_agent_manager: Optional[AgentManager] = None
|
||
|
|
_approval_service: Optional[ApprovalService] = None
|
||
|
|
|
||
|
|
|
||
|
|
def get_agent_manager() -> AgentManager:
|
||
|
|
"""获取 Agent 管理器"""
|
||
|
|
# 这里应该从 app.state 获取
|
||
|
|
from app.main import agent_manager
|
||
|
|
if agent_manager is None:
|
||
|
|
raise HTTPException(status_code=503, detail="Agent service not initialized")
|
||
|
|
return agent_manager
|
||
|
|
|
||
|
|
|
||
|
|
def get_approval_service() -> ApprovalService:
|
||
|
|
"""获取审批服务"""
|
||
|
|
global _approval_service
|
||
|
|
if _approval_service is None:
|
||
|
|
_approval_service = ApprovalService()
|
||
|
|
return _approval_service
|
||
|
|
|
||
|
|
|
||
|
|
# ==================== 请求/响应模型 ====================
|
||
|
|
|
||
|
|
class ChatRequest(BaseModel):
|
||
|
|
"""聊天请求"""
|
||
|
|
agent_id: str
|
||
|
|
message: str
|
||
|
|
session_id: str = ""
|
||
|
|
context: dict = {}
|
||
|
|
|
||
|
|
|
||
|
|
class ChatResponse(BaseModel):
|
||
|
|
"""聊天响应"""
|
||
|
|
reply: str
|
||
|
|
session_id: str
|
||
|
|
tools_used: list[str] = []
|
||
|
|
metadata: dict = {}
|
||
|
|
|
||
|
|
|
||
|
|
class ApprovalRequest(BaseModel):
|
||
|
|
"""审批请求"""
|
||
|
|
request_id: str
|
||
|
|
tool_name: str
|
||
|
|
params: dict
|
||
|
|
reason: str
|
||
|
|
approved: bool
|
||
|
|
|
||
|
|
|
||
|
|
# ==================== API 端点 ====================
|
||
|
|
|
||
|
|
@router.post("/chat", response_model=ChatResponse)
|
||
|
|
async def chat(
|
||
|
|
request: ChatRequest,
|
||
|
|
agent_manager: AgentManager = Depends(get_agent_manager)
|
||
|
|
):
|
||
|
|
"""处理 Agent 聊天请求"""
|
||
|
|
try:
|
||
|
|
# 生成会话ID
|
||
|
|
if not request.session_id:
|
||
|
|
import uuid
|
||
|
|
request.session_id = str(uuid.uuid4())
|
||
|
|
|
||
|
|
# 执行 Agent
|
||
|
|
result = await agent_manager.execute(
|
||
|
|
agent_id=request.agent_id,
|
||
|
|
message=request.message,
|
||
|
|
session_id=request.session_id,
|
||
|
|
context=request.context
|
||
|
|
)
|
||
|
|
|
||
|
|
return ChatResponse(
|
||
|
|
reply=result.get("reply", ""),
|
||
|
|
session_id=request.session_id,
|
||
|
|
tools_used=result.get("tools_used", []),
|
||
|
|
metadata=result.get("metadata", {})
|
||
|
|
)
|
||
|
|
|
||
|
|
except ValueError as e:
|
||
|
|
raise HTTPException(status_code=404, detail=str(e))
|
||
|
|
except Exception as e:
|
||
|
|
raise HTTPException(status_code=500, detail=f"Agent execution failed: {str(e)}")
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/tool/request")
|
||
|
|
async def request_tool_execution(
|
||
|
|
request: dict,
|
||
|
|
approval_service: ApprovalService = Depends(get_approval_service)
|
||
|
|
):
|
||
|
|
"""请求执行工具(需要审批)"""
|
||
|
|
tool_name = request.get("tool_name")
|
||
|
|
params = request.get("params", {})
|
||
|
|
user_id = request.get("user_id", "unknown")
|
||
|
|
agent_id = request.get("agent_id")
|
||
|
|
reason = request.get("reason", "")
|
||
|
|
|
||
|
|
# 创建审批请求
|
||
|
|
request_id = await approval_service.request_approval(
|
||
|
|
tool_name=tool_name,
|
||
|
|
params=params,
|
||
|
|
user_id=user_id,
|
||
|
|
agent_id=agent_id or "",
|
||
|
|
reason=reason
|
||
|
|
)
|
||
|
|
|
||
|
|
return {
|
||
|
|
"request_id": request_id,
|
||
|
|
"status": "pending"
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/tools")
|
||
|
|
async def list_tools(agent_manager: AgentManager = Depends(get_agent_manager)):
|
||
|
|
"""列出所有可用工具"""
|
||
|
|
tools = agent_manager.list_tools()
|
||
|
|
return {"tools": [tool.dict() for tool in tools]}
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/agents")
|
||
|
|
async def list_agents(agent_manager: AgentManager = Depends(get_agent_manager)):
|
||
|
|
"""列出所有已加载的 Agent"""
|
||
|
|
agents = agent_manager.list_agents()
|
||
|
|
return {"agents": agents}
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/agent/{agent_id}")
|
||
|
|
async def get_agent(
|
||
|
|
agent_id: str,
|
||
|
|
agent_manager: AgentManager = Depends(get_agent_manager)
|
||
|
|
):
|
||
|
|
"""获取特定 Agent 信息"""
|
||
|
|
agent_info = agent_manager.get_agent_info(agent_id)
|
||
|
|
if not agent_info:
|
||
|
|
raise HTTPException(status_code=404, detail="Agent not found")
|
||
|
|
return agent_info
|