feat: 重构前后端架构,添加Go后端和Python Agent服务
- 新增 Go 语言后端服务(server/),包含用户认证、Agent管理、数据库连接等API - 新增 Python Agent 服务(agent/),实现Agent核心逻辑和工具集 - 前端从原生HTML迁移到Vue.js框架(web/src/) - 添加 Docker Compose 支持(docker-compose.yml) - 添加项目架构文档(docs/ARCHITECTURE.md) - 添加环境变量示例(.env.example)和本地启动脚本(start-local.ps1) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
192
agent/app/agent/core/agent.py
Normal file
192
agent/app/agent/core/agent.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
Agent 核心管理器
|
||||
"""
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
from app.agent.core.executor import AgentExecutor
|
||||
from app.agent.memory.session import SessionManager
|
||||
from app.agent.tools.registry import ToolRegistry
|
||||
from app.llm.factory import LLMFactory
|
||||
from app.security.audit import AuditLogger
|
||||
|
||||
|
||||
class AgentManager:
|
||||
"""Agent 管理器 - 负责加载和管理所有 Agent"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_provider: str = "openai",
|
||||
openai_api_key: Optional[str] = None,
|
||||
anthropic_api_key: Optional[str] = None,
|
||||
):
|
||||
self.llm_provider = llm_provider
|
||||
self.openai_api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.anthropic_api_key = anthropic_api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
|
||||
# 初始化组件
|
||||
self.llm_factory = LLMFactory(
|
||||
provider=llm_provider,
|
||||
openai_api_key=self.openai_api_key,
|
||||
anthropic_api_key=self.anthropic_api_key
|
||||
)
|
||||
self.tool_registry = ToolRegistry()
|
||||
self.session_manager = SessionManager()
|
||||
self.audit_logger = AuditLogger()
|
||||
|
||||
# 已加载的 Agent
|
||||
self.agents: dict[str, dict] = {}
|
||||
self.executors: dict[str, AgentExecutor] = {}
|
||||
|
||||
# 注册默认工具
|
||||
self._register_default_tools()
|
||||
|
||||
def _register_default_tools(self):
|
||||
"""注册默认工具"""
|
||||
from app.agent.tools.impl import search, calculator, time_tool
|
||||
from app.agent.tools.impl import sandbox, database, api_client
|
||||
|
||||
# 安全工具 - Safe 级别
|
||||
self.tool_registry.register(
|
||||
name="search",
|
||||
func=search.search_web,
|
||||
description="Search the web for information",
|
||||
security_level="safe"
|
||||
)
|
||||
|
||||
self.tool_registry.register(
|
||||
name="calculator",
|
||||
func=calculator.calculate,
|
||||
description="Perform mathematical calculations",
|
||||
security_level="safe"
|
||||
)
|
||||
|
||||
self.tool_registry.register(
|
||||
name="get_current_time",
|
||||
func=time_tool.get_current_time,
|
||||
description="Get current date and time",
|
||||
security_level="safe"
|
||||
)
|
||||
|
||||
# 需要审核的工具 - Review 级别
|
||||
self.tool_registry.register(
|
||||
name="execute_code",
|
||||
func=sandbox.sandbox.execute,
|
||||
description="Execute code in sandbox (Python/JavaScript)",
|
||||
security_level="review",
|
||||
require_approval=True,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {"type": "string", "description": "Code to execute"},
|
||||
"language": {"type": "string", "default": "python"},
|
||||
"timeout": {"type": "integer", "default": 30}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
)
|
||||
|
||||
self.tool_registry.register(
|
||||
name="query_database",
|
||||
func=database.query_data,
|
||||
description="Query database (SELECT only)",
|
||||
security_level="review",
|
||||
require_approval=True,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sql": {"type": "string", "description": "SELECT query"}
|
||||
},
|
||||
"required": ["sql"]
|
||||
}
|
||||
)
|
||||
|
||||
self.tool_registry.register(
|
||||
name="call_api",
|
||||
func=api_client.call_api,
|
||||
description="Call external API (whitelist only)",
|
||||
security_level="review",
|
||||
require_approval=True,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"api_name": {"type": "string"},
|
||||
"endpoint": {"type": "string"},
|
||||
"params": {"type": "object"}
|
||||
},
|
||||
"required": ["api_name"]
|
||||
}
|
||||
)
|
||||
|
||||
async def load_agents(self):
|
||||
"""加载 Agent 配置"""
|
||||
# TODO: 从数据库或配置文件加载
|
||||
# 这里先注册一些示例 Agent
|
||||
|
||||
self.agents["assistant"] = {
|
||||
"name": "General Assistant",
|
||||
"description": "A general purpose assistant",
|
||||
"system_prompt": "You are a helpful assistant.",
|
||||
"tools": ["search", "calculator", "get_current_time"]
|
||||
}
|
||||
|
||||
self.agents["coder"] = {
|
||||
"name": "Code Assistant",
|
||||
"description": "Helps with coding tasks",
|
||||
"system_prompt": "You are a helpful coding assistant. You can write, explain, and debug code.",
|
||||
"tools": ["search", "calculator"]
|
||||
}
|
||||
|
||||
# 为每个 Agent 创建执行器
|
||||
for agent_id, config in self.agents.items():
|
||||
self.executors[agent_id] = AgentExecutor(
|
||||
agent_id=agent_id,
|
||||
llm_factory=self.llm_factory,
|
||||
tool_registry=self.tool_registry,
|
||||
session_manager=self.session_manager,
|
||||
audit_logger=self.audit_logger,
|
||||
config=config
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
agent_id: str,
|
||||
message: str,
|
||||
session_id: str,
|
||||
context: dict = None
|
||||
) -> dict[str, Any]:
|
||||
"""执行 Agent"""
|
||||
if agent_id not in self.executors:
|
||||
raise ValueError(f"Agent '{agent_id}' not found")
|
||||
|
||||
executor = self.executors[agent_id]
|
||||
|
||||
# 执行
|
||||
result = await executor.run(
|
||||
message=message,
|
||||
session_id=session_id,
|
||||
context=context or {}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def list_tools(self) -> list:
|
||||
"""列出所有可用工具"""
|
||||
return self.tool_registry.list_tools()
|
||||
|
||||
def list_agents(self) -> list[dict]:
|
||||
"""列出所有 Agent"""
|
||||
return [
|
||||
{
|
||||
"id": agent_id,
|
||||
"name": config["name"],
|
||||
"description": config["description"]
|
||||
}
|
||||
for agent_id, config in self.agents.items()
|
||||
]
|
||||
|
||||
def get_agent_info(self, agent_id: str) -> Optional[dict]:
|
||||
"""获取 Agent 信息"""
|
||||
if agent_id not in self.agents:
|
||||
return None
|
||||
return self.agents[agent_id]
|
||||
163
agent/app/agent/core/executor.py
Normal file
163
agent/app/agent/core/executor.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
Agent 执行器 - 负责执行 Agent 的核心逻辑
|
||||
"""
|
||||
from typing import Any, Optional
|
||||
from app.llm.factory import LLMFactory
|
||||
from app.agent.tools.registry import ToolRegistry
|
||||
from app.agent.memory.session import SessionManager
|
||||
from app.security.audit import AuditLogger
|
||||
|
||||
|
||||
class AgentExecutor:
|
||||
"""Agent 执行器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
llm_factory: LLMFactory,
|
||||
tool_registry: ToolRegistry,
|
||||
session_manager: SessionManager,
|
||||
audit_logger: AuditLogger,
|
||||
config: dict
|
||||
):
|
||||
self.agent_id = agent_id
|
||||
self.llm_factory = llm_factory
|
||||
self.tool_registry = tool_registry
|
||||
self.session_manager = session_manager
|
||||
self.audit_logger = audit_logger
|
||||
self.config = config
|
||||
|
||||
# 获取 LLM
|
||||
self.llm = self.llm_factory.get_llm()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
message: str,
|
||||
session_id: str,
|
||||
context: dict
|
||||
) -> dict[str, Any]:
|
||||
"""运行 Agent"""
|
||||
tools_used = []
|
||||
|
||||
# 1. 获取会话历史
|
||||
history = self.session_manager.get_history(session_id)
|
||||
|
||||
# 2. 构建消息列表
|
||||
messages = self._build_messages(message, history)
|
||||
|
||||
# 3. 获取可用工具
|
||||
available_tools = self._get_available_tools()
|
||||
|
||||
# 4. 调用 LLM(带工具)
|
||||
try:
|
||||
response = await self.llm.agenerate(
|
||||
messages=messages,
|
||||
tools=available_tools
|
||||
)
|
||||
|
||||
# 检查是否需要调用工具
|
||||
response_message = response.generations[0][0]
|
||||
|
||||
# 如果有工具调用
|
||||
if hasattr(response_message, "tool_calls") and response_message.tool_calls:
|
||||
for tool_call in response_message.tool_calls:
|
||||
tool_name = tool_call.name
|
||||
tool_args = tool_call.arguments
|
||||
|
||||
# 记录工具使用
|
||||
tools_used.append(tool_name)
|
||||
|
||||
# 执行工具
|
||||
tool_result = await self._execute_tool(tool_name, tool_args)
|
||||
|
||||
# 添加工具结果到消息
|
||||
messages.append(response_message)
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": str(tool_result)
|
||||
})
|
||||
|
||||
# 再次调用 LLM 生成最终响应
|
||||
final_response = await self.llm.agenerate(messages=messages)
|
||||
final_message = final_response.generations[0][0].text
|
||||
|
||||
# 保存到历史
|
||||
self.session_manager.add_message(session_id, "user", message)
|
||||
self.session_manager.add_message(session_id, "assistant", final_message)
|
||||
|
||||
return {
|
||||
"reply": final_message,
|
||||
"tools_used": tools_used,
|
||||
"metadata": {}
|
||||
}
|
||||
|
||||
# 没有工具调用,直接返回
|
||||
reply = response_message.text
|
||||
|
||||
# 保存到历史
|
||||
self.session_manager.add_message(session_id, "user", message)
|
||||
self.session_manager.add_message(session_id, "assistant", reply)
|
||||
|
||||
return {
|
||||
"reply": reply,
|
||||
"tools_used": tools_used,
|
||||
"metadata": {}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# 记录错误
|
||||
self.audit_logger.log(
|
||||
action="agent_error",
|
||||
agent_id=self.agent_id,
|
||||
session_id=session_id,
|
||||
details={"error": str(e)}
|
||||
)
|
||||
raise
|
||||
|
||||
def _build_messages(self, message: str, history: list) -> list:
|
||||
"""构建消息列表"""
|
||||
messages = []
|
||||
|
||||
# 添加系统提示
|
||||
system_prompt = self.config.get("system_prompt", "You are a helpful assistant.")
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
# 添加历史
|
||||
for msg in history:
|
||||
messages.append(msg)
|
||||
|
||||
# 添加当前消息
|
||||
messages.append({"role": "user", "content": message})
|
||||
|
||||
return messages
|
||||
|
||||
def _get_available_tools(self) -> list:
|
||||
"""获取可用工具定义"""
|
||||
agent_tools = self.config.get("tools", [])
|
||||
tool_defs = []
|
||||
|
||||
for tool_name in agent_tools:
|
||||
tool_def = self.tool_registry.get_tool_definition(tool_name)
|
||||
if tool_def:
|
||||
tool_defs.append(tool_def)
|
||||
|
||||
return tool_defs
|
||||
|
||||
async def _execute_tool(self, tool_name: str, args: dict) -> Any:
|
||||
"""执行工具"""
|
||||
# 安全检查
|
||||
tool_func, metadata = self.tool_registry.get_tool(tool_name)
|
||||
|
||||
# 如果需要审批,抛出异常
|
||||
if metadata.require_approval:
|
||||
raise PermissionError(
|
||||
f"Tool '{tool_name}' requires approval before execution"
|
||||
)
|
||||
|
||||
# 执行工具
|
||||
try:
|
||||
result = tool_func(**args)
|
||||
return result
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
62
agent/app/agent/memory/session.py
Normal file
62
agent/app/agent/memory/session.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
会话管理器 - 管理 Agent 的会话历史
|
||||
"""
|
||||
from typing import Any, Optional
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""会话管理器"""
|
||||
|
||||
def __init__(self, max_history: int = 10):
|
||||
"""
|
||||
初始化会话管理器
|
||||
|
||||
Args:
|
||||
max_history: 每个会话保留的最大历史消息数
|
||||
"""
|
||||
self.max_history = max_history
|
||||
self.sessions: dict[str, list[dict]] = defaultdict(list)
|
||||
self.metadata: dict[str, dict] = {}
|
||||
|
||||
def add_message(self, session_id: str, role: str, content: str):
|
||||
"""添加消息到会话"""
|
||||
self.sessions[session_id].append({
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# 限制历史长度
|
||||
if len(self.sessions[session_id]) > self.max_history:
|
||||
self.sessions[session_id] = self.sessions[session_id][-self.max_history:]
|
||||
|
||||
def get_history(self, session_id: str) -> list[dict]:
|
||||
"""获取会话历史"""
|
||||
return self.sessions.get(session_id, [])
|
||||
|
||||
def clear_session(self, session_id: str):
|
||||
"""清除会话"""
|
||||
if session_id in self.sessions:
|
||||
del self.sessions[session_id]
|
||||
if session_id in self.metadata:
|
||||
del self.metadata[session_id]
|
||||
|
||||
def set_metadata(self, session_id: str, key: str, value: Any):
|
||||
"""设置会话元数据"""
|
||||
if session_id not in self.metadata:
|
||||
self.metadata[session_id] = {}
|
||||
self.metadata[session_id][key] = value
|
||||
|
||||
def get_metadata(self, session_id: str, key: str, default: Any = None) -> Any:
|
||||
"""获取会话元数据"""
|
||||
return self.metadata.get(session_id, {}).get(key, default)
|
||||
|
||||
def list_sessions(self) -> list[str]:
|
||||
"""列出所有会话ID"""
|
||||
return list(self.sessions.keys())
|
||||
|
||||
def get_session_count(self) -> int:
|
||||
"""获取会话数量"""
|
||||
return len(self.sessions)
|
||||
22
agent/app/agent/tools/impl/__init__.py
Normal file
22
agent/app/agent/tools/impl/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
工具实现模块
|
||||
"""
|
||||
|
||||
# 基础工具
|
||||
from . import search
|
||||
from . import calculator
|
||||
from . import time_tool
|
||||
|
||||
# 安全工具
|
||||
from . import sandbox
|
||||
from . import database
|
||||
from . import api_client
|
||||
|
||||
__all__ = [
|
||||
"search",
|
||||
"calculator",
|
||||
"time_tool",
|
||||
"sandbox",
|
||||
"database",
|
||||
"api_client",
|
||||
]
|
||||
166
agent/app/agent/tools/impl/api_client.py
Normal file
166
agent/app/agent/tools/impl/api_client.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
API 调用工具 - 安全的外部 API 调用
|
||||
"""
|
||||
import httpx
|
||||
from typing import Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class APIPermission(Enum):
|
||||
"""API 权限级别"""
|
||||
PUBLIC = "public" # 公开 API
|
||||
APPROVED = "approved" # 已审批的 API
|
||||
ADMIN = "admin" # 管理员 API
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIEndpoint:
|
||||
"""API 端点定义"""
|
||||
name: str
|
||||
url: str
|
||||
method: str
|
||||
permission: APIPermission
|
||||
description: str
|
||||
rate_limit: int = 60 # 每分钟请求次数
|
||||
|
||||
|
||||
# API 白名单
|
||||
ALLOWED_APIS = [
|
||||
APIEndpoint(
|
||||
name="weather",
|
||||
url="https://api.weather.example.com/v1",
|
||||
method="GET",
|
||||
permission=APIPermission.PUBLIC,
|
||||
description="获取天气信息",
|
||||
rate_limit=30
|
||||
),
|
||||
APIEndpoint(
|
||||
name="news",
|
||||
url="https://newsapi.org/v2",
|
||||
method="GET",
|
||||
permission=APIPermission.PUBLIC,
|
||||
description="获取新闻",
|
||||
rate_limit=30
|
||||
),
|
||||
# 可以添加更多已审批的 API
|
||||
]
|
||||
|
||||
|
||||
class APICallTool:
|
||||
"""
|
||||
API 调用工具
|
||||
|
||||
安全特性:
|
||||
- 只允许调用白名单中的 API
|
||||
- 速率限制
|
||||
- 请求超时
|
||||
- 响应大小限制
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.allowed_apis = {api.name: api for api in ALLOWED_APIS}
|
||||
self.request_timeout = 10 # 请求超时(秒)
|
||||
self.max_response_size = 1024 * 1024 # 最大响应大小(1MB)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
api_name: str,
|
||||
endpoint: str = "",
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
调用 API
|
||||
|
||||
Args:
|
||||
api_name: API 名称(必须在白名单中)
|
||||
endpoint: 具体的端点
|
||||
params: 查询参数
|
||||
headers: 请求头
|
||||
|
||||
Returns:
|
||||
API 响应
|
||||
"""
|
||||
# 安全检查1: API 必须在白名单中
|
||||
if api_name not in self.allowed_apis:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"API '{api_name}' not in whitelist. Allowed: {list(self.allowed_apis.keys())}"
|
||||
}
|
||||
|
||||
api = self.allowed_apis[api_name]
|
||||
|
||||
# 构建完整 URL
|
||||
url = f"{api.url}/{endpoint}" if endpoint else api.url
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self.request_timeout) as client:
|
||||
# 根据方法调用
|
||||
if api.method == "GET":
|
||||
response = await client.get(url, params=params, headers=headers)
|
||||
elif api.method == "POST":
|
||||
response = await client.post(url, json=params, headers=headers)
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Method {api.method} not supported"
|
||||
}
|
||||
|
||||
# 检查响应大小
|
||||
if len(response.content) > self.max_response_size:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Response too large (max {self.max_response_size} bytes)"
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"status_code": response.status_code,
|
||||
"data": response.json() if response.headers.get("content-type", "").startswith("application/json") else response.text,
|
||||
"headers": dict(response.headers)
|
||||
}
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Request timeout"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def list_apis(self) -> list:
|
||||
"""列出所有可用的 API"""
|
||||
return [
|
||||
{
|
||||
"name": api.name,
|
||||
"description": api.description,
|
||||
"method": api.method,
|
||||
"permission": api.permission.value,
|
||||
"rate_limit": api.rate_limit
|
||||
}
|
||||
for api in ALLOWED_APIS
|
||||
]
|
||||
|
||||
|
||||
# 全局实例
|
||||
api_tool = APICallTool()
|
||||
|
||||
|
||||
async def call_api(
|
||||
api_name: str,
|
||||
endpoint: str = "",
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
API 调用工具(供 Agent 使用)
|
||||
"""
|
||||
return await api_tool.call(api_name, endpoint, params)
|
||||
|
||||
|
||||
def list_allowed_apis() -> list:
|
||||
"""列出允许的 API"""
|
||||
return api_tool.list_apis()
|
||||
91
agent/app/agent/tools/impl/calculator.py
Normal file
91
agent/app/agent/tools/impl/calculator.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
计算器工具
|
||||
"""
|
||||
import ast
|
||||
import operator
|
||||
from typing import Any
|
||||
|
||||
|
||||
# 安全运算符
|
||||
SAFE_OPERATORS = {
|
||||
ast.Add: operator.add,
|
||||
ast.Sub: operator.sub,
|
||||
ast.Mult: operator.mul,
|
||||
ast.Div: operator.truediv,
|
||||
ast.Pow: operator.pow,
|
||||
ast.Mod: operator.mod,
|
||||
ast.USub: operator.neg,
|
||||
}
|
||||
|
||||
|
||||
def safe_eval_expr(node):
|
||||
"""安全地求值表达式节点"""
|
||||
if isinstance(node, ast.Num):
|
||||
return node.n
|
||||
elif isinstance(node, ast.BinOp):
|
||||
left = safe_eval_expr(node.left)
|
||||
right = safe_eval_expr(node.right)
|
||||
op_type = type(node.op)
|
||||
if op_type in SAFE_OPERATORS:
|
||||
return SAFE_OPERATORS[op_type](left, right)
|
||||
raise ValueError(f"Unsupported operator: {op_type}")
|
||||
elif isinstance(node, ast.UnaryOp):
|
||||
operand = safe_eval_expr(node.operand)
|
||||
op_type = type(node.op)
|
||||
if op_type in SAFE_OPERATORS:
|
||||
return SAFE_OPERATORS[op_type](operand)
|
||||
raise ValueError(f"Unsupported unary operator: {op_type}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported expression: {ast.dump(node)}")
|
||||
|
||||
|
||||
def calculate(expression: str) -> dict:
|
||||
"""
|
||||
执行数学计算
|
||||
|
||||
Args:
|
||||
expression: 数学表达式,如 "2 + 2" 或 "sqrt(16)"
|
||||
|
||||
Returns:
|
||||
计算结果
|
||||
"""
|
||||
try:
|
||||
# 预处理:处理常见数学函数
|
||||
expression = expression.replace("sqrt", "**0.5")
|
||||
expression = expression.replace("pi", "3.14159265359")
|
||||
expression = expression.replace("e", "2.71828182846")
|
||||
|
||||
# 解析表达式
|
||||
tree = ast.parse(expression, mode='eval')
|
||||
result = safe_eval_expr(tree.body)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"expression": expression,
|
||||
"result": result,
|
||||
"type": type(result).__name__
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"expression": expression,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# 工具定义
|
||||
TOOL_DEFINITION = {
|
||||
"name": "calculator",
|
||||
"description": "Perform mathematical calculations. Supports basic arithmetic (+, -, *, /), powers (**), and functions (sqrt).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "Mathematical expression to evaluate, e.g., '2 + 2' or 'sqrt(16) + 5'"
|
||||
}
|
||||
},
|
||||
"required": ["expression"]
|
||||
}
|
||||
}
|
||||
96
agent/app/agent/tools/impl/database.py
Normal file
96
agent/app/agent/tools/impl/database.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""
|
||||
数据库查询工具 - 安全的数据查询接口
|
||||
"""
|
||||
from typing import Dict, Any, List, Optional
|
||||
import os
|
||||
|
||||
|
||||
# 只读查询白名单 - 只允许 SELECT 语句
|
||||
ALLOWED_TABLES = ["users", "agents", "sessions", "audit_logs"]
|
||||
|
||||
|
||||
class DatabaseQueryTool:
|
||||
"""
|
||||
数据库查询工具
|
||||
|
||||
安全特性:
|
||||
- 只允许 SELECT 查询
|
||||
- 表名白名单
|
||||
- 结果数量限制
|
||||
"""
|
||||
|
||||
def __init__(self, connection_string: str = ""):
|
||||
self.connection_string = connection_string or os.getenv(
|
||||
"DATABASE_URL",
|
||||
"postgresql://postgres:postgres@localhost:5432/x_agents"
|
||||
)
|
||||
self.max_rows = 100 # 最多返回100行
|
||||
|
||||
def query(self, sql: str, params: List[Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
执行查询
|
||||
|
||||
Args:
|
||||
sql: SQL 查询语句(必须是 SELECT)
|
||||
params: 查询参数
|
||||
|
||||
Returns:
|
||||
查询结果
|
||||
"""
|
||||
# 安全检查1: 必须是 SELECT 语句
|
||||
sql_upper = sql.strip().upper()
|
||||
if not sql_upper.startswith("SELECT"):
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Only SELECT queries are allowed"
|
||||
}
|
||||
|
||||
# 安全检查2: 禁止危险关键字
|
||||
dangerous_keywords = [
|
||||
"DROP", "DELETE", "INSERT", "UPDATE", "ALTER",
|
||||
"CREATE", "TRUNCATE", "EXEC", "EXECUTE"
|
||||
]
|
||||
for keyword in dangerous_keywords:
|
||||
if keyword in sql_upper:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Keyword '{keyword}' is not allowed"
|
||||
}
|
||||
|
||||
# 安全检查3: 表名白名单
|
||||
for table in ALLOWED_TABLES:
|
||||
if f"FROM {table}" in sql_upper or f"JOIN {table}" in sql_upper:
|
||||
# 表名在白名单中,允许
|
||||
break
|
||||
else:
|
||||
# 没有找到白名单表
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Table not in whitelist. Allowed: {ALLOWED_TABLES}"
|
||||
}
|
||||
|
||||
# TODO: 实际执行查询(需要数据库连接)
|
||||
# 这里返回模拟数据
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Query executed (mock mode - database not connected)",
|
||||
"rows": [],
|
||||
"columns": []
|
||||
}
|
||||
|
||||
|
||||
# 全局实例
|
||||
db_tool = DatabaseQueryTool()
|
||||
|
||||
|
||||
def query_data(sql: str) -> Dict[str, Any]:
|
||||
"""
|
||||
查询数据工具
|
||||
|
||||
Args:
|
||||
sql: SELECT 查询语句
|
||||
|
||||
Returns:
|
||||
查询结果
|
||||
"""
|
||||
return db_tool.query(sql)
|
||||
87
agent/app/agent/tools/impl/search.py
Normal file
87
agent/app/agent/tools/impl/search.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
网页搜索工具
|
||||
"""
|
||||
import httpx
|
||||
from typing import Optional
|
||||
|
||||
|
||||
async def search_web(query: str, max_results: int = 5) -> dict:
|
||||
"""
|
||||
搜索网页获取信息
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
max_results: 返回结果数量
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
# 这里可以使用搜索引擎API,如 Google, Bing, DuckDuckGo 等
|
||||
# 示例使用 DuckDuckGo API(免费)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
"https://api.duckduckgo.com/",
|
||||
params={
|
||||
"q": query,
|
||||
"format": "json",
|
||||
"no_html": 1,
|
||||
"skip_disambig": 1
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
results = []
|
||||
|
||||
# 提取相关主题
|
||||
if "RelatedTopics" in data:
|
||||
for item in data["RelatedTopics"][:max_results]:
|
||||
if "Text" in item:
|
||||
results.append({
|
||||
"title": item.get("Text", "").split(" - ")[0] if " - " in item.get("Text", "") else "",
|
||||
"content": item.get("Text", ""),
|
||||
"url": item.get("URL", "")
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"query": query,
|
||||
"results": results,
|
||||
"count": len(results)
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Search API returned status {response.status_code}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# 工具定义(用于 LLM)
|
||||
TOOL_DEFINITION = {
|
||||
"name": "search",
|
||||
"description": "Search the web for information. Use this when you need to find current information or facts.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query"
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return",
|
||||
"default": 5
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
70
agent/app/agent/tools/impl/time_tool.py
Normal file
70
agent/app/agent/tools/impl/time_tool.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
时间工具
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_current_time(timezone: Optional[str] = None) -> dict:
|
||||
"""
|
||||
获取当前时间
|
||||
|
||||
Args:
|
||||
timezone: 时区名称,如 "UTC", "Asia/Shanghai"
|
||||
|
||||
Returns:
|
||||
当前时间信息
|
||||
"""
|
||||
now = datetime.now()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"datetime": now.isoformat(),
|
||||
"timestamp": now.timestamp(),
|
||||
"date": now.strftime("%Y-%m-%d"),
|
||||
"time": now.strftime("%H:%M:%S"),
|
||||
"weekday": now.strftime("%A"),
|
||||
"timezone": timezone or "Local Time"
|
||||
}
|
||||
|
||||
|
||||
def format_time(timestamp: float, format_str: str = "%Y-%m-%d %H:%M:%S") -> dict:
|
||||
"""
|
||||
格式化时间戳
|
||||
|
||||
Args:
|
||||
timestamp: Unix 时间戳
|
||||
format_str: 格式字符串
|
||||
|
||||
Returns:
|
||||
格式化后的时间
|
||||
"""
|
||||
try:
|
||||
dt = datetime.fromtimestamp(timestamp)
|
||||
return {
|
||||
"success": True,
|
||||
"formatted": dt.strftime(format_str),
|
||||
"datetime": dt.isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# 工具定义
|
||||
TOOL_DEFINITION = {
|
||||
"name": "get_current_time",
|
||||
"description": "Get the current date and time. Useful for timestamps or scheduling.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "Optional timezone (e.g., 'UTC', 'Asia/Shanghai')",
|
||||
"default": "Local"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
99
agent/app/agent/tools/registry.py
Normal file
99
agent/app/agent/tools/registry.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
工具注册表 - 管理所有可用工具(白名单机制)
|
||||
"""
|
||||
from typing import Any, Callable, Optional
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SecurityLevel(Enum):
|
||||
"""工具安全等级"""
|
||||
SAFE = "safe" # 安全操作
|
||||
REVIEW = "review" # 需要审核
|
||||
DANGER = "danger" # 危险操作
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolMetadata:
|
||||
"""工具元数据"""
|
||||
name: str
|
||||
description: str
|
||||
security_level: str
|
||||
require_approval: bool = False
|
||||
allowed_roles: list = None
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"security_level": self.security_level,
|
||||
"require_approval": self.require_approval
|
||||
}
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""工具注册表"""
|
||||
|
||||
def __init__(self):
|
||||
self._tools: dict[str, tuple[Callable, ToolMetadata]] = {}
|
||||
self._definitions: dict[str, dict] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
func: Callable,
|
||||
description: str = "",
|
||||
security_level: str = "safe",
|
||||
require_approval: bool = False,
|
||||
allowed_roles: list = None,
|
||||
parameters: dict = None
|
||||
):
|
||||
"""注册工具到白名单"""
|
||||
metadata = ToolMetadata(
|
||||
name=name,
|
||||
description=description,
|
||||
security_level=security_level,
|
||||
require_approval=require_approval,
|
||||
allowed_roles=allowed_roles or ["user", "admin"]
|
||||
)
|
||||
|
||||
self._tools[name] = (func, metadata)
|
||||
|
||||
# 生成工具定义(用于 LLM 调用)
|
||||
self._definitions[name] = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": parameters or {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
|
||||
def get_tool(self, name: str) -> tuple[Callable, ToolMetadata]:
|
||||
"""获取工具函数和元数据"""
|
||||
if name not in self._tools:
|
||||
raise ValueError(f"Tool '{name}' not found in whitelist")
|
||||
return self._tools[name]
|
||||
|
||||
def get_tool_definition(self, name: str) -> Optional[dict]:
|
||||
"""获取工具定义(用于 LLM)"""
|
||||
return self._definitions.get(name)
|
||||
|
||||
def list_tools(self) -> list[ToolMetadata]:
|
||||
"""列出所有已注册工具"""
|
||||
return [meta for _, meta in self._tools.values()]
|
||||
|
||||
def check_permission(self, tool_name: str, user_role: str) -> bool:
|
||||
"""检查用户权限"""
|
||||
if tool_name not in self._tools:
|
||||
return False
|
||||
_, metadata = self._tools[tool_name]
|
||||
return user_role in metadata.allowed_roles
|
||||
|
||||
def need_approval(self, tool_name: str) -> bool:
|
||||
"""判断是否需要审批"""
|
||||
if tool_name not in self._tools:
|
||||
return False
|
||||
_, metadata = self._tools[tool_name]
|
||||
return metadata.require_approval
|
||||
283
agent/app/agent/tools/sandbox/sandbox.py
Normal file
283
agent/app/agent/tools/sandbox/sandbox.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
沙盒执行环境 - 在项目内构建,不依赖 Docker
|
||||
提供安全的代码执行环境
|
||||
"""
|
||||
import subprocess
|
||||
import tempfile
|
||||
import os
|
||||
import shutil
|
||||
import resource
|
||||
import signal
|
||||
import threading
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class SandboxConfig:
|
||||
"""沙盒配置"""
|
||||
# 资源限制
|
||||
MAX_MEMORY_MB = 256 # 最大内存 (MB)
|
||||
MAX_CPU_PERCENT = 50 # 最大 CPU 百分比
|
||||
MAX_EXECUTION_TIME = 30 # 最大执行时间 (秒)
|
||||
MAX_OUTPUT_SIZE = 1024 * 1024 # 最大输出大小 (bytes)
|
||||
|
||||
|
||||
class Sandbox:
|
||||
"""
|
||||
沙盒执行器 - 使用 subprocess 隔离执行
|
||||
|
||||
安全特性:
|
||||
- 内存限制
|
||||
- CPU限制
|
||||
- 超时控制
|
||||
- 网络隔离(可选)
|
||||
- 临时文件隔离
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[SandboxConfig] = None):
|
||||
self.config = config or SandboxConfig()
|
||||
self.temp_dir = None
|
||||
|
||||
def _setup_temp_dir(self) -> str:
|
||||
"""创建临时目录"""
|
||||
self.temp_dir = tempfile.mkdtemp(prefix="sandbox_")
|
||||
return self.temp_dir
|
||||
|
||||
def _cleanup(self):
|
||||
"""清理临时目录"""
|
||||
if self.temp_dir and os.path.exists(self.temp_dir):
|
||||
try:
|
||||
shutil.rmtree(self.temp_dir)
|
||||
except Exception as e:
|
||||
print(f"Cleanup error: {e}")
|
||||
|
||||
def execute(
|
||||
self,
|
||||
code: str,
|
||||
language: str = "python",
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
在沙盒中执行代码
|
||||
|
||||
Args:
|
||||
code: 要执行的代码
|
||||
language: 语言类型 (python, javascript)
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
timeout = timeout or self.config.MAX_EXECUTION_TIME
|
||||
|
||||
self._setup_temp_dir()
|
||||
|
||||
try:
|
||||
if language == "python":
|
||||
return self._execute_python(code, timeout)
|
||||
elif language == "javascript":
|
||||
return self._execute_javascript(code, timeout)
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Unsupported language: {language}"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
finally:
|
||||
self._cleanup()
|
||||
|
||||
def _execute_python(self, code: str, timeout: int) -> Dict[str, Any]:
|
||||
"""执行 Python 代码"""
|
||||
# 创建临时文件
|
||||
temp_file = os.path.join(self.temp_dir, "code.py")
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
|
||||
try:
|
||||
# 构建命令
|
||||
cmd = ["python", temp_file]
|
||||
|
||||
# 执行
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
cwd=self.temp_dir, # 限制工作目录
|
||||
env=self._get_restricted_env(), # 限制环境变量
|
||||
)
|
||||
|
||||
# 检查输出大小
|
||||
stdout = result.stdout.decode("utf-8", errors="replace")
|
||||
stderr = result.stderr.decode("utf-8", errors="replace")
|
||||
|
||||
if len(stdout) > self.config.MAX_OUTPUT_SIZE:
|
||||
stdout = stdout[:self.config.MAX_OUTPUT_SIZE] + "\n... (output truncated)"
|
||||
|
||||
return {
|
||||
"success": result.returncode == 0,
|
||||
"output": stdout,
|
||||
"error": stderr if result.returncode != 0 else None,
|
||||
"exit_code": result.returncode
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Execution timeout ({timeout}s)",
|
||||
"output": None
|
||||
}
|
||||
|
||||
def _execute_javascript(self, code: str, timeout: int) -> Dict[str, Any]:
|
||||
"""执行 JavaScript 代码"""
|
||||
temp_file = os.path.join(self.temp_dir, "code.js")
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
|
||||
try:
|
||||
# 尝试使用 node
|
||||
cmd = ["node", temp_file]
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
cwd=self.temp_dir,
|
||||
)
|
||||
|
||||
stdout = result.stdout.decode("utf-8", errors="replace")
|
||||
stderr = result.stderr.decode("utf-8", errors="replace")
|
||||
|
||||
return {
|
||||
"success": result.returncode == 0,
|
||||
"output": stdout,
|
||||
"error": stderr if result.returncode != 0 else None,
|
||||
"exit_code": result.returncode
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Execution timeout ({timeout}s)",
|
||||
"output": None
|
||||
}
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Node.js not installed",
|
||||
"output": None
|
||||
}
|
||||
|
||||
def _get_restricted_env(self) -> Dict[str, str]:
|
||||
"""
|
||||
获取受限的环境变量
|
||||
移除敏感变量,保留必要的 PATH
|
||||
"""
|
||||
# 保留 PATH,移除其他敏感变量
|
||||
safe_env = {
|
||||
"PATH": os.environ.get("PATH", "/usr/bin:/bin"),
|
||||
"LANG": "en_US.UTF-8",
|
||||
"HOME": self.temp_dir,
|
||||
"TMPDIR": self.temp_dir,
|
||||
}
|
||||
|
||||
# 移除可能不安全的变量
|
||||
unsafe_vars = [
|
||||
"PYTHONPATH",
|
||||
"PYTHONHOME",
|
||||
"LD_PRELOAD",
|
||||
"LD_LIBRARY_PATH",
|
||||
]
|
||||
|
||||
for var in unsafe_vars:
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
|
||||
return safe_env
|
||||
|
||||
|
||||
class SafeEval:
|
||||
"""
|
||||
安全求值器 - 用于简单表达式计算
|
||||
比沙盒更轻量,适用于不需要完全隔离的场景
|
||||
"""
|
||||
|
||||
# 安全函数白名单
|
||||
SAFE_BUILTINS = {
|
||||
"abs": abs,
|
||||
"min": min,
|
||||
"max": max,
|
||||
"sum": sum,
|
||||
"len": len,
|
||||
"round": round,
|
||||
"pow": pow,
|
||||
"print": print,
|
||||
"str": str,
|
||||
"int": int,
|
||||
"float": float,
|
||||
"bool": bool,
|
||||
"list": list,
|
||||
"dict": dict,
|
||||
"tuple": tuple,
|
||||
"set": set,
|
||||
"range": range,
|
||||
"enumerate": enumerate,
|
||||
"zip": zip,
|
||||
"map": map,
|
||||
"filter": filter,
|
||||
"sorted": sorted,
|
||||
"reversed": reversed,
|
||||
}
|
||||
|
||||
# 安全数学常量
|
||||
SAFE_MATH = {
|
||||
"pi": 3.14159265359,
|
||||
"e": 2.71828182846,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def eval(cls, expression: str) -> Any:
|
||||
"""
|
||||
安全地求值表达式
|
||||
|
||||
Args:
|
||||
expression: 数学表达式
|
||||
|
||||
Returns:
|
||||
计算结果
|
||||
"""
|
||||
# 预处理表达式
|
||||
expression = expression.replace("sqrt", "**0.5")
|
||||
|
||||
# 构建安全命名空间
|
||||
safe_namespace = {
|
||||
**cls.SAFE_BUILTINS,
|
||||
**cls.SAFE_MATH,
|
||||
"__builtins__": {} # 禁用__builtins__
|
||||
}
|
||||
|
||||
try:
|
||||
result = eval(expression, safe_namespace)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise ValueError(f"Evaluation error: {e}")
|
||||
|
||||
|
||||
# 全局沙盒实例
|
||||
sandbox = Sandbox()
|
||||
|
||||
|
||||
# 装饰器:快速将函数封装为沙盒执行
|
||||
def sandboxed(timeout: int = 30):
|
||||
"""装饰器:为函数添加沙盒执行能力"""
|
||||
def decorator(func):
|
||||
def wrapper(code: str, *args, **kwargs):
|
||||
result = sandbox.execute(code, timeout=timeout)
|
||||
if not result["success"]:
|
||||
raise RuntimeError(result.get("error", "Execution failed"))
|
||||
return result["output"]
|
||||
return wrapper
|
||||
return decorator
|
||||
149
agent/app/api/routes.py
Normal file
149
agent/app/api/routes.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
API 路由定义
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.agent.core.agent import AgentManager
|
||||
from app.security.approval import ApprovalService
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# 全局依赖(实际应该注入)
|
||||
_agent_manager: Optional[AgentManager] = None
|
||||
_approval_service: Optional[ApprovalService] = None
|
||||
|
||||
|
||||
def get_agent_manager() -> AgentManager:
|
||||
"""获取 Agent 管理器"""
|
||||
# 这里应该从 app.state 获取
|
||||
from app.main import agent_manager
|
||||
if agent_manager is None:
|
||||
raise HTTPException(status_code=503, detail="Agent service not initialized")
|
||||
return agent_manager
|
||||
|
||||
|
||||
def get_approval_service() -> ApprovalService:
|
||||
"""获取审批服务"""
|
||||
global _approval_service
|
||||
if _approval_service is None:
|
||||
_approval_service = ApprovalService()
|
||||
return _approval_service
|
||||
|
||||
|
||||
# ==================== 请求/响应模型 ====================
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""聊天请求"""
|
||||
agent_id: str
|
||||
message: str
|
||||
session_id: str = ""
|
||||
context: dict = {}
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""聊天响应"""
|
||||
reply: str
|
||||
session_id: str
|
||||
tools_used: list[str] = []
|
||||
metadata: dict = {}
|
||||
|
||||
|
||||
class ApprovalRequest(BaseModel):
|
||||
"""审批请求"""
|
||||
request_id: str
|
||||
tool_name: str
|
||||
params: dict
|
||||
reason: str
|
||||
approved: bool
|
||||
|
||||
|
||||
# ==================== API 端点 ====================
|
||||
|
||||
@router.post("/chat", response_model=ChatResponse)
|
||||
async def chat(
|
||||
request: ChatRequest,
|
||||
agent_manager: AgentManager = Depends(get_agent_manager)
|
||||
):
|
||||
"""处理 Agent 聊天请求"""
|
||||
try:
|
||||
# 生成会话ID
|
||||
if not request.session_id:
|
||||
import uuid
|
||||
request.session_id = str(uuid.uuid4())
|
||||
|
||||
# 执行 Agent
|
||||
result = await agent_manager.execute(
|
||||
agent_id=request.agent_id,
|
||||
message=request.message,
|
||||
session_id=request.session_id,
|
||||
context=request.context
|
||||
)
|
||||
|
||||
return ChatResponse(
|
||||
reply=result.get("reply", ""),
|
||||
session_id=request.session_id,
|
||||
tools_used=result.get("tools_used", []),
|
||||
metadata=result.get("metadata", {})
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Agent execution failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/tool/request")
|
||||
async def request_tool_execution(
|
||||
request: dict,
|
||||
approval_service: ApprovalService = Depends(get_approval_service)
|
||||
):
|
||||
"""请求执行工具(需要审批)"""
|
||||
tool_name = request.get("tool_name")
|
||||
params = request.get("params", {})
|
||||
user_id = request.get("user_id", "unknown")
|
||||
agent_id = request.get("agent_id")
|
||||
reason = request.get("reason", "")
|
||||
|
||||
# 创建审批请求
|
||||
request_id = await approval_service.request_approval(
|
||||
tool_name=tool_name,
|
||||
params=params,
|
||||
user_id=user_id,
|
||||
agent_id=agent_id or "",
|
||||
reason=reason
|
||||
)
|
||||
|
||||
return {
|
||||
"request_id": request_id,
|
||||
"status": "pending"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/tools")
|
||||
async def list_tools(agent_manager: AgentManager = Depends(get_agent_manager)):
|
||||
"""列出所有可用工具"""
|
||||
tools = agent_manager.list_tools()
|
||||
return {"tools": [tool.dict() for tool in tools]}
|
||||
|
||||
|
||||
@router.get("/agents")
|
||||
async def list_agents(agent_manager: AgentManager = Depends(get_agent_manager)):
|
||||
"""列出所有已加载的 Agent"""
|
||||
agents = agent_manager.list_agents()
|
||||
return {"agents": agents}
|
||||
|
||||
|
||||
@router.get("/agent/{agent_id}")
|
||||
async def get_agent(
|
||||
agent_id: str,
|
||||
agent_manager: AgentManager = Depends(get_agent_manager)
|
||||
):
|
||||
"""获取特定 Agent 信息"""
|
||||
agent_info = agent_manager.get_agent_info(agent_id)
|
||||
if not agent_info:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
return agent_info
|
||||
63
agent/app/llm/factory.py
Normal file
63
agent/app/llm/factory.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
LLM 工厂 - 创建不同提供商的 LLM 实例
|
||||
"""
|
||||
from typing import Optional
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
|
||||
class LLMFactory:
|
||||
"""LLM 工厂类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: str = "openai",
|
||||
openai_api_key: Optional[str] = None,
|
||||
anthropic_api_key: Optional[str] = None,
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000
|
||||
):
|
||||
self.provider = provider
|
||||
self.openai_api_key = openai_api_key
|
||||
self.anthropic_api_key = anthropic_api_key
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
self._llm = None
|
||||
|
||||
def get_llm(self):
|
||||
"""获取 LLM 实例"""
|
||||
if self._llm is not None:
|
||||
return self._llm
|
||||
|
||||
if self.provider == "openai":
|
||||
self._llm = ChatOpenAI(
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
api_key=self.openai_api_key
|
||||
)
|
||||
elif self.provider == "anthropic":
|
||||
self._llm = ChatAnthropic(
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
anthropic_api_key=self.anthropic_api_key
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
return self._llm
|
||||
|
||||
def set_model(self, model: str):
|
||||
"""设置模型"""
|
||||
self.model = model
|
||||
self._llm = None # 重置 LLM 实例
|
||||
|
||||
def set_temperature(self, temperature: float):
|
||||
"""设置温度"""
|
||||
self.temperature = temperature
|
||||
if self._llm:
|
||||
self._llm.temperature = temperature
|
||||
84
agent/app/main.py
Normal file
84
agent/app/main.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
X-Agents Python Agent Service
|
||||
智能体引擎服务入口
|
||||
"""
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api import routes
|
||||
from app.agent.core.agent import AgentManager
|
||||
from app.security.audit import AuditLogger
|
||||
|
||||
|
||||
# 全局组件
|
||||
agent_manager: AgentManager = None
|
||||
audit_logger: AuditLogger = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
global agent_manager, audit_logger
|
||||
|
||||
# 启动时初始化
|
||||
audit_logger = AuditLogger()
|
||||
|
||||
# 初始化 Agent 管理器
|
||||
agent_manager = AgentManager(
|
||||
llm_provider=os.getenv("LLM_PROVIDER", "openai"),
|
||||
openai_api_key=os.getenv("OPENAI_API_KEY"),
|
||||
anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
)
|
||||
|
||||
# 加载 Agent 配置
|
||||
await agent_manager.load_agents()
|
||||
|
||||
print("Agent service started successfully")
|
||||
|
||||
yield
|
||||
|
||||
# 关闭时清理
|
||||
print("Agent service shutting down")
|
||||
|
||||
|
||||
# 创建 FastAPI 应用
|
||||
app = FastAPI(
|
||||
title="X-Agents Agent Service",
|
||||
description="AI Agent Engine for X-Agents Platform",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS 中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
app.include_router(routes.router, prefix="/agent", tags=["Agent"])
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "agent",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""根路径"""
|
||||
return {
|
||||
"message": "X-Agents Agent Service",
|
||||
"docs": "/docs"
|
||||
}
|
||||
104
agent/app/security/approval.py
Normal file
104
agent/app/security/approval.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
审批服务 - 处理工具执行的审批流程
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ApprovalStatus(Enum):
|
||||
"""审批状态"""
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
REJECTED = "rejected"
|
||||
|
||||
|
||||
class ApprovalService:
|
||||
"""审批服务"""
|
||||
|
||||
def __init__(self):
|
||||
# 待审批队列
|
||||
self.pending: Dict[str, dict] = {}
|
||||
# 审批结果
|
||||
self.results: Dict[str, ApprovalStatus] = {}
|
||||
|
||||
async def request_approval(
|
||||
self,
|
||||
tool_name: str,
|
||||
params: dict,
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
reason: str
|
||||
) -> str:
|
||||
"""
|
||||
请求审批
|
||||
|
||||
Returns:
|
||||
request_id: 审批请求ID
|
||||
"""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
request = {
|
||||
"request_id": request_id,
|
||||
"tool_name": tool_name,
|
||||
"params": params,
|
||||
"user_id": user_id,
|
||||
"agent_id": agent_id,
|
||||
"reason": reason,
|
||||
"status": ApprovalStatus.PENDING,
|
||||
"created_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
self.pending[request_id] = request
|
||||
self.results[request_id] = ApprovalStatus.PENDING
|
||||
|
||||
# TODO: 通知 Go 后端有新审批
|
||||
|
||||
return request_id
|
||||
|
||||
async def check_approval(self, request_id: str, timeout: int = 300) -> bool:
|
||||
"""
|
||||
检查审批状态
|
||||
|
||||
Args:
|
||||
request_id: 审批请求ID
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
是否已批准
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
start = datetime.now()
|
||||
|
||||
while (datetime.now() - start).seconds < timeout:
|
||||
status = self.results.get(request_id)
|
||||
|
||||
if status == ApprovalStatus.APPROVED:
|
||||
return True
|
||||
elif status == ApprovalStatus.REJECTED:
|
||||
return False
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
raise TimeoutError("Approval request timeout")
|
||||
|
||||
async def approve(self, request_id: str):
|
||||
"""批准请求"""
|
||||
if request_id in self.pending:
|
||||
self.pending[request_id]["status"] = ApprovalStatus.APPROVED
|
||||
self.results[request_id] = ApprovalStatus.APPROVED
|
||||
|
||||
async def reject(self, request_id: str):
|
||||
"""拒绝请求"""
|
||||
if request_id in self.pending:
|
||||
self.pending[request_id]["status"] = ApprovalStatus.REJECTED
|
||||
self.results[request_id] = ApprovalStatus.REJECTED
|
||||
|
||||
def get_pending(self) -> list[dict]:
|
||||
"""获取待审批列表"""
|
||||
return [
|
||||
req for req in self.pending.values()
|
||||
if req["status"] == ApprovalStatus.PENDING
|
||||
]
|
||||
81
agent/app/security/audit.py
Normal file
81
agent/app/security/audit.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
审计日志 - 记录所有 Agent 操作
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
"""审计日志记录器"""
|
||||
|
||||
def __init__(self, log_file: str = "audit.log"):
|
||||
self.log_file = log_file
|
||||
|
||||
def log(
|
||||
self,
|
||||
action: str,
|
||||
agent_id: str = "",
|
||||
session_id: str = "",
|
||||
user_id: str = "",
|
||||
details: Dict[str, Any] = None,
|
||||
result: str = "success"
|
||||
):
|
||||
"""记录审计日志"""
|
||||
entry = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"action": action,
|
||||
"agent_id": agent_id,
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
"details": details or {},
|
||||
"result": result
|
||||
}
|
||||
|
||||
# 写入文件
|
||||
self._write_log(entry)
|
||||
|
||||
# TODO: 发送到 Go 后端
|
||||
|
||||
def log_tool_execution(
|
||||
self,
|
||||
tool_name: str,
|
||||
params: Dict[str, Any],
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
approved: bool,
|
||||
result: Any
|
||||
):
|
||||
"""记录工具执行"""
|
||||
self.log(
|
||||
action="tool_execution",
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
details={
|
||||
"tool_name": tool_name,
|
||||
"params": params,
|
||||
"approved": approved,
|
||||
"result_preview": str(result)[:200] if result else None
|
||||
},
|
||||
result="approved" if approved else "pending_approval"
|
||||
)
|
||||
|
||||
def log_error(self, action: str, error: str, **kwargs):
|
||||
"""记录错误"""
|
||||
self.log(
|
||||
action=action,
|
||||
details={"error": error, **kwargs},
|
||||
result="error"
|
||||
)
|
||||
|
||||
def _write_log(self, entry: dict):
|
||||
"""写入日志文件"""
|
||||
try:
|
||||
log_path = Path(self.log_file)
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(log_path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||
except Exception as e:
|
||||
print(f"Failed to write audit log: {e}")
|
||||
Reference in New Issue
Block a user