diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..e06b026 --- /dev/null +++ b/.env.example @@ -0,0 +1,11 @@ +# JWT 配置 +JWT_SECRET=your-secret-key-change-in-production + +# LLM 提供商 (openai / anthropic) +LLM_PROVIDER=openai + +# OpenAI API Key +OPENAI_API_KEY=your-openai-api-key + +# Anthropic API Key +ANTHROPIC_API_KEY=your-anthropic-api-key diff --git a/agent/Dockerfile b/agent/Dockerfile new file mode 100644 index 0000000..4f904ac --- /dev/null +++ b/agent/Dockerfile @@ -0,0 +1,28 @@ +# Python Agent Service Dockerfile + +FROM python:3.11-slim + +WORKDIR /app + +# 安装系统依赖 +RUN apt-get update && apt-get install -y \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# 复制依赖文件 +COPY requirements.txt . + +# 安装 Python 依赖 +RUN pip install --no-cache-dir -r requirements.txt + +# 复制应用代码 +COPY app/ ./app/ + +# 创建数据目录 +RUN mkdir -p /app/data + +# 暴露端口 +EXPOSE 8081 + +# 启动服务 +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8081"] diff --git a/agent/app/agent/core/agent.py b/agent/app/agent/core/agent.py new file mode 100644 index 0000000..38c1cc8 --- /dev/null +++ b/agent/app/agent/core/agent.py @@ -0,0 +1,192 @@ +""" +Agent 核心管理器 +""" +import os +from typing import Any, Optional + +from app.agent.core.executor import AgentExecutor +from app.agent.memory.session import SessionManager +from app.agent.tools.registry import ToolRegistry +from app.llm.factory import LLMFactory +from app.security.audit import AuditLogger + + +class AgentManager: + """Agent 管理器 - 负责加载和管理所有 Agent""" + + def __init__( + self, + llm_provider: str = "openai", + openai_api_key: Optional[str] = None, + anthropic_api_key: Optional[str] = None, + ): + self.llm_provider = llm_provider + self.openai_api_key = openai_api_key or os.getenv("OPENAI_API_KEY") + self.anthropic_api_key = anthropic_api_key or os.getenv("ANTHROPIC_API_KEY") + + # 初始化组件 + self.llm_factory = LLMFactory( + provider=llm_provider, + openai_api_key=self.openai_api_key, + anthropic_api_key=self.anthropic_api_key + ) + self.tool_registry = ToolRegistry() + self.session_manager = SessionManager() + self.audit_logger = AuditLogger() + + # 已加载的 Agent + self.agents: dict[str, dict] = {} + self.executors: dict[str, AgentExecutor] = {} + + # 注册默认工具 + self._register_default_tools() + + def _register_default_tools(self): + """注册默认工具""" + from app.agent.tools.impl import search, calculator, time_tool + from app.agent.tools.impl import sandbox, database, api_client + + # 安全工具 - Safe 级别 + self.tool_registry.register( + name="search", + func=search.search_web, + description="Search the web for information", + security_level="safe" + ) + + self.tool_registry.register( + name="calculator", + func=calculator.calculate, + description="Perform mathematical calculations", + security_level="safe" + ) + + self.tool_registry.register( + name="get_current_time", + func=time_tool.get_current_time, + description="Get current date and time", + security_level="safe" + ) + + # 需要审核的工具 - Review 级别 + self.tool_registry.register( + name="execute_code", + func=sandbox.sandbox.execute, + description="Execute code in sandbox (Python/JavaScript)", + security_level="review", + require_approval=True, + parameters={ + "type": "object", + "properties": { + "code": {"type": "string", "description": "Code to execute"}, + "language": {"type": "string", "default": "python"}, + "timeout": {"type": "integer", "default": 30} + }, + "required": ["code"] + } + ) + + self.tool_registry.register( + name="query_database", + func=database.query_data, + description="Query database (SELECT only)", + security_level="review", + require_approval=True, + parameters={ + "type": "object", + "properties": { + "sql": {"type": "string", "description": "SELECT query"} + }, + "required": ["sql"] + } + ) + + self.tool_registry.register( + name="call_api", + func=api_client.call_api, + description="Call external API (whitelist only)", + security_level="review", + require_approval=True, + parameters={ + "type": "object", + "properties": { + "api_name": {"type": "string"}, + "endpoint": {"type": "string"}, + "params": {"type": "object"} + }, + "required": ["api_name"] + } + ) + + async def load_agents(self): + """加载 Agent 配置""" + # TODO: 从数据库或配置文件加载 + # 这里先注册一些示例 Agent + + self.agents["assistant"] = { + "name": "General Assistant", + "description": "A general purpose assistant", + "system_prompt": "You are a helpful assistant.", + "tools": ["search", "calculator", "get_current_time"] + } + + self.agents["coder"] = { + "name": "Code Assistant", + "description": "Helps with coding tasks", + "system_prompt": "You are a helpful coding assistant. You can write, explain, and debug code.", + "tools": ["search", "calculator"] + } + + # 为每个 Agent 创建执行器 + for agent_id, config in self.agents.items(): + self.executors[agent_id] = AgentExecutor( + agent_id=agent_id, + llm_factory=self.llm_factory, + tool_registry=self.tool_registry, + session_manager=self.session_manager, + audit_logger=self.audit_logger, + config=config + ) + + async def execute( + self, + agent_id: str, + message: str, + session_id: str, + context: dict = None + ) -> dict[str, Any]: + """执行 Agent""" + if agent_id not in self.executors: + raise ValueError(f"Agent '{agent_id}' not found") + + executor = self.executors[agent_id] + + # 执行 + result = await executor.run( + message=message, + session_id=session_id, + context=context or {} + ) + + return result + + def list_tools(self) -> list: + """列出所有可用工具""" + return self.tool_registry.list_tools() + + def list_agents(self) -> list[dict]: + """列出所有 Agent""" + return [ + { + "id": agent_id, + "name": config["name"], + "description": config["description"] + } + for agent_id, config in self.agents.items() + ] + + def get_agent_info(self, agent_id: str) -> Optional[dict]: + """获取 Agent 信息""" + if agent_id not in self.agents: + return None + return self.agents[agent_id] diff --git a/agent/app/agent/core/executor.py b/agent/app/agent/core/executor.py new file mode 100644 index 0000000..9d53195 --- /dev/null +++ b/agent/app/agent/core/executor.py @@ -0,0 +1,163 @@ +""" +Agent 执行器 - 负责执行 Agent 的核心逻辑 +""" +from typing import Any, Optional +from app.llm.factory import LLMFactory +from app.agent.tools.registry import ToolRegistry +from app.agent.memory.session import SessionManager +from app.security.audit import AuditLogger + + +class AgentExecutor: + """Agent 执行器""" + + def __init__( + self, + agent_id: str, + llm_factory: LLMFactory, + tool_registry: ToolRegistry, + session_manager: SessionManager, + audit_logger: AuditLogger, + config: dict + ): + self.agent_id = agent_id + self.llm_factory = llm_factory + self.tool_registry = tool_registry + self.session_manager = session_manager + self.audit_logger = audit_logger + self.config = config + + # 获取 LLM + self.llm = self.llm_factory.get_llm() + + async def run( + self, + message: str, + session_id: str, + context: dict + ) -> dict[str, Any]: + """运行 Agent""" + tools_used = [] + + # 1. 获取会话历史 + history = self.session_manager.get_history(session_id) + + # 2. 构建消息列表 + messages = self._build_messages(message, history) + + # 3. 获取可用工具 + available_tools = self._get_available_tools() + + # 4. 调用 LLM(带工具) + try: + response = await self.llm.agenerate( + messages=messages, + tools=available_tools + ) + + # 检查是否需要调用工具 + response_message = response.generations[0][0] + + # 如果有工具调用 + if hasattr(response_message, "tool_calls") and response_message.tool_calls: + for tool_call in response_message.tool_calls: + tool_name = tool_call.name + tool_args = tool_call.arguments + + # 记录工具使用 + tools_used.append(tool_name) + + # 执行工具 + tool_result = await self._execute_tool(tool_name, tool_args) + + # 添加工具结果到消息 + messages.append(response_message) + messages.append({ + "role": "tool", + "tool_call_id": tool_call.id, + "content": str(tool_result) + }) + + # 再次调用 LLM 生成最终响应 + final_response = await self.llm.agenerate(messages=messages) + final_message = final_response.generations[0][0].text + + # 保存到历史 + self.session_manager.add_message(session_id, "user", message) + self.session_manager.add_message(session_id, "assistant", final_message) + + return { + "reply": final_message, + "tools_used": tools_used, + "metadata": {} + } + + # 没有工具调用,直接返回 + reply = response_message.text + + # 保存到历史 + self.session_manager.add_message(session_id, "user", message) + self.session_manager.add_message(session_id, "assistant", reply) + + return { + "reply": reply, + "tools_used": tools_used, + "metadata": {} + } + + except Exception as e: + # 记录错误 + self.audit_logger.log( + action="agent_error", + agent_id=self.agent_id, + session_id=session_id, + details={"error": str(e)} + ) + raise + + def _build_messages(self, message: str, history: list) -> list: + """构建消息列表""" + messages = [] + + # 添加系统提示 + system_prompt = self.config.get("system_prompt", "You are a helpful assistant.") + messages.append({"role": "system", "content": system_prompt}) + + # 添加历史 + for msg in history: + messages.append(msg) + + # 添加当前消息 + messages.append({"role": "user", "content": message}) + + return messages + + def _get_available_tools(self) -> list: + """获取可用工具定义""" + agent_tools = self.config.get("tools", []) + tool_defs = [] + + for tool_name in agent_tools: + tool_def = self.tool_registry.get_tool_definition(tool_name) + if tool_def: + tool_defs.append(tool_def) + + return tool_defs + + async def _execute_tool(self, tool_name: str, args: dict) -> Any: + """执行工具""" + # 安全检查 + tool_func, metadata = self.tool_registry.get_tool(tool_name) + + # 如果需要审批,抛出异常 + if metadata.require_approval: + raise PermissionError( + f"Tool '{tool_name}' requires approval before execution" + ) + + # 执行工具 + try: + result = tool_func(**args) + return result + except Exception as e: + return {"error": str(e)} diff --git a/agent/app/agent/memory/session.py b/agent/app/agent/memory/session.py new file mode 100644 index 0000000..62f26a8 --- /dev/null +++ b/agent/app/agent/memory/session.py @@ -0,0 +1,62 @@ +""" +会话管理器 - 管理 Agent 的会话历史 +""" +from typing import Any, Optional +from collections import defaultdict +from datetime import datetime + + +class SessionManager: + """会话管理器""" + + def __init__(self, max_history: int = 10): + """ + 初始化会话管理器 + + Args: + max_history: 每个会话保留的最大历史消息数 + """ + self.max_history = max_history + self.sessions: dict[str, list[dict]] = defaultdict(list) + self.metadata: dict[str, dict] = {} + + def add_message(self, session_id: str, role: str, content: str): + """添加消息到会话""" + self.sessions[session_id].append({ + "role": role, + "content": content, + "timestamp": datetime.now().isoformat() + }) + + # 限制历史长度 + if len(self.sessions[session_id]) > self.max_history: + self.sessions[session_id] = self.sessions[session_id][-self.max_history:] + + def get_history(self, session_id: str) -> list[dict]: + """获取会话历史""" + return self.sessions.get(session_id, []) + + def clear_session(self, session_id: str): + """清除会话""" + if session_id in self.sessions: + del self.sessions[session_id] + if session_id in self.metadata: + del self.metadata[session_id] + + def set_metadata(self, session_id: str, key: str, value: Any): + """设置会话元数据""" + if session_id not in self.metadata: + self.metadata[session_id] = {} + self.metadata[session_id][key] = value + + def get_metadata(self, session_id: str, key: str, default: Any = None) -> Any: + """获取会话元数据""" + return self.metadata.get(session_id, {}).get(key, default) + + def list_sessions(self) -> list[str]: + """列出所有会话ID""" + return list(self.sessions.keys()) + + def get_session_count(self) -> int: + """获取会话数量""" + return len(self.sessions) diff --git a/agent/app/agent/tools/impl/__init__.py b/agent/app/agent/tools/impl/__init__.py new file mode 100644 index 0000000..94ea043 --- /dev/null +++ b/agent/app/agent/tools/impl/__init__.py @@ -0,0 +1,22 @@ +""" +工具实现模块 +""" + +# 基础工具 +from . import search +from . import calculator +from . import time_tool + +# 安全工具 +from . import sandbox +from . import database +from . import api_client + +__all__ = [ + "search", + "calculator", + "time_tool", + "sandbox", + "database", + "api_client", +] diff --git a/agent/app/agent/tools/impl/api_client.py b/agent/app/agent/tools/impl/api_client.py new file mode 100644 index 0000000..d7ec6c0 --- /dev/null +++ b/agent/app/agent/tools/impl/api_client.py @@ -0,0 +1,166 @@ +""" +API 调用工具 - 安全的外部 API 调用 +""" +import httpx +from typing import Dict, Any, Optional +from dataclasses import dataclass +from enum import Enum + + +class APIPermission(Enum): + """API 权限级别""" + PUBLIC = "public" # 公开 API + APPROVED = "approved" # 已审批的 API + ADMIN = "admin" # 管理员 API + + +@dataclass +class APIEndpoint: + """API 端点定义""" + name: str + url: str + method: str + permission: APIPermission + description: str + rate_limit: int = 60 # 每分钟请求次数 + + +# API 白名单 +ALLOWED_APIS = [ + APIEndpoint( + name="weather", + url="https://api.weather.example.com/v1", + method="GET", + permission=APIPermission.PUBLIC, + description="获取天气信息", + rate_limit=30 + ), + APIEndpoint( + name="news", + url="https://newsapi.org/v2", + method="GET", + permission=APIPermission.PUBLIC, + description="获取新闻", + rate_limit=30 + ), + # 可以添加更多已审批的 API +] + + +class APICallTool: + """ + API 调用工具 + + 安全特性: + - 只允许调用白名单中的 API + - 速率限制 + - 请求超时 + - 响应大小限制 + """ + + def __init__(self): + self.allowed_apis = {api.name: api for api in ALLOWED_APIS} + self.request_timeout = 10 # 请求超时(秒) + self.max_response_size = 1024 * 1024 # 最大响应大小(1MB) + + async def call( + self, + api_name: str, + endpoint: str = "", + params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None + ) -> Dict[str, Any]: + """ + 调用 API + + Args: + api_name: API 名称(必须在白名单中) + endpoint: 具体的端点 + params: 查询参数 + headers: 请求头 + + Returns: + API 响应 + """ + # 安全检查1: API 必须在白名单中 + if api_name not in self.allowed_apis: + return { + "success": False, + "error": f"API '{api_name}' not in whitelist. Allowed: {list(self.allowed_apis.keys())}" + } + + api = self.allowed_apis[api_name] + + # 构建完整 URL + url = f"{api.url}/{endpoint}" if endpoint else api.url + + try: + async with httpx.AsyncClient(timeout=self.request_timeout) as client: + # 根据方法调用 + if api.method == "GET": + response = await client.get(url, params=params, headers=headers) + elif api.method == "POST": + response = await client.post(url, json=params, headers=headers) + else: + return { + "success": False, + "error": f"Method {api.method} not supported" + } + + # 检查响应大小 + if len(response.content) > self.max_response_size: + return { + "success": False, + "error": f"Response too large (max {self.max_response_size} bytes)" + } + + return { + "success": True, + "status_code": response.status_code, + "data": response.json() if response.headers.get("content-type", "").startswith("application/json") else response.text, + "headers": dict(response.headers) + } + + except httpx.TimeoutException: + return { + "success": False, + "error": "Request timeout" + } + except Exception as e: + return { + "success": False, + "error": str(e) + } + + def list_apis(self) -> list: + """列出所有可用的 API""" + return [ + { + "name": api.name, + "description": api.description, + "method": api.method, + "permission": api.permission.value, + "rate_limit": api.rate_limit + } + for api in ALLOWED_APIS + ] + + +# 全局实例 +api_tool = APICallTool() + + +async def call_api( + api_name: str, + endpoint: str = "", + params: Optional[Dict[str, Any]] = None +) -> Dict[str, Any]: + """ + API 调用工具(供 Agent 使用) + """ + return await api_tool.call(api_name, endpoint, params) + + +def list_allowed_apis() -> list: + """列出允许的 API""" + return api_tool.list_apis() diff --git a/agent/app/agent/tools/impl/calculator.py b/agent/app/agent/tools/impl/calculator.py new file mode 100644 index 0000000..9b57b67 --- /dev/null +++ b/agent/app/agent/tools/impl/calculator.py @@ -0,0 +1,91 @@ +""" +计算器工具 +""" +import ast +import operator +from typing import Any + + +# 安全运算符 +SAFE_OPERATORS = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.Pow: operator.pow, + ast.Mod: operator.mod, + ast.USub: operator.neg, +} + + +def safe_eval_expr(node): + """安全地求值表达式节点""" + if isinstance(node, ast.Num): + return node.n + elif isinstance(node, ast.BinOp): + left = safe_eval_expr(node.left) + right = safe_eval_expr(node.right) + op_type = type(node.op) + if op_type in SAFE_OPERATORS: + return SAFE_OPERATORS[op_type](left, right) + raise ValueError(f"Unsupported operator: {op_type}") + elif isinstance(node, ast.UnaryOp): + operand = safe_eval_expr(node.operand) + op_type = type(node.op) + if op_type in SAFE_OPERATORS: + return SAFE_OPERATORS[op_type](operand) + raise ValueError(f"Unsupported unary operator: {op_type}") + else: + raise ValueError(f"Unsupported expression: {ast.dump(node)}") + + +def calculate(expression: str) -> dict: + """ + 执行数学计算 + + Args: + expression: 数学表达式,如 "2 + 2" 或 "sqrt(16)" + + Returns: + 计算结果 + """ + try: + # 预处理:处理常见数学函数 + expression = expression.replace("sqrt", "**0.5") + expression = expression.replace("pi", "3.14159265359") + expression = expression.replace("e", "2.71828182846") + + # 解析表达式 + tree = ast.parse(expression, mode='eval') + result = safe_eval_expr(tree.body) + + return { + "success": True, + "expression": expression, + "result": result, + "type": type(result).__name__ + } + + except Exception as e: + return { + "success": False, + "expression": expression, + "error": str(e) + } + + +# 工具定义 +TOOL_DEFINITION = { + "name": "calculator", + "description": "Perform mathematical calculations. Supports basic arithmetic (+, -, *, /), powers (**), and functions (sqrt).", + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "Mathematical expression to evaluate, e.g., '2 + 2' or 'sqrt(16) + 5'" + } + }, + "required": ["expression"] + } +} diff --git a/agent/app/agent/tools/impl/database.py b/agent/app/agent/tools/impl/database.py new file mode 100644 index 0000000..2f2a4ca --- /dev/null +++ b/agent/app/agent/tools/impl/database.py @@ -0,0 +1,96 @@ +""" +数据库查询工具 - 安全的数据查询接口 +""" +from typing import Dict, Any, List, Optional +import os + + +# 只读查询白名单 - 只允许 SELECT 语句 +ALLOWED_TABLES = ["users", "agents", "sessions", "audit_logs"] + + +class DatabaseQueryTool: + """ + 数据库查询工具 + + 安全特性: + - 只允许 SELECT 查询 + - 表名白名单 + - 结果数量限制 + """ + + def __init__(self, connection_string: str = ""): + self.connection_string = connection_string or os.getenv( + "DATABASE_URL", + "postgresql://postgres:postgres@localhost:5432/x_agents" + ) + self.max_rows = 100 # 最多返回100行 + + def query(self, sql: str, params: List[Any] = None) -> Dict[str, Any]: + """ + 执行查询 + + Args: + sql: SQL 查询语句(必须是 SELECT) + params: 查询参数 + + Returns: + 查询结果 + """ + # 安全检查1: 必须是 SELECT 语句 + sql_upper = sql.strip().upper() + if not sql_upper.startswith("SELECT"): + return { + "success": False, + "error": "Only SELECT queries are allowed" + } + + # 安全检查2: 禁止危险关键字 + dangerous_keywords = [ + "DROP", "DELETE", "INSERT", "UPDATE", "ALTER", + "CREATE", "TRUNCATE", "EXEC", "EXECUTE" + ] + for keyword in dangerous_keywords: + if keyword in sql_upper: + return { + "success": False, + "error": f"Keyword '{keyword}' is not allowed" + } + + # 安全检查3: 表名白名单 + for table in ALLOWED_TABLES: + if f"FROM {table}" in sql_upper or f"JOIN {table}" in sql_upper: + # 表名在白名单中,允许 + break + else: + # 没有找到白名单表 + return { + "success": False, + "error": f"Table not in whitelist. Allowed: {ALLOWED_TABLES}" + } + + # TODO: 实际执行查询(需要数据库连接) + # 这里返回模拟数据 + return { + "success": True, + "message": "Query executed (mock mode - database not connected)", + "rows": [], + "columns": [] + } + + +# 全局实例 +db_tool = DatabaseQueryTool() + + +def query_data(sql: str) -> Dict[str, Any]: + """ + 查询数据工具 + + Args: + sql: SELECT 查询语句 + + Returns: + 查询结果 + """ + return db_tool.query(sql) diff --git a/agent/app/agent/tools/impl/search.py b/agent/app/agent/tools/impl/search.py new file mode 100644 index 0000000..598daba --- /dev/null +++ b/agent/app/agent/tools/impl/search.py @@ -0,0 +1,87 @@ +""" +网页搜索工具 +""" +import httpx +from typing import Optional + + +async def search_web(query: str, max_results: int = 5) -> dict: + """ + 搜索网页获取信息 + + Args: + query: 搜索关键词 + max_results: 返回结果数量 + + Returns: + 搜索结果 + """ + # 这里可以使用搜索引擎API,如 Google, Bing, DuckDuckGo 等 + # 示例使用 DuckDuckGo API(免费) + + try: + async with httpx.AsyncClient() as client: + response = await client.get( + "https://api.duckduckgo.com/", + params={ + "q": query, + "format": "json", + "no_html": 1, + "skip_disambig": 1 + }, + timeout=10.0 + ) + + if response.status_code == 200: + data = response.json() + results = [] + + # 提取相关主题 + if "RelatedTopics" in data: + for item in data["RelatedTopics"][:max_results]: + if "Text" in item: + results.append({ + "title": item.get("Text", "").split(" - ")[0] if " - " in item.get("Text", "") else "", + "content": item.get("Text", ""), + "url": item.get("URL", "") + }) + + return { + "success": True, + "query": query, + "results": results, + "count": len(results) + } + else: + return { + "success": False, + "error": f"Search API returned status {response.status_code}" + } + + except Exception as e: + return { + "success": False, + "error": str(e) + } + + +# 工具定义(用于 LLM) +TOOL_DEFINITION = { + "name": "search", + "description": "Search the web for information. Use this when you need to find current information or facts.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query" + }, + "max_results": { + "type": "integer", + "description": "Maximum number of results to return", + "default": 5 + } + }, + "required": ["query"] + } +} diff --git a/agent/app/agent/tools/impl/time_tool.py b/agent/app/agent/tools/impl/time_tool.py new file mode 100644 index 0000000..b6671f4 --- /dev/null +++ b/agent/app/agent/tools/impl/time_tool.py @@ -0,0 +1,70 @@ +""" +时间工具 +""" +from datetime import datetime +from typing import Optional + + +def get_current_time(timezone: Optional[str] = None) -> dict: + """ + 获取当前时间 + + Args: + timezone: 时区名称,如 "UTC", "Asia/Shanghai" + + Returns: + 当前时间信息 + """ + now = datetime.now() + + return { + "success": True, + "datetime": now.isoformat(), + "timestamp": now.timestamp(), + "date": now.strftime("%Y-%m-%d"), + "time": now.strftime("%H:%M:%S"), + "weekday": now.strftime("%A"), + "timezone": timezone or "Local Time" + } + + +def format_time(timestamp: float, format_str: str = "%Y-%m-%d %H:%M:%S") -> dict: + """ + 格式化时间戳 + + Args: + timestamp: Unix 时间戳 + format_str: 格式字符串 + + Returns: + 格式化后的时间 + """ + try: + dt = datetime.fromtimestamp(timestamp) + return { + "success": True, + "formatted": dt.strftime(format_str), + "datetime": dt.isoformat() + } + except Exception as e: + return { + "success": False, + "error": str(e) + } + + +# 工具定义 +TOOL_DEFINITION = { + "name": "get_current_time", + "description": "Get the current date and time. Useful for timestamps or scheduling.", + "parameters": { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "Optional timezone (e.g., 'UTC', 'Asia/Shanghai')", + "default": "Local" + } + } + } +} diff --git a/agent/app/agent/tools/registry.py b/agent/app/agent/tools/registry.py new file mode 100644 index 0000000..9fbed3f --- /dev/null +++ b/agent/app/agent/tools/registry.py @@ -0,0 +1,99 @@ +""" +工具注册表 - 管理所有可用工具(白名单机制) +""" +from typing import Any, Callable, Optional +from dataclasses import dataclass, asdict +from enum import Enum + + +class SecurityLevel(Enum): + """工具安全等级""" + SAFE = "safe" # 安全操作 + REVIEW = "review" # 需要审核 + DANGER = "danger" # 危险操作 + + +@dataclass +class ToolMetadata: + """工具元数据""" + name: str + description: str + security_level: str + require_approval: bool = False + allowed_roles: list = None + + def dict(self): + return { + "name": self.name, + "description": self.description, + "security_level": self.security_level, + "require_approval": self.require_approval + } + + +class ToolRegistry: + """工具注册表""" + + def __init__(self): + self._tools: dict[str, tuple[Callable, ToolMetadata]] = {} + self._definitions: dict[str, dict] = {} + + def register( + self, + name: str, + func: Callable, + description: str = "", + security_level: str = "safe", + require_approval: bool = False, + allowed_roles: list = None, + parameters: dict = None + ): + """注册工具到白名单""" + metadata = ToolMetadata( + name=name, + description=description, + security_level=security_level, + require_approval=require_approval, + allowed_roles=allowed_roles or ["user", "admin"] + ) + + self._tools[name] = (func, metadata) + + # 生成工具定义(用于 LLM 调用) + self._definitions[name] = { + "name": name, + "description": description, + "parameters": parameters or { + "type": "object", + "properties": {}, + "required": [] + } + } + + def get_tool(self, name: str) -> tuple[Callable, ToolMetadata]: + """获取工具函数和元数据""" + if name not in self._tools: + raise ValueError(f"Tool '{name}' not found in whitelist") + return self._tools[name] + + def get_tool_definition(self, name: str) -> Optional[dict]: + """获取工具定义(用于 LLM)""" + return self._definitions.get(name) + + def list_tools(self) -> list[ToolMetadata]: + """列出所有已注册工具""" + return [meta for _, meta in self._tools.values()] + + def check_permission(self, tool_name: str, user_role: str) -> bool: + """检查用户权限""" + if tool_name not in self._tools: + return False + _, metadata = self._tools[tool_name] + return user_role in metadata.allowed_roles + + def need_approval(self, tool_name: str) -> bool: + """判断是否需要审批""" + if tool_name not in self._tools: + return False + _, metadata = self._tools[tool_name] + return metadata.require_approval diff --git a/agent/app/agent/tools/sandbox/sandbox.py b/agent/app/agent/tools/sandbox/sandbox.py new file mode 100644 index 0000000..764d3ea --- /dev/null +++ b/agent/app/agent/tools/sandbox/sandbox.py @@ -0,0 +1,283 @@ +""" +沙盒执行环境 - 在项目内构建,不依赖 Docker +提供安全的代码执行环境 +""" +import subprocess +import tempfile +import os +import shutil +import resource +import signal +import threading +from typing import Dict, Any, Optional +from pathlib import Path + + +class SandboxConfig: + """沙盒配置""" + # 资源限制 + MAX_MEMORY_MB = 256 # 最大内存 (MB) + MAX_CPU_PERCENT = 50 # 最大 CPU 百分比 + MAX_EXECUTION_TIME = 30 # 最大执行时间 (秒) + MAX_OUTPUT_SIZE = 1024 * 1024 # 最大输出大小 (bytes) + + +class Sandbox: + """ + 沙盒执行器 - 使用 subprocess 隔离执行 + + 安全特性: + - 内存限制 + - CPU限制 + - 超时控制 + - 网络隔离(可选) + - 临时文件隔离 + """ + + def __init__(self, config: Optional[SandboxConfig] = None): + self.config = config or SandboxConfig() + self.temp_dir = None + + def _setup_temp_dir(self) -> str: + """创建临时目录""" + self.temp_dir = tempfile.mkdtemp(prefix="sandbox_") + return self.temp_dir + + def _cleanup(self): + """清理临时目录""" + if self.temp_dir and os.path.exists(self.temp_dir): + try: + shutil.rmtree(self.temp_dir) + except Exception as e: + print(f"Cleanup error: {e}") + + def execute( + self, + code: str, + language: str = "python", + timeout: Optional[int] = None + ) -> Dict[str, Any]: + """ + 在沙盒中执行代码 + + Args: + code: 要执行的代码 + language: 语言类型 (python, javascript) + timeout: 超时时间(秒) + + Returns: + 执行结果 + """ + timeout = timeout or self.config.MAX_EXECUTION_TIME + + self._setup_temp_dir() + + try: + if language == "python": + return self._execute_python(code, timeout) + elif language == "javascript": + return self._execute_javascript(code, timeout) + else: + return { + "success": False, + "error": f"Unsupported language: {language}" + } + except Exception as e: + return { + "success": False, + "error": str(e) + } + finally: + self._cleanup() + + def _execute_python(self, code: str, timeout: int) -> Dict[str, Any]: + """执行 Python 代码""" + # 创建临时文件 + temp_file = os.path.join(self.temp_dir, "code.py") + with open(temp_file, "w", encoding="utf-8") as f: + f.write(code) + + try: + # 构建命令 + cmd = ["python", temp_file] + + # 执行 + result = subprocess.run( + cmd, + capture_output=True, + timeout=timeout, + cwd=self.temp_dir, # 限制工作目录 + env=self._get_restricted_env(), # 限制环境变量 + ) + + # 检查输出大小 + stdout = result.stdout.decode("utf-8", errors="replace") + stderr = result.stderr.decode("utf-8", errors="replace") + + if len(stdout) > self.config.MAX_OUTPUT_SIZE: + stdout = stdout[:self.config.MAX_OUTPUT_SIZE] + "\n... (output truncated)" + + return { + "success": result.returncode == 0, + "output": stdout, + "error": stderr if result.returncode != 0 else None, + "exit_code": result.returncode + } + + except subprocess.TimeoutExpired: + return { + "success": False, + "error": f"Execution timeout ({timeout}s)", + "output": None + } + + def _execute_javascript(self, code: str, timeout: int) -> Dict[str, Any]: + """执行 JavaScript 代码""" + temp_file = os.path.join(self.temp_dir, "code.js") + with open(temp_file, "w", encoding="utf-8") as f: + f.write(code) + + try: + # 尝试使用 node + cmd = ["node", temp_file] + + result = subprocess.run( + cmd, + capture_output=True, + timeout=timeout, + cwd=self.temp_dir, + ) + + stdout = result.stdout.decode("utf-8", errors="replace") + stderr = result.stderr.decode("utf-8", errors="replace") + + return { + "success": result.returncode == 0, + "output": stdout, + "error": stderr if result.returncode != 0 else None, + "exit_code": result.returncode + } + + except subprocess.TimeoutExpired: + return { + "success": False, + "error": f"Execution timeout ({timeout}s)", + "output": None + } + except FileNotFoundError: + return { + "success": False, + "error": "Node.js not installed", + "output": None + } + + def _get_restricted_env(self) -> Dict[str, str]: + """ + 获取受限的环境变量 + 移除敏感变量,保留必要的 PATH + """ + # 保留 PATH,移除其他敏感变量 + safe_env = { + "PATH": os.environ.get("PATH", "/usr/bin:/bin"), + "LANG": "en_US.UTF-8", + "HOME": self.temp_dir, + "TMPDIR": self.temp_dir, + } + + # 移除可能不安全的变量 + unsafe_vars = [ + "PYTHONPATH", + "PYTHONHOME", + "LD_PRELOAD", + "LD_LIBRARY_PATH", + ] + + for var in unsafe_vars: + if var in os.environ: + del os.environ[var] + + return safe_env + + +class SafeEval: + """ + 安全求值器 - 用于简单表达式计算 + 比沙盒更轻量,适用于不需要完全隔离的场景 + """ + + # 安全函数白名单 + SAFE_BUILTINS = { + "abs": abs, + "min": min, + "max": max, + "sum": sum, + "len": len, + "round": round, + "pow": pow, + "print": print, + "str": str, + "int": int, + "float": float, + "bool": bool, + "list": list, + "dict": dict, + "tuple": tuple, + "set": set, + "range": range, + "enumerate": enumerate, + "zip": zip, + "map": map, + "filter": filter, + "sorted": sorted, + "reversed": reversed, + } + + # 安全数学常量 + SAFE_MATH = { + "pi": 3.14159265359, + "e": 2.71828182846, + } + + @classmethod + def eval(cls, expression: str) -> Any: + """ + 安全地求值表达式 + + Args: + expression: 数学表达式 + + Returns: + 计算结果 + """ + # 预处理表达式 + expression = expression.replace("sqrt", "**0.5") + + # 构建安全命名空间 + safe_namespace = { + **cls.SAFE_BUILTINS, + **cls.SAFE_MATH, + "__builtins__": {} # 禁用__builtins__ + } + + try: + result = eval(expression, safe_namespace) + return result + except Exception as e: + raise ValueError(f"Evaluation error: {e}") + + +# 全局沙盒实例 +sandbox = Sandbox() + + +# 装饰器:快速将函数封装为沙盒执行 +def sandboxed(timeout: int = 30): + """装饰器:为函数添加沙盒执行能力""" + def decorator(func): + def wrapper(code: str, *args, **kwargs): + result = sandbox.execute(code, timeout=timeout) + if not result["success"]: + raise RuntimeError(result.get("error", "Execution failed")) + return result["output"] + return wrapper + return decorator diff --git a/agent/app/api/routes.py b/agent/app/api/routes.py new file mode 100644 index 0000000..9dccef7 --- /dev/null +++ b/agent/app/api/routes.py @@ -0,0 +1,149 @@ +""" +API 路由定义 +""" +from typing import Optional + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel + +from app.agent.core.agent import AgentManager +from app.security.approval import ApprovalService + + +router = APIRouter() + +# 全局依赖(实际应该注入) +_agent_manager: Optional[AgentManager] = None +_approval_service: Optional[ApprovalService] = None + + +def get_agent_manager() -> AgentManager: + """获取 Agent 管理器""" + # 这里应该从 app.state 获取 + from app.main import agent_manager + if agent_manager is None: + raise HTTPException(status_code=503, detail="Agent service not initialized") + return agent_manager + + +def get_approval_service() -> ApprovalService: + """获取审批服务""" + global _approval_service + if _approval_service is None: + _approval_service = ApprovalService() + return _approval_service + + +# ==================== 请求/响应模型 ==================== + +class ChatRequest(BaseModel): + """聊天请求""" + agent_id: str + message: str + session_id: str = "" + context: dict = {} + + +class ChatResponse(BaseModel): + """聊天响应""" + reply: str + session_id: str + tools_used: list[str] = [] + metadata: dict = {} + + +class ApprovalRequest(BaseModel): + """审批请求""" + request_id: str + tool_name: str + params: dict + reason: str + approved: bool + + +# ==================== API 端点 ==================== + +@router.post("/chat", response_model=ChatResponse) +async def chat( + request: ChatRequest, + agent_manager: AgentManager = Depends(get_agent_manager) +): + """处理 Agent 聊天请求""" + try: + # 生成会话ID + if not request.session_id: + import uuid + request.session_id = str(uuid.uuid4()) + + # 执行 Agent + result = await agent_manager.execute( + agent_id=request.agent_id, + message=request.message, + session_id=request.session_id, + context=request.context + ) + + return ChatResponse( + reply=result.get("reply", ""), + session_id=request.session_id, + tools_used=result.get("tools_used", []), + metadata=result.get("metadata", {}) + ) + + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Agent execution failed: {str(e)}") + + +@router.post("/tool/request") +async def request_tool_execution( + request: dict, + approval_service: ApprovalService = Depends(get_approval_service) +): + """请求执行工具(需要审批)""" + tool_name = request.get("tool_name") + params = request.get("params", {}) + user_id = request.get("user_id", "unknown") + agent_id = request.get("agent_id") + reason = request.get("reason", "") + + # 创建审批请求 + request_id = await approval_service.request_approval( + tool_name=tool_name, + params=params, + user_id=user_id, + agent_id=agent_id or "", + reason=reason + ) + + return { + "request_id": request_id, + "status": "pending" + } + + +@router.get("/tools") +async def list_tools(agent_manager: AgentManager = Depends(get_agent_manager)): + """列出所有可用工具""" + tools = agent_manager.list_tools() + return {"tools": [tool.dict() for tool in tools]} + + +@router.get("/agents") +async def list_agents(agent_manager: AgentManager = Depends(get_agent_manager)): + """列出所有已加载的 Agent""" + agents = agent_manager.list_agents() + return {"agents": agents} + + +@router.get("/agent/{agent_id}") +async def get_agent( + agent_id: str, + agent_manager: AgentManager = Depends(get_agent_manager) +): + """获取特定 Agent 信息""" + agent_info = agent_manager.get_agent_info(agent_id) + if not agent_info: + raise HTTPException(status_code=404, detail="Agent not found") + return agent_info diff --git a/agent/app/llm/factory.py b/agent/app/llm/factory.py new file mode 100644 index 0000000..d47989d --- /dev/null +++ b/agent/app/llm/factory.py @@ -0,0 +1,63 @@ +""" +LLM 工厂 - 创建不同提供商的 LLM 实例 +""" +from typing import Optional +from langchain_openai import ChatOpenAI +from langchain_anthropic import ChatAnthropic + + +class LLMFactory: + """LLM 工厂类""" + + def __init__( + self, + provider: str = "openai", + openai_api_key: Optional[str] = None, + anthropic_api_key: Optional[str] = None, + model: str = "gpt-3.5-turbo", + temperature: float = 0.7, + max_tokens: int = 2000 + ): + self.provider = provider + self.openai_api_key = openai_api_key + self.anthropic_api_key = anthropic_api_key + self.model = model + self.temperature = temperature + self.max_tokens = max_tokens + + self._llm = None + + def get_llm(self): + """获取 LLM 实例""" + if self._llm is not None: + return self._llm + + if self.provider == "openai": + self._llm = ChatOpenAI( + model=self.model, + temperature=self.temperature, + max_tokens=self.max_tokens, + api_key=self.openai_api_key + ) + elif self.provider == "anthropic": + self._llm = ChatAnthropic( + model=self.model, + temperature=self.temperature, + max_tokens=self.max_tokens, + anthropic_api_key=self.anthropic_api_key + ) + else: + raise ValueError(f"Unsupported provider: {self.provider}") + + return self._llm + + def set_model(self, model: str): + """设置模型""" + self.model = model + self._llm = None # 重置 LLM 实例 + + def set_temperature(self, temperature: float): + """设置温度""" + self.temperature = temperature + if self._llm: + self._llm.temperature = temperature diff --git a/agent/app/main.py b/agent/app/main.py new file mode 100644 index 0000000..2c6f0ce --- /dev/null +++ b/agent/app/main.py @@ -0,0 +1,84 @@ +""" +X-Agents Python Agent Service +智能体引擎服务入口 +""" +import os +from contextlib import asynccontextmanager + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from app.api import routes +from app.agent.core.agent import AgentManager +from app.security.audit import AuditLogger + + +# 全局组件 +agent_manager: AgentManager = None +audit_logger: AuditLogger = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """应用生命周期管理""" + global agent_manager, audit_logger + + # 启动时初始化 + audit_logger = AuditLogger() + + # 初始化 Agent 管理器 + agent_manager = AgentManager( + llm_provider=os.getenv("LLM_PROVIDER", "openai"), + openai_api_key=os.getenv("OPENAI_API_KEY"), + anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"), + ) + + # 加载 Agent 配置 + await agent_manager.load_agents() + + print("Agent service started successfully") + + yield + + # 关闭时清理 + print("Agent service shutting down") + + +# 创建 FastAPI 应用 +app = FastAPI( + title="X-Agents Agent Service", + description="AI Agent Engine for X-Agents Platform", + version="1.0.0", + lifespan=lifespan, +) + +# CORS 中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 注册路由 +app.include_router(routes.router, prefix="/agent", tags=["Agent"]) + + +@app.get("/health") +async def health_check(): + """健康检查""" + return { + "status": "healthy", + "service": "agent", + "version": "1.0.0" + } + + +@app.get("/") +async def root(): + """根路径""" + return { + "message": "X-Agents Agent Service", + "docs": "/docs" + } diff --git a/agent/app/security/approval.py b/agent/app/security/approval.py new file mode 100644 index 0000000..ad6ee23 --- /dev/null +++ b/agent/app/security/approval.py @@ -0,0 +1,104 @@ +""" +审批服务 - 处理工具执行的审批流程 +""" +import uuid +from datetime import datetime +from typing import Any, Dict, Optional +from enum import Enum + + +class ApprovalStatus(Enum): + """审批状态""" + PENDING = "pending" + APPROVED = "approved" + REJECTED = "rejected" + + +class ApprovalService: + """审批服务""" + + def __init__(self): + # 待审批队列 + self.pending: Dict[str, dict] = {} + # 审批结果 + self.results: Dict[str, ApprovalStatus] = {} + + async def request_approval( + self, + tool_name: str, + params: dict, + user_id: str, + agent_id: str, + reason: str + ) -> str: + """ + 请求审批 + + Returns: + request_id: 审批请求ID + """ + request_id = str(uuid.uuid4()) + + request = { + "request_id": request_id, + "tool_name": tool_name, + "params": params, + "user_id": user_id, + "agent_id": agent_id, + "reason": reason, + "status": ApprovalStatus.PENDING, + "created_at": datetime.now().isoformat() + } + + self.pending[request_id] = request + self.results[request_id] = ApprovalStatus.PENDING + + # TODO: 通知 Go 后端有新审批 + + return request_id + + async def check_approval(self, request_id: str, timeout: int = 300) -> bool: + """ + 检查审批状态 + + Args: + request_id: 审批请求ID + timeout: 超时时间(秒) + + Returns: + 是否已批准 + """ + import asyncio + + start = datetime.now() + + while (datetime.now() - start).seconds < timeout: + status = self.results.get(request_id) + + if status == ApprovalStatus.APPROVED: + return True + elif status == ApprovalStatus.REJECTED: + return False + + await asyncio.sleep(1) + + raise TimeoutError("Approval request timeout") + + async def approve(self, request_id: str): + """批准请求""" + if request_id in self.pending: + self.pending[request_id]["status"] = ApprovalStatus.APPROVED + self.results[request_id] = ApprovalStatus.APPROVED + + async def reject(self, request_id: str): + """拒绝请求""" + if request_id in self.pending: + self.pending[request_id]["status"] = ApprovalStatus.REJECTED + self.results[request_id] = ApprovalStatus.REJECTED + + def get_pending(self) -> list[dict]: + """获取待审批列表""" + return [ + req for req in self.pending.values() + if req["status"] == ApprovalStatus.PENDING + ] diff --git a/agent/app/security/audit.py b/agent/app/security/audit.py new file mode 100644 index 0000000..5392942 --- /dev/null +++ b/agent/app/security/audit.py @@ -0,0 +1,81 @@ +""" +审计日志 - 记录所有 Agent 操作 +""" +import json +from datetime import datetime +from typing import Any, Dict, Optional +from pathlib import Path + + +class AuditLogger: + """审计日志记录器""" + + def __init__(self, log_file: str = "audit.log"): + self.log_file = log_file + + def log( + self, + action: str, + agent_id: str = "", + session_id: str = "", + user_id: str = "", + details: Dict[str, Any] = None, + result: str = "success" + ): + """记录审计日志""" + entry = { + "timestamp": datetime.now().isoformat(), + "action": action, + "agent_id": agent_id, + "session_id": session_id, + "user_id": user_id, + "details": details or {}, + "result": result + } + + # 写入文件 + self._write_log(entry) + + # TODO: 发送到 Go 后端 + + def log_tool_execution( + self, + tool_name: str, + params: Dict[str, Any], + user_id: str, + agent_id: str, + approved: bool, + result: Any + ): + """记录工具执行""" + self.log( + action="tool_execution", + agent_id=agent_id, + user_id=user_id, + details={ + "tool_name": tool_name, + "params": params, + "approved": approved, + "result_preview": str(result)[:200] if result else None + }, + result="approved" if approved else "pending_approval" + ) + + def log_error(self, action: str, error: str, **kwargs): + """记录错误""" + self.log( + action=action, + details={"error": error, **kwargs}, + result="error" + ) + + def _write_log(self, entry: dict): + """写入日志文件""" + try: + log_path = Path(self.log_file) + log_path.parent.mkdir(parents=True, exist_ok=True) + + with open(log_path, "a", encoding="utf-8") as f: + f.write(json.dumps(entry, ensure_ascii=False) + "\n") + except Exception as e: + print(f"Failed to write audit log: {e}") diff --git a/agent/requirements.txt b/agent/requirements.txt new file mode 100644 index 0000000..d3054ee --- /dev/null +++ b/agent/requirements.txt @@ -0,0 +1,19 @@ +# 核心依赖 +fastapi>=0.100.0 +uvicorn>=0.20.0 +pydantic>=2.0.0 +httpx>=0.24.0 +aiohttp>=3.8.0 +python-multipart>=0.0.5 + +# LLM 支持 +openai>=1.0.0 +anthropic>=0.18.0 +langchain-core>=0.1.0 +langchain-openai>=0.0.2 + +# 可选:向量数据库 +chromadb>=0.4.0 + +# Redis +redis>=4.5.0 diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml new file mode 100644 index 0000000..340dc1f --- /dev/null +++ b/docker-compose.dev.yml @@ -0,0 +1,33 @@ +services: + # MySQL 数据库 + x-agent-mysql: + image: mysql:8.0 + container_name: x-agents-mysql + environment: + MYSQL_ROOT_PASSWORD: root + MYSQL_DATABASE: x_agents + volumes: + - mysql-data:/var/lib/mysql + ports: + - "6036:3306" + healthcheck: + test: ["CMD", "mysqladmin", "ping", "-h", "localhost"] + interval: 10s + timeout: 5s + retries: 5 + restart: unless-stopped + command: --default-authentication-plugin=mysql_native_password + + # Redis + x-agent-redis: + image: redis:7-alpine + container_name: x-agents-redis + ports: + - "6037:6379" + volumes: + - redis-data:/data + restart: unless-stopped + +volumes: + mysql-data: + redis-data: diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..5e40a50 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,34 @@ +services: + # 只启动数据库,Go 和 Python 在本地运行 + db: + image: mysql:8.0 + container_name: x-agents-mysql + environment: + MYSQL_ROOT_PASSWORD: root + MYSQL_DATABASE: x_agents + MYSQL_USER: xagents + MYSQL_PASSWORD: xagents123 + volumes: + - mysql-data:/var/lib/mysql + ports: + - "3306:3306" + healthcheck: + test: ["CMD", "mysqladmin", "ping", "-h", "localhost"] + interval: 10s + timeout: 5s + retries: 5 + restart: unless-stopped + command: --default-authentication-plugin=mysql_native_password + + redis: + image: redis:7-alpine + container_name: x-agents-redis + ports: + - "6379:6379" + volumes: + - redis-data:/data + restart: unless-stopped + +volumes: + mysql-data: + redis-data: diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md new file mode 100644 index 0000000..7ae26a7 --- /dev/null +++ b/docs/ARCHITECTURE.md @@ -0,0 +1,835 @@ +# X-Agents 智能体平台架构设计 + +## 一、整体架构 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ 用户层 │ +│ Web / App / API Consumer │ +└─────────────────────────────────┬───────────────────────────────────────────┘ + │ HTTP / WebSocket + ▼ +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Go API Gateway │ +│ ┌─────────────────────────────────────────────────────────────────────────┐│ +│ │ • HTTP Server (Gin) • 认证鉴权 (JWT) ││ +│ │ • 路由管理 • 限流熔断 ││ +│ │ • 业务逻辑 • 日志监控 ││ +│ │ • 数据库操作 • 权限控制 (RBAC) ││ +│ └─────────────────────────────────────────────────────────────────────────┘│ +└─────────────────────────────────┬───────────────────────────────────────────┘ + │ HTTP JSON API + ▼ +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Python Agent Engine │ +│ ┌─────────────────────────────────────────────────────────────────────────┐│ +│ │ • FastAPI Server • Agent Core (LangChain/AutoGen) ││ +│ │ • LLM Adapter • Tool Registry (白名单) ││ +│ │ • Memory Manager • Sandbox Executor (沙盒) ││ +│ │ • RAG Pipeline • Audit Logger ││ +│ └─────────────────────────────────────────────────────────────────────────┘│ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +--- + +## 二、系统分层 + +### 2.1 Go 后端层 (server/) + +``` +server/ # Go API Gateway 服务 +├── cmd/api/ # 程序入口 +│ └── main.go +├── internal/ +│ ├── config/ # 配置管理 +│ │ └── config.go +│ ├── handler/ # HTTP处理器 +│ │ ├── auth_handler.go # 认证接口 +│ │ ├── chat_handler.go # 聊天接口 +│ │ └── approval_handler.go # 审批接口 +│ ├── service/ # 业务逻辑 +│ │ ├── auth_service.go +│ │ ├── chat_service.go +│ │ └── approval_service.go +│ ├── repository/ # 数据访问层 +│ │ ├── user_repo.go +│ │ ├── agent_repo.go +│ │ └── audit_repo.go +│ ├── middleware/ # 中间件 +│ │ └── auth.go +│ └── model/ # 数据模型 +│ ├── user.go +│ ├── agent.go +│ └── audit.go +├── config/ # 配置文件 +│ └── config.yaml +├── Dockerfile +├── go.mod +└── go.sum +``` + +### 2.2 Python Agent 层 (agent/) + +``` +agent/ # Python Agent Engine +├── app/ +│ ├── main.py # FastAPI入口 +│ ├── api/ +│ │ └── routes.py # API路由 +│ ├── agent/ +│ │ ├── core/ +│ │ │ ├── agent.py # Agent管理器 +│ │ │ └── executor.py # Agent执行器 +│ │ ├── tools/ +│ │ │ ├── registry.py # 工具注册表(白名单) +│ │ │ └── impl/ # 工具实现 +│ │ │ ├── search.py +│ │ │ ├── calculator.py +│ │ │ └── time_tool.py +│ │ └── memory/ +│ │ └── session.py # 会话管理 +│ ├── llm/ +│ │ └── factory.py # LLM工厂 +│ └── security/ +│ ├── audit.py # 审计日志 +│ └── approval.py # 审批服务 +├── requirements.txt +├── Dockerfile +└── pyproject.toml +``` + +### 2.3 根目录结构 + +``` +X-Agents/ +├── server/ # Go API Gateway +├── agent/ # Python Agent Engine +├── web/ # 前端 (Vue.js) +├── docs/ +│ └── ARCHITECTURE.md # 架构文档 +├── docker-compose.yml # 容器编排 +├── .env.example # 环境变量模板 +└── README.md +│ ├── service/ # 业务逻辑 +│ │ ├── chat_service.go +│ │ ├── agent_service.go +│ │ └── approval_service.go # 审批服务 +│ ├── repository/ # 数据访问层 +│ │ ├── user_repo.go +│ │ ├── agent_repo.go +│ │ └── audit_repo.go +│ ├── middleware/ # 中间件 +│ │ ├── auth.go # 认证中间件 +│ │ ├── rbac.go # 权限中间件 +│ │ └── audit.go # 审计中间件 +│ ├── client/ # 外部服务客户端 +│ │ └── python_client.go # Python服务HTTP客户端 +│ └── model/ # 数据模型 +│ ├── user.go +│ ├── agent.go +│ └── audit.go +├── pkg/ +│ ├── utils/ # 工具函数 +│ └── errors/ # 错误定义 +└── go.mod +``` + +### 2.2 Python AI 层 (智能逻辑) + +``` +python/ +├── app/ +│ ├── main.py # FastAPI入口 +│ ├── api/ +│ │ ├── routes.py # 路由定义 +│ │ └── dependencies.py # 依赖注入 +│ ├── agent/ +│ │ ├── core/ +│ │ │ ├── agent.py # Agent核心 +│ │ │ ├── executor.py # 执行器 +│ │ │ └── memory.py # 记忆管理 +│ │ ├── tools/ +│ │ │ ├── registry.py # 工具注册表 +│ │ │ ├── base.py # 工具基类 +│ │ │ ├── security.py # 安全检查 +│ │ │ └── impl/ # 具体工具实现 +│ │ │ ├── search.py +│ │ │ ├── calculator.py +│ │ │ ├── database.py +│ │ │ └── sandbox.py # 沙盒执行 +│ │ └── sandbox/ +│ │ ├── docker_sandbox.py +│ │ └── wasm_sandbox.py +│ ├── llm/ +│ │ ├── factory.py # LLM工厂 +│ │ ├── openai_adapter.py +│ │ ├── anthropic_adapter.py +│ │ └── base.py +│ ├── rag/ +│ │ ├── vector_store.py # 向量存储 +│ │ ├── retriever.py # 检索器 +│ │ └── pipeline.py # RAG流程 +│ └── security/ +│ ├── permission.py # 权限检查 +│ ├── approval.py # 审批管理 +│ └── audit.py # 安全审计 +├── requirements.txt +└── pyproject.toml +``` + +--- + +## 三、通信机制 + +### 3.1 HTTP API 通信 + +``` +┌──────────────────┐ HTTP POST ┌──────────────────┐ +│ │ ─────────────▶ │ │ +│ Go Service │ JSON Request │ Python Service │ +│ (Port: 8080) │ │ (Port: 8081) │ +│ │ ◀───────────── │ │ +└──────────────────┘ JSON Response └──────────────────┘ +``` + +#### 接口设计 + +**1. Agent 聊天接口** + +``` +POST /api/v1/agent/chat +Content-Type: application/json +Authorization: Bearer + +Request: +{ + "agent_id": "agent_001", + "message": "帮我查询用户数据", + "session_id": "session_xxx", + "context": {} // 额外上下文 +} + +Response: +{ + "reply": "查询结果...", + "session_id": "session_xxx", + "tools_used": ["query_database"], + "metadata": {} +} +``` + +**2. 工具执行审批接口** + +``` +POST /api/v1/tool/approve +Request: +{ + "request_id": "req_001", + "tool_name": "execute_sql", + "params": {"sql": "SELECT * FROM users"}, + "reason": "用户查询自己的订单", + "approved": true // true=批准, false=拒绝 +} +``` + +**3. 工具执行状态查询** + +``` +GET /api/v1/tool/request/{request_id} +Response: +{ + "status": "pending|approved|rejected|completed", + "tool_name": "execute_sql", + "created_at": "2024-01-01T00:00:00Z", + "result": null // 如果已完成 +} +``` + +### 3.2 Go → Python 客户端 + +```go +// internal/client/python_client.go + +package client + +type PythonAgentClient struct { + baseURL string + client *http.Client +} + +type ChatRequest struct { + AgentID string `json:"agent_id"` + Message string `json:"message"` + SessionID string `json:"session_id"` + Context map[string]interface{} `json:"context"` +} + +type ChatResponse struct { + Reply string `json:"reply"` + SessionID string `json:"session_id"` + ToolsUsed []string `json:"tools_used"` + Metadata map[string]interface{} `json:"metadata"` +} + +// Chat 调用Python Agent服务 +func (c *PythonAgentClient) Chat(ctx context.Context, req ChatRequest) (*ChatResponse, error) { + // 1. 构建请求 + // 2. 添加超时 + // 3. 发送请求 + // 4. 处理响应 + // 5. 错误处理 +} +``` + +--- + +## 四、沙盒安全机制 + +### 4.1 安全架构总览 + +``` +┌─────────────────────────────────────────────────────────────────────────┐ +│ 安全控制层 │ +│ │ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌──────────┐ │ +│ │ 权限管理 │ │ 工具分级 │ │ 人工审批 │ │ 审计日志 │ │ +│ │ (RBAC) │ │ (白名单) │ │ (Approval) │ │ (Audit) │ │ +│ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ └────┬─────┘ │ +└─────────┼─────────────────┼─────────────────┼─────────────────┼────────┘ + │ │ │ │ + ▼ ▼ ▼ ▼ +┌─────────────────────────────────────────────────────────────────────────┐ +│ Agent 执行层 │ +│ │ +│ User Request ─▶ Permission Check ─▶ Tool Lookup ─▶ Execute ─▶ Result │ +│ │ │ │ │ +│ ▼ ▼ ▼ │ +│ [Need Approval] ──▶ [Pending Queue] ──▶ [Notify] │ +└─────────────────────────────────────────────────────────────────────────┘ +``` + +### 4.2 工具安全等级 + +```python +# python/app/agent/tools/security.py + +from enum import Enum +from dataclasses import dataclass +from typing import List, Callable + +class SecurityLevel(Enum): + """工具安全等级""" + SAFE = "safe" # 安全操作:搜索、计算、读取公开数据 + REVIEW = "review" # 需要审核:修改数据、发送消息 + DANGER = "danger" # 危险操作:删除数据、执行代码、敏感API + +@dataclass +class ToolMetadata: + """工具元数据""" + name: str + description: str + security_level: SecurityLevel + require_approval: bool # 是否需要人工审批 + allowed_roles: List[str] # 允许调用的角色 + rate_limit: int # 调用频率限制 + timeout: int # 超时时间(秒) + +class ToolSecurity: + """工具安全管理""" + + # 安全等级阈值 + APPROVAL_THRESHOLD = SecurityLevel.REVIEW + + @staticmethod + def check_permission(tool: ToolMetadata, user_role: str) -> bool: + """检查用户权限""" + if user_role in tool.allowed_roles: + return True + return False + + @staticmethod + def need_approval(tool: ToolMetadata) -> bool: + """判断是否需要审批""" + return tool.security_level.value >= ToolSecurity.APPROVAL_THRESHOLD.value +``` + +### 4.3 工具注册与执行 + +```python +# python/app/agent/tools/registry.py + +from typing import Dict, Callable, Any +from .security import ToolMetadata, SecurityLevel + +class ToolRegistry: + """工具注册表 - 白名单机制""" + + def __init__(self): + self._tools: Dict[str, tuple[Callable, ToolMetadata]] = {} + + def register( + self, + name: str, + func: Callable, + security_level: SecurityLevel = SecurityLevel.SAFE, + require_approval: bool = False, + allowed_roles: List[str] = None, + description: str = "" + ): + """注册工具到白名单""" + metadata = ToolMetadata( + name=name, + description=description, + security_level=security_level, + require_approval=require_approval or security_level == SecurityLevel.REVIEW, + allowed_roles=allowed_roles or ["user", "admin"], + rate_limit=100, + timeout=30 + ) + self._tools[name] = (func, metadata) + + def get_tool(self, name: str) -> tuple[Callable, ToolMetadata]: + """获取工具(必须在白名单中)""" + if name not in self._tools: + raise ValueError(f"Tool '{name}' not found in whitelist") + return self._tools[name] + + def list_tools(self) -> List[ToolMetadata]: + """列出所有可用工具""" + return [meta for _, meta in self._tools.values()] +``` + +### 4.4 沙盒执行 + +```python +# python/app/agent/tools/sandbox/docker_sandbox.py + +import subprocess +import tempfile +import shutil +import os +from typing import Any, Dict + +class DockerSandbox: + """Docker沙盒执行环境""" + + def __init__(self, image: str = "python-sandbox:latest", timeout: int = 30): + self.image = image + self.timeout = timeout + + def execute(self, code: str, language: str = "python") -> Dict[str, Any]: + """在沙盒中执行代码""" + + # 1. 创建临时文件 + with tempfile.NamedTemporaryFile( + mode='w', + suffix=f'.{language}', + delete=False + ) as f: + f.write(code) + temp_path = f.name + + try: + # 2. Docker容器执行 + result = subprocess.run( + [ + "docker", "run", + "--rm", + "--network", "none", # 断网 + "--memory", "256m", # 内存限制 + "--cpus", "0.5", # CPU限制 + "-v", f"{temp_path}:/code/{os.path.basename(temp_path)}", + self.image, + "python", f"/code/{os.path.basename(temp_path)}" + ], + capture_output=True, + timeout=self.timeout + ) + + return { + "success": result.returncode == 0, + "output": result.stdout.decode(), + "error": result.stderr.decode() + } + + except subprocess.TimeoutExpired: + return { + "success": False, + "output": "", + "error": "Execution timeout" + } + finally: + # 3. 清理临时文件 + os.unlink(temp_path) + +# 使用示例 +@sandbox.execute +def execute_code(code: str) -> str: + """安全执行用户代码""" + pass +``` + +### 4.5 Human in the Loop (人工审批) + +```python +# python/app/security/approval.py + +from enum import Enum +from dataclasses import dataclass +from typing import Optional +from datetime import datetime +import asyncio + +class ApprovalStatus(Enum): + PENDING = "pending" + APPROVED = "approved" + REJECTED = "rejected" + +@dataclass +class ApprovalRequest: + """审批请求""" + request_id: str + tool_name: str + params: dict + user_id: str + reason: str + status: ApprovalStatus + created_at: datetime + reviewed_at: Optional[datetime] + reviewed_by: Optional[str] + +class ApprovalService: + """审批服务""" + + def __init__(self, http_client): + self.client = http_client + self.pending: Dict[str, ApprovalRequest] = {} + + async def request_approval( + self, + tool_name: str, + params: dict, + user_id: str, + reason: str + ) -> str: + """请求审批""" + request_id = generate_uuid() + + approval_req = ApprovalRequest( + request_id=request_id, + tool_name=tool_name, + params=params, + user_id=user_id, + reason=reason, + status=ApprovalStatus.PENDING, + created_at=datetime.now(), + reviewed_at=None, + reviewed_by=None + ) + + self.pending[request_id] = approval_req + + # 通知Go后端有新审批 + await self.notify_go_service(approval_req) + + return request_id + + async def wait_for_approval(self, request_id: str, timeout: int = 300) -> bool: + """等待审批结果""" + start = datetime.now() + + while (datetime.now() - start).seconds < timeout: + if request_id in self.pending: + status = self.pending[request_id].status + if status == ApprovalStatus.APPROVED: + return True + elif status == ApprovalStatus.REJECTED: + return False + await asyncio.sleep(1) + + raise TimeoutError("Approval request timeout") +``` + +### 4.6 全链路审计 + +```python +# python/app/security/audit.py + +from datetime import datetime +from typing import Any, Dict +import json + +class AuditLogger: + """审计日志""" + + def __init__(self, log_file: str = "audit.log"): + self.log_file = log_file + + def log( + self, + action: str, + user_id: str, + agent_id: str, + details: Dict[str, Any], + result: str = "success" + ): + """记录审计日志""" + entry = { + "timestamp": datetime.now().isoformat(), + "action": action, + "user_id": user_id, + "agent_id": agent_id, + "details": details, + "result": result + } + + # 写入日志文件 + with open(self.log_file, 'a') as f: + f.write(json.dumps(entry) + '\n') + + # 发送到Go后端 + self.send_to_backend(entry) + + def log_tool_execution( + self, + user_id: str, + tool_name: str, + params: Dict[str, Any], + approved: bool, + result: Any + ): + """记录工具执行""" + self.log( + action="tool_execution", + user_id=user_id, + agent_id="", + details={ + "tool_name": tool_name, + "params": params, + "approved": approved, + "result_preview": str(result)[:100] + } + ) +``` + +--- + +## 五、权限模型 (Go端) + +### 5.1 用户角色 + +```go +// go/internal/model/user.go + +package model + +// 权限级别 +type PermissionLevel int + +const ( + PermissionRead PermissionLevel = 1 // 只读 + PermissionWrite PermissionLevel = 2 // 读写 + PermissionExecute PermissionLevel = 3 // 可执行工具 + PermissionAdmin PermissionLevel = 4 // 管理员 +) + +// 角色定义 +type Role struct { + ID string `json:"id"` + Name string `json:"name"` + Permissions []PermissionLevel `json:"permissions"` +} + +// 用户 +type User struct { + ID string `json:"id"` + Username string `json:"username"` + RoleID string `json:"role_id"` + Role *Role `json:"role,omitempty"` +} +``` + +### 5.2 Agent定义 + +```go +// go/internal/model/agent.go + +package model + +type Agent struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + OwnerID string `json:"owner_id"` + + // Agent能力配置 + Capabilities []string `json:"capabilities"` // 可用工具列表 + MemoryLimit int64 `json:"memory_limit"` // 内存限制 + Timeout int `json:"timeout"` // 超时时间 + + // 安全配置 + SecurityLevel SecurityLevel `json:"security_level"` + AllowDangerousTools bool `json:"allow_dangerous_tools"` + + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} +``` + +--- + +## 六、部署架构 + +### 6.1 Docker Compose + +```yaml +# docker-compose.yml + +version: '3.8' + +services: + # Go API 服务 + go-api: + build: ./go + ports: + - "8080:8080" + environment: + - DATABASE_URL=postgres://user:pass@db:5432/agents + - PYTHON_SERVICE_URL=http://python-agent:8081 + - JWT_SECRET=your-secret + depends_on: + - db + - python-agent + + # Python Agent 服务 + python-agent: + build: ./python + ports: + - "8081:8081" + environment: + - OPENAI_API_KEY=${OPENAI_API_KEY} + - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY} + volumes: + - ./python/app:/app + - /var/run/docker.sock:/var/run/docker.sock # 如果需要Docker沙盒 + + # 数据库 + db: + image: postgres:15 + environment: + POSTGRES_USER: user + POSTGRES_PASSWORD: pass + POSTGRES_DB: agents + volumes: + - db-data:/var/lib/postgresql/data + + # Redis (缓存/会话) + redis: + image: redis:7 + volumes: + - redis-data:/data + + # 向量数据库 (可选) + qdrant: + image: qdrant/qdrant + volumes: + - qdrant-data:/qdrant/storage + +volumes: + db-data: + redis-data: + qdrant-data: +``` + +--- + +## 七、开发流程 + +### 7.1 请求流程图 + +``` +┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ +│ 用户 │────▶│ Go │────▶│ Python │────▶│ LLM │────▶│ 返回 │ +│ 请求 │ │ 鉴权 │ │ Agent │ │ +Tools │ │ 结果 │ +└─────────┘ └─────────┘ └─────────┘ └─────────┘ └─────────┘ + │ │ + │ │ + ▼ ▼ + ┌─────────┐ ┌─────────┐ + │ 检查 │ │ 权限 │ + │ 权限 │ │ 检查 │ + └─────────┘ └─────────┘ + │ + ▼ + ┌─────────────────────┐ + │ 工具安全等级判断 │ + └─────────────────────┘ + │ + ┌─────────────────┼─────────────────┐ + │ │ │ + ▼ ▼ ▼ + ┌──────────┐ ┌──────────┐ ┌──────────┐ + │ Safe │ │ Review │ │ Danger │ + │ 直接执行 │ │ 等待审批 │ │ 拒绝执行 │ + └──────────┘ └──────────┘ └──────────┘ +``` + +### 7.2 目录结构总览 + +``` +X-Agents/ +├── docs/ +│ └── ARCHITECTURE.md # 本文档 +│ +├── go/ # Go 后端 +│ ├── cmd/ +│ ├── internal/ +│ ├── pkg/ +│ ├── go.mod +│ └── Dockerfile +│ +├── python/ # Python AI 层 +│ ├── app/ +│ │ ├── api/ +│ │ ├── agent/ +│ │ ├── llm/ +│ │ ├── rag/ +│ │ └── security/ +│ ├── requirements.txt +│ └── Dockerfile +│ +├── web/ # 前端 (Vue) +│ ├── src/ +│ └── package.json +│ +├── docker-compose.yml # 容器编排 +└── README.md +``` + +--- + +## 八、总结 + +### 架构核心原则 + +| 原则 | 实现方式 | +|------|----------| +| **分层治理** | Go负责业务/权限,Python负责AI逻辑 | +| **安全优先** | 工具分级+权限控制+人工审批+审计日志 | +| **通信简洁** | HTTP JSON API,后续可升级gRPC | +| **可扩展** | 模块化设计,支持多Agent/多Python服务 | +| **可观测** | 全链路日志+监控 | + +### 安全特性 + +- [x] 工具白名单机制 +- [x] 安全等级分级 (Safe/Review/Danger) +- [x] RBAC权限控制 +- [x] Human in the Loop 人工审批 +- [x] 沙盒执行环境 (Docker) +- [x] 全链路审计日志 + +--- + +*本文档将随项目开发持续更新* diff --git a/web/agents.html b/docs/agents.html similarity index 100% rename from web/agents.html rename to docs/agents.html diff --git a/web/dashboard.html b/docs/dashboard.html similarity index 100% rename from web/dashboard.html rename to docs/dashboard.html diff --git a/web/graph.html b/docs/graph.html similarity index 100% rename from web/graph.html rename to docs/graph.html diff --git a/server/Dockerfile b/server/Dockerfile new file mode 100644 index 0000000..9f14488 --- /dev/null +++ b/server/Dockerfile @@ -0,0 +1,32 @@ +# 构建阶段 +FROM golang:1.21-alpine AS builder + +WORKDIR /app + +# 安装依赖 +RUN apk add --no-cache git + +# 复制 go.mod 和 go.sum +COPY go.mod go.sum ./ +RUN go mod download + +# 复制源代码 +COPY . . + +# 构建 +RUN CGO_ENABLED=0 GOOS=linux go build -o /server ./cmd/api + +# 运行阶段 +FROM alpine:latest + +RUN apk --no-cache add ca-certificates + +WORKDIR /app + +# 复制构建产物 +COPY --from=builder /server . +COPY config/ ./config/ + +EXPOSE 8080 + +CMD ["./server"] diff --git a/server/api.exe b/server/api.exe new file mode 100644 index 0000000..0d06de7 Binary files /dev/null and b/server/api.exe differ diff --git a/server/cmd/api/main.go b/server/cmd/api/main.go new file mode 100644 index 0000000..7caef09 --- /dev/null +++ b/server/cmd/api/main.go @@ -0,0 +1,162 @@ +package main + +import ( + "bytes" + "io" + "log" + "os" + "path/filepath" + "time" + "x-agents/server/internal/config" + "x-agents/server/internal/handler" + "x-agents/server/internal/model" + "x-agents/server/internal/repository" + "x-agents/server/internal/service" + + "github.com/gin-gonic/gin" +) + +// Logger 日志记录器 +type Logger struct { + successLog *log.Logger + errorLog *log.Logger +} + +func NewLogger() *Logger { + // 创建日志目录 + today := time.Now().Format("2006-01-02") + logDir := filepath.Join("logs", today) + os.MkdirAll(logDir, 0755) + + // 成功日志 + successFile, _ := os.OpenFile(filepath.Join(logDir, "success.log"), os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + successLogger := log.New(successFile, "", log.Ldate|log.Ltime|log.Lshortfile) + + // 错误日志 + errorFile, _ := os.OpenFile(filepath.Join(logDir, "failure.log"), os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + errorLogger := log.New(errorFile, "", log.Ldate|log.Ltime|log.Lshortfile) + + return &Logger{ + successLog: successLogger, + errorLog: errorLogger, + } +} + +// LogRequest 记录请求 +func (l *Logger) LogRequest(method, path, body string, status int, duration time.Duration) { + entry := "[%s] %s %s %d %v" + + if status >= 400 { + l.errorLog.Printf(entry, method, path, body, status, duration) + } else { + l.successLog.Printf(entry, method, path, body, status, duration) + } +} + +var logger *Logger + +func main() { + // 初始化日志 + logger = NewLogger() + + // 1. 加载配置 + cfg := config.Load() + log.Printf("=== Server starting, port=%s ===", cfg.Port) + + // 2. 初始化数据库 + db, err := config.InitDB(cfg) + if err != nil { + log.Fatalf("Failed to connect database: %v", err) + } + + // 3. 自动迁移表 + db.AutoMigrate(&model.DatabaseInfo{}, &model.SubTableInfo{}) + + // 4. 初始化 Repository + dbRepo := repository.NewDatabaseRepository(db) + subTableRepo := repository.NewSubTableRepository(db) + + // 5. 初始化 Service + dbService := service.NewDatabaseService(dbRepo, subTableRepo) + subTableService := service.NewSubTableService(subTableRepo, dbRepo) + + // 6. 初始化 Handler + dbHandler := handler.NewDatabaseHandler(dbService) + subTableHandler := handler.NewSubTableHandler(subTableService) + systemHandler := handler.NewSystemHandler() + + // 7. 设置路由 + r := gin.New() + + // 添加日志和恢复中间件 + r.Use(gin.Logger()) + r.Use(gin.Recovery()) + + // 请求日志中间件 + r.Use(func(c *gin.Context) { + start := time.Now() + path := c.Request.URL.Path + + // 记录请求体 + var requestBody []byte + if c.Request.Method == "POST" || c.Request.Method == "PUT" { + requestBody, _ = io.ReadAll(c.Request.Body) + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + } + + c.Next() + + // 记录响应日志 + latency := time.Since(start) + status := c.Writer.Status() + + // 使用日志系统记录 + logger.LogRequest(c.Request.Method, path, string(requestBody), status, latency) + }) + + // CORS 中间件 + r.Use(func(c *gin.Context) { + c.Writer.Header().Set("Access-Control-Allow-Origin", "*") + c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + + if c.Request.Method == "OPTIONS" { + c.AbortWithStatus(204) + return + } + + c.Next() + }) + + // 数据库管理模块 + databaseGroup := r.Group("/database") + { + databaseGroup.GET("/list", dbHandler.List) + databaseGroup.GET("/:id", dbHandler.GetByID) + databaseGroup.POST("/check", dbHandler.Check) + databaseGroup.POST("/add", dbHandler.Create) + databaseGroup.PUT("/:id", dbHandler.Update) + databaseGroup.DELETE("/:id", dbHandler.Delete) + } + + // 子表映射管理模块 + subTableGroup := r.Group("/sub-table") + { + subTableGroup.POST("/add", subTableHandler.Create) + subTableGroup.GET("/:id", subTableHandler.GetByID) + subTableGroup.GET("/database/:database_id", subTableHandler.ListByDatabase) + subTableGroup.GET("/mapping/:database_id", subTableHandler.GetMappingFromFile) + subTableGroup.GET("/ddl/:database_id", subTableHandler.GetTablesDDL) + subTableGroup.PUT("/:id", subTableHandler.Update) + subTableGroup.DELETE("/:id", subTableHandler.Delete) + } + + // 系统信息模块 + r.GET("/system/info", systemHandler.GetSystemInfo) + + // 8. 启动服务 + log.Printf("Server starting on :%s", cfg.Port) + if err := r.Run(":" + cfg.Port); err != nil { + log.Fatalf("Failed to start server: %v", err) + } +} diff --git a/server/config/config.yaml b/server/config/config.yaml new file mode 100644 index 0000000..b4882d3 --- /dev/null +++ b/server/config/config.yaml @@ -0,0 +1,6 @@ +# 本地开发配置 +port: "8082" +jwt_secret: "dev-secret-key" +# Docker 内访问用 db:3306,本地访问用 localhost:6036 +database_url: "root:root@tcp(localhost:6036)/x_agents?charset=utf8mb4&parseTime=True&loc=Local" +python_service_url: "http://localhost:8081" diff --git a/server/go.mod b/server/go.mod new file mode 100644 index 0000000..8c2a7ed --- /dev/null +++ b/server/go.mod @@ -0,0 +1,65 @@ +module x-agents/server + +go 1.21 + +require ( + github.com/gin-gonic/gin v1.9.1 + github.com/go-sql-driver/mysql v1.7.0 + github.com/golang-jwt/jwt/v5 v5.2.0 + github.com/google/uuid v1.5.0 + github.com/lib/pq v1.10.9 + github.com/spf13/viper v1.18.2 + golang.org/x/crypto v0.16.0 + gorm.io/driver/mysql v1.5.2 + gorm.io/gorm v1.25.5 +) + +require ( + github.com/bytedance/sonic v1.9.1 // indirect + github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.2 // indirect + github.com/gin-contrib/sse v0.1.0 // indirect + github.com/go-ole/go-ole v1.2.6 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.14.0 // indirect + github.com/goccy/go-json v0.10.2 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/cpuid/v2 v2.2.4 // indirect + github.com/leodido/go-urn v1.2.4 // indirect + github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect + github.com/magiconair/properties v1.8.7 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pelletier/go-toml/v2 v2.1.0 // indirect + github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect + github.com/sagikazarmark/locafero v0.4.0 // indirect + github.com/sagikazarmark/slog-shim v0.1.0 // indirect + github.com/shirou/gopsutil/v3 v3.24.5 // indirect + github.com/sourcegraph/conc v0.3.0 // indirect + github.com/spf13/afero v1.11.0 // indirect + github.com/spf13/cast v1.6.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + github.com/tklauser/go-sysconf v0.3.12 // indirect + github.com/tklauser/numcpus v0.6.1 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.2.11 // indirect + github.com/yusufpapurcu/wmi v1.2.4 // indirect + go.uber.org/atomic v1.9.0 // indirect + go.uber.org/multierr v1.9.0 // indirect + golang.org/x/arch v0.3.0 // indirect + golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect + golang.org/x/net v0.19.0 // indirect + golang.org/x/sys v0.20.0 // indirect + golang.org/x/text v0.14.0 // indirect + google.golang.org/protobuf v1.31.0 // indirect + gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/server/go.sum b/server/go.sum new file mode 100644 index 0000000..bf1fc27 --- /dev/null +++ b/server/go.sum @@ -0,0 +1,169 @@ +github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= +github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= +github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= +github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= +github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= +github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= +github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= +github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= +github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw= +github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU= +github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= +github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= +github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4= +github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= +github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= +github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4= +github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= +github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= +github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= +github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= +github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= +github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI= +github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk= +github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= +github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= +github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= +github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= +github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= +github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= +github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMVB+yk= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= +github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= +github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= +github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= +github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= +github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= +go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= +golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= +golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= +golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= +golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= +google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= +gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.5.2 h1:QC2HRskSE75wBuOxe0+iCkyJZ+RqpudsQtqkp+IMuXs= +gorm.io/driver/mysql v1.5.2/go.mod h1:pQLhh1Ut/WUAySdTHwBpBv6+JKcj+ua4ZFx1QQTBzb8= +gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= +gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= +gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/server/internal/config/config.go b/server/internal/config/config.go new file mode 100644 index 0000000..431e5a4 --- /dev/null +++ b/server/internal/config/config.go @@ -0,0 +1,61 @@ +package config + +import ( + "fmt" + "log" + + "gorm.io/driver/mysql" + "gorm.io/gorm" + "gorm.io/gorm/logger" + + "github.com/spf13/viper" +) + +type Config struct { + Port string + JWTSecret string + DatabaseURL string + PythonServiceURL string +} + +func Load() *Config { + viper.SetConfigName("config") + viper.SetConfigType("yaml") + viper.AddConfigPath("./config") + viper.AddConfigPath("../config") + viper.AddConfigPath("../../config") + + // 默认值 + viper.SetDefault("port", "8080") + viper.SetDefault("jwt_secret", "your-secret-key-change-in-production") + viper.SetDefault("python_service_url", "http://localhost:8081") + viper.SetDefault("database_url", "root:root@tcp(localhost:3306)/x_agents?charset=utf8mb4&parseTime=True&loc=Local") + + if err := viper.ReadInConfig(); err != nil { + log.Printf("Using default config: %v", err) + } + + return &Config{ + Port: viper.GetString("port"), + JWTSecret: viper.GetString("jwt_secret"), + DatabaseURL: viper.GetString("database_url"), + PythonServiceURL: viper.GetString("python_service_url"), + } +} + +func InitDB(cfg *Config) (*gorm.DB, error) { + dsn := cfg.DatabaseURL + if dsn == "" { + return nil, fmt.Errorf("database URL is empty") + } + + db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Info), + }) + if err != nil { + return nil, fmt.Errorf("failed to connect database: %w", err) + } + + log.Println("Database connected successfully") + return db, nil +} diff --git a/server/internal/handler/approval_handler.go b/server/internal/handler/approval_handler.go new file mode 100644 index 0000000..db90993 --- /dev/null +++ b/server/internal/handler/approval_handler.go @@ -0,0 +1,80 @@ +package handler + +import ( + "net/http" + + "x-agents/server/internal/model" + "x-agents/server/internal/service" + + "github.com/gin-gonic/gin" +) + +type ApprovalHandler struct { + approvalService *service.ApprovalService +} + +func NewApprovalHandler(approvalService *service.ApprovalService) *ApprovalHandler { + return &ApprovalHandler{approvalService: approvalService} +} + +// Approve 处理审批请求 +func (h *ApprovalHandler) Approve(c *gin.Context) { + var req struct { + RequestID string `json:"request_id" binding:"required"` + Approved bool `json:"approved"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + userID, exists := c.Get("user_id") + if !exists { + c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"}) + return + } + + var result interface{} + var err error + + if req.Approved { + result, err = h.approvalService.Approve(req.RequestID, userID.(string)) + } else { + result, err = h.approvalService.Reject(req.RequestID, userID.(string)) + } + + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, result) +} + +// GetStatus 获取审批状态 +func (h *ApprovalHandler) GetStatus(c *gin.Context) { + requestID := c.Param("id") + + result, err := h.approvalService.GetApproval(requestID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "request not found"}) + return + } + + c.JSON(http.StatusOK, result) +} + +// ListPending 获取待审批列表 +func (h *ApprovalHandler) ListPending(c *gin.Context) { + result, err := h.approvalService.GetPendingApprovals() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if result == nil { + result = []model.ToolApprovalRequest{} + } + + c.JSON(http.StatusOK, gin.H{"pending": result}) +} diff --git a/server/internal/handler/auth_handler.go b/server/internal/handler/auth_handler.go new file mode 100644 index 0000000..0c35032 --- /dev/null +++ b/server/internal/handler/auth_handler.go @@ -0,0 +1,80 @@ +package handler + +import ( + "net/http" + + "x-agents/server/internal/service" + + "github.com/gin-gonic/gin" +) + +type AuthHandler struct { + authService *service.AuthService +} + +func NewAuthHandler(authService *service.AuthService) *AuthHandler { + return &AuthHandler{authService: authService} +} + +type LoginRequest struct { + Username string `json:"username" binding:"required"` + Password string `json:"password" binding:"required"` +} + +type LoginResponse struct { + Token string `json:"token"` + User interface{} `json:"user"` +} + +// Login 处理登录 +func (h *AuthHandler) Login(c *gin.Context) { + var req LoginRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + resp, err := h.authService.Login(service.LoginRequest{ + Username: req.Username, + Password: req.Password, + }) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, LoginResponse{ + Token: resp.Token, + User: gin.H{ + "id": resp.User.ID, + "username": resp.User.Username, + "email": resp.User.Email, + "role": resp.User.RoleID, + }, + }) +} + +// Register 处理注册 +func (h *AuthHandler) Register(c *gin.Context) { + var req struct { + Username string `json:"username" binding:"required"` + Password string `json:"password" binding:"required"` + Email string `json:"email"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + user, err := h.authService.Register(req.Username, req.Password, req.Email) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusCreated, gin.H{ + "id": user.ID, + "username": user.Username, + "email": user.Email, + }) +} diff --git a/server/internal/handler/chat_handler.go b/server/internal/handler/chat_handler.go new file mode 100644 index 0000000..9af5d02 --- /dev/null +++ b/server/internal/handler/chat_handler.go @@ -0,0 +1,89 @@ +package handler + +import ( + "net/http" + + "x-agents/server/internal/model" + "x-agents/server/internal/service" + + "github.com/gin-gonic/gin" +) + +type ChatHandler struct { + chatService *service.ChatService +} + +func NewChatHandler(chatService *service.ChatService) *ChatHandler { + return &ChatHandler{chatService: chatService} +} + +// Chat 处理聊天请求 +func (h *ChatHandler) Chat(c *gin.Context) { + var req model.AgentRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // 从上下文获取用户ID(由中间件设置) + userID, exists := c.Get("user_id") + if !exists { + c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"}) + return + } + + resp, err := h.chatService.Chat(c.Request.Context(), userID.(string), req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, resp) +} + +// ListAgents 获取 Agent 列表 +func (h *ChatHandler) ListAgents(c *gin.Context) { + userID, exists := c.Get("user_id") + if !exists { + c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"}) + return + } + + agents, err := h.chatService.ListAgents(userID.(string)) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if agents == nil { + agents = []model.Agent{} + } + + c.JSON(http.StatusOK, gin.H{"agents": agents}) +} + +// CreateAgent 创建 Agent +func (h *ChatHandler) CreateAgent(c *gin.Context) { + userID, exists := c.Get("user_id") + if !exists { + c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"}) + return + } + + var req struct { + Name string `json:"name" binding:"required"` + Description string `json:"description"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + agent, err := h.chatService.CreateAgent(userID.(string), req.Name, req.Description) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusCreated, agent) +} diff --git a/server/internal/handler/database_handler.go b/server/internal/handler/database_handler.go new file mode 100644 index 0000000..5290016 --- /dev/null +++ b/server/internal/handler/database_handler.go @@ -0,0 +1,112 @@ +package handler + +import ( + "net/http" + + "x-agents/server/internal/model" + "x-agents/server/internal/service" + + "github.com/gin-gonic/gin" +) + +type DatabaseHandler struct { + service *service.DatabaseService +} + +func NewDatabaseHandler(svc *service.DatabaseService) *DatabaseHandler { + return &DatabaseHandler{service: svc} +} + +// Check 检查数据库连接 +func (h *DatabaseHandler) Check(c *gin.Context) { + var req model.CheckRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + result, err := h.service.Check(req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, result) +} + +// Create 创建数据库信息 +func (h *DatabaseHandler) Create(c *gin.Context) { + var req model.CreateDatabaseRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + info, err := h.service.Create(req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusCreated, info) +} + +// GetByID 获取详情 +func (h *DatabaseHandler) GetByID(c *gin.Context) { + id := c.Param("id") + + info, err := h.service.GetByID(id) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "not found"}) + return + } + + c.JSON(http.StatusOK, info) +} + +// List 获取列表 +func (h *DatabaseHandler) List(c *gin.Context) { + list, err := h.service.List() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if list == nil { + list = []model.DatabaseInfo{} + } + + c.JSON(http.StatusOK, gin.H{"list": list}) +} + +// Update 更新 +func (h *DatabaseHandler) Update(c *gin.Context) { + id := c.Param("id") + + var req model.UpdateDatabaseRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + info, err := h.service.Update(id, req) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "not found"}) + return + } + + c.JSON(http.StatusOK, info) +} + +// Delete 删除 +func (h *DatabaseHandler) Delete(c *gin.Context) { + id := c.Param("id") + + err := h.service.Delete(id) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "not found"}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "deleted"}) +} diff --git a/server/internal/handler/sub_table_handler.go b/server/internal/handler/sub_table_handler.go new file mode 100644 index 0000000..31a533d --- /dev/null +++ b/server/internal/handler/sub_table_handler.go @@ -0,0 +1,132 @@ +package handler + +import ( + "net/http" + + "x-agents/server/internal/model" + "x-agents/server/internal/service" + + "github.com/gin-gonic/gin" +) + +type SubTableHandler struct { + service *service.SubTableService +} + +func NewSubTableHandler(svc *service.SubTableService) *SubTableHandler { + return &SubTableHandler{service: svc} +} + +// Create 创建子表信息 +func (h *SubTableHandler) Create(c *gin.Context) { + var req model.CreateSubTableRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + info, err := h.service.Create(req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusCreated, info) +} + +// GetByID 获取详情 +func (h *SubTableHandler) GetByID(c *gin.Context) { + id := c.Param("id") + + info, err := h.service.GetByID(id) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "not found"}) + return + } + + c.JSON(http.StatusOK, info) +} + +// ListByDatabase 获取数据库下所有子表 +func (h *SubTableHandler) ListByDatabase(c *gin.Context) { + databaseID := c.Param("database_id") + + list, err := h.service.ListByDatabaseID(databaseID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if list == nil { + list = []model.SubTableInfo{} + } + + c.JSON(http.StatusOK, gin.H{"list": list}) +} + +// GetMappingFromFile 从文件获取映射 +func (h *SubTableHandler) GetMappingFromFile(c *gin.Context) { + databaseID := c.Param("database_id") + + mapping, err := h.service.GetMappingFromFile(databaseID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if mapping == nil { + c.JSON(http.StatusOK, gin.H{"mapping": nil, "message": "no mapping file found"}) + return + } + + c.JSON(http.StatusOK, gin.H{"mapping": mapping}) +} + +// Update 更新 +func (h *SubTableHandler) Update(c *gin.Context) { + id := c.Param("id") + + var req model.UpdateSubTableRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + info, err := h.service.Update(id, req) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "not found"}) + return + } + + c.JSON(http.StatusOK, info) +} + +// Delete 删除 +func (h *SubTableHandler) Delete(c *gin.Context) { + id := c.Param("id") + + err := h.service.Delete(id) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "not found"}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "deleted"}) +} + +// GetTablesDDL 获取数据库下所有表及DDL +func (h *SubTableHandler) GetTablesDDL(c *gin.Context) { + databaseID := c.Param("database_id") + + tables, err := h.service.GetTableDDLFromDatabase(databaseID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if tables == nil { + tables = []model.TableDDLInfo{} + } + + c.JSON(http.StatusOK, gin.H{"tables": tables}) +} diff --git a/server/internal/handler/system_handler.go b/server/internal/handler/system_handler.go new file mode 100644 index 0000000..2408a1e --- /dev/null +++ b/server/internal/handler/system_handler.go @@ -0,0 +1,62 @@ +package handler + +import ( + "net/http" + + "x-agents/server/internal/model" + + "github.com/gin-gonic/gin" +) + +type SystemHandler struct{} + +func NewSystemHandler() *SystemHandler { + return &SystemHandler{} +} + +// GetSystemInfo 获取系统信息 +func (h *SystemHandler) GetSystemInfo(c *gin.Context) { + info, err := getSystemInfo() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, info) +} + +// getSystemInfo 获取系统信息 +func getSystemInfo() (*model.SystemInfo, error) { + // 获取CPU使用率 + cpuPercent, err := getCPUPercent() + if err != nil { + return nil, err + } + + // 获取CPU核心数 + coreCount, err := getCPUCoreCount() + if err != nil { + coreCount = 0 + } + + // 获取CPU型号 + modelName, err := getCPUModelName() + if err != nil { + modelName = "Unknown" + } + + // 获取内存信息 + memoryInfo, err := getMemoryInfo() + if err != nil { + return nil, err + } + + return &model.SystemInfo{ + CPU: model.CPUInfo{ + Percent: cpuPercent, + CoreCount: coreCount, + ModelName: modelName, + }, + Memory: *memoryInfo, + }, nil +} diff --git a/server/internal/handler/system_helper.go b/server/internal/handler/system_helper.go new file mode 100644 index 0000000..28448b7 --- /dev/null +++ b/server/internal/handler/system_helper.go @@ -0,0 +1,60 @@ +package handler + +import ( + "github.com/shirou/gopsutil/v3/cpu" + "github.com/shirou/gopsutil/v3/mem" + "x-agents/server/internal/model" +) + +func getCPUPercent() (float64, error) { + percent, err := cpu.Percent(0, false) + if err != nil { + return 0, err + } + if len(percent) > 0 { + return percent[0], nil + } + return 0, nil +} + +func getCPUCoreCount() (int, error) { + count, err := cpu.Counts(false) + if err != nil { + return 0, err + } + return count, nil +} + +func getCPUModelName() (string, error) { + info, err := cpu.Info() + if err != nil { + return "Unknown", err + } + if len(info) > 0 { + return info[0].ModelName, nil + } + return "Unknown", nil +} + +func getMemoryInfo() (*model.MemoryInfo, error) { + v, err := mem.VirtualMemory() + if err != nil { + return nil, err + } + + // 计算使用率 + percent := 0.0 + if v.Total > 0 { + percent = float64(v.Used) / float64(v.Total) * 100 + } + + return &model.MemoryInfo{ + Total: v.Total, + Used: v.Used, + Available: v.Available, + Percent: percent, + TotalGB: float64(v.Total) / 1024 / 1024 / 1024, + UsedGB: float64(v.Used) / 1024 / 1024 / 1024, + AvailableGB: float64(v.Available) / 1024 / 1024 / 1024, + }, nil +} diff --git a/server/internal/middleware/auth.go b/server/internal/middleware/auth.go new file mode 100644 index 0000000..e97ebb2 --- /dev/null +++ b/server/internal/middleware/auth.go @@ -0,0 +1,71 @@ +package middleware + +import ( + "net/http" + "strings" + + "x-agents/server/internal/service" + + "github.com/gin-gonic/gin" +) + +// CORS 中间件 +func CORS() gin.HandlerFunc { + return func(c *gin.Context) { + c.Header("Access-Control-Allow-Origin", "*") + c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Authorization") + c.Header("Access-Control-Max-Age", "86400") + + if c.Request.Method == "OPTIONS" { + c.AbortWithStatus(http.StatusNoContent) + return + } + + c.Next() + } +} + +// Recovery 中间件 - 恢复 panic +func Recovery() gin.HandlerFunc { + return gin.Recovery() +} + +// Auth 认证中间件 +func Auth(jwtSecret string) gin.HandlerFunc { + return func(c *gin.Context) { + // 从 Header 获取 Token + authHeader := c.GetHeader("Authorization") + if authHeader == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header required"}) + c.Abort() + return + } + + // 解析 Bearer Token + parts := strings.SplitN(authHeader, " ", 2) + if len(parts) != 2 || parts[0] != "Bearer" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid authorization format"}) + c.Abort() + return + } + + tokenString := parts[1] + + // 验证 Token + authService := service.NewAuthService(jwtSecret, nil) + claims, err := authService.ValidateToken(tokenString) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"}) + c.Abort() + return + } + + // 将用户信息存入上下文 + c.Set("user_id", claims["sub"]) + c.Set("username", claims["username"]) + c.Set("role", claims["role"]) + + c.Next() + } +} diff --git a/server/internal/model/agent.go b/server/internal/model/agent.go new file mode 100644 index 0000000..13aa32c --- /dev/null +++ b/server/internal/model/agent.go @@ -0,0 +1,53 @@ +package model + +import ( + "time" +) + +// SecurityLevel 安全等级 +type SecurityLevel string + +const ( + SecurityLevelSafe SecurityLevel = "safe" + SecurityLevelReview SecurityLevel = "review" + SecurityLevelDanger SecurityLevel = "danger" +) + +// Agent 智能体 +type Agent struct { + ID string `json:"id" gorm:"primaryKey"` + Name string `json:"name" gorm:"size:100;not null"` + Description string `json:"description" gorm:"type:text"` + OwnerID string `json:"owner_id" gorm:"size:50;not null;index"` + + // Agent能力配置 + Capabilities []string `json:"capabilities" gorm:"type:text"` // JSON数组,可用工具列表 + MemoryLimit int64 `json:"memory_limit" gorm:"default:134217728"` // 128MB + Timeout int `json:"timeout" gorm:"default:60"` // 60秒 + + // 安全配置 + SecurityLevel SecurityLevel `json:"security_level" gorm:"size:20;default:'safe'"` + AllowDangerousTools bool `json:"allow_dangerous_tools" gorm:"default:false"` + + // 状态 + IsActive bool `json:"is_active" gorm:"default:true"` + + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// AgentRequest 聊天请求 +type AgentRequest struct { + AgentID string `json:"agent_id" binding:"required"` + Message string `json:"message" binding:"required"` + SessionID string `json:"session_id"` + Context map[string]interface{} `json:"context"` +} + +// AgentResponse 聊天响应 +type AgentResponse struct { + Reply string `json:"reply"` + SessionID string `json:"session_id"` + ToolsUsed []string `json:"tools_used"` + Metadata map[string]interface{} `json:"metadata"` +} diff --git a/server/internal/model/audit.go b/server/internal/model/audit.go new file mode 100644 index 0000000..872bcf0 --- /dev/null +++ b/server/internal/model/audit.go @@ -0,0 +1,76 @@ +package model + +import ( + "encoding/json" + "time" +) + +// AuditAction 审计动作 +type AuditAction string + +const ( + AuditActionLogin AuditAction = "login" + AuditActionLogout AuditAction = "logout" + AuditActionChat AuditAction = "chat" + AuditActionToolExecute AuditAction = "tool_execute" + AuditActionToolApprove AuditAction = "tool_approve" + AuditActionToolReject AuditAction = "tool_reject" + AuditActionAgentCreate AuditAction = "agent_create" + AuditActionAgentUpdate AuditAction = "agent_update" + AuditActionAgentDelete AuditAction = "agent_delete" +) + +// AuditLog 审计日志 +type AuditLog struct { + ID string `json:"id" gorm:"primaryKey"` + UserID string `json:"user_id" gorm:"size:50;index"` + AgentID string `json:"agent_id" gorm:"size:50;index"` + Action AuditAction `json:"action" gorm:"size:50;index"` + Details JSONMap `json:"details" gorm:"type:jsonb"` + Result string `json:"result" gorm:"size:20"` // success, failed, rejected + IPAddress string `json:"ip_address" gorm:"size:45"` + UserAgent string `json:"user_agent" gorm:"size:255"` + CreatedAt time.Time `json:"created_at" gorm:"index"` +} + +// ApprovalStatus 审批状态 +type ApprovalStatus string + +const ( + ApprovalStatusPending ApprovalStatus = "pending" + ApprovalStatusApproved ApprovalStatus = "approved" + ApprovalStatusRejected ApprovalStatus = "rejected" +) + +// ToolApprovalRequest 工具审批请求 +type ToolApprovalRequest struct { + ID string `json:"id" gorm:"primaryKey"` + ToolName string `json:"tool_name" gorm:"size:100;index"` + Params JSONMap `json:"params" gorm:"type:jsonb"` + UserID string `json:"user_id" gorm:"size:50;index"` + AgentID string `json:"agent_id" gorm:"size:50"` + Reason string `json:"reason" gorm:"type:text"` + Status ApprovalStatus `json:"status" gorm:"size:20;default:'pending';index"` + ReviewedBy *string `json:"reviewed_by" gorm:"size:50"` + ReviewedAt *time.Time `json:"reviewed_at"` + Result *string `json:"result" gorm:"type:text"` // 执行结果 + CreatedAt time.Time `json:"created_at" gorm:"index"` + UpdatedAt time.Time `json:"updated_at"` +} + +// JSONMap JSON数据映射 +type JSONMap map[string]interface{} + +func (j JSONMap) MarshalJSON() ([]byte, error) { + if j == nil { + return []byte("null"), nil + } + return json.Marshal(j) +} + +func (j *JSONMap) UnmarshalJSON(data []byte) error { + if j == nil { + *j = make(map[string]interface{}) + } + return json.Unmarshal(data, j) +} diff --git a/server/internal/model/database_info.go b/server/internal/model/database_info.go new file mode 100644 index 0000000..61ad2c1 --- /dev/null +++ b/server/internal/model/database_info.go @@ -0,0 +1,83 @@ +package model + +import ( + "time" +) + +// DatabaseInfo 数据库连接信息 +type DatabaseInfo struct { + ID string `json:"id" gorm:"primaryKey;size:36"` // UUID + Name string `json:"name" gorm:"size:100;not null"` // 数据库名称 + Description string `json:"description" gorm:"size:500"` // 描述 + DBType string `json:"db_type" gorm:"size:20;not null"` // 数据库类型: mysql, postgres, mongodb等 + Host string `json:"host" gorm:"size:255;not null"` // 主机地址 + Port int `json:"port" gorm:"not null"` // 端口 + Username string `json:"username" gorm:"size:100;not null"` // 用户名 + Password string `json:"password" gorm:"size:255"` // 密码(建议加密存储) + Database string `json:"database" gorm:"size:100"` // 数据库名 + TableCount int `json:"table_count" gorm:"default:0"` // 子表数量 + + // 连接选项 + Charset string `json:"charset" gorm:"size:20;default:utf8mb4"` // 字符集 + SSLMode string `json:"ssl_mode" gorm:"size:20"` // SSL模式 + + // 时间 + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// TableName 表名 +func (DatabaseInfo) TableName() string { + return "database_info" +} + +// CreateRequest 创建数据库信息请求(支持同时保存子表配置) +type CreateDatabaseRequest struct { + Name string `json:"name" binding:"required"` + Description string `json:"description"` + DBType string `json:"db_type" binding:"required"` + Host string `json:"host" binding:"required"` + Port int `json:"port" binding:"required"` + Username string `json:"username" binding:"required"` + Password string `json:"password"` + Database string `json:"database"` + Charset string `json:"charset"` + SSLMode string `json:"ssl_mode"` + SubTables []CreateSubTableRequest `json:"sub_tables"` // 可选,子表配置 +} + +// UpdateRequest 更新数据库信息请求 +type UpdateDatabaseRequest struct { + Name string `json:"name"` + Description string `json:"description"` + DBType string `json:"db_type"` + Host string `json:"host"` + Port int `json:"port"` + Username string `json:"username"` + Password string `json:"password"` + Database string `json:"database"` + TableCount int `json:"table_count"` + Charset string `json:"charset"` + SSLMode string `json:"ssl_mode"` +} + +// CheckRequest 检查连接请求 +type CheckRequest struct { + DBType string `json:"db_type" binding:"required"` + Host string `json:"host" binding:"required"` + Port int `json:"port" binding:"required"` + Username string `json:"username" binding:"required"` + Password string `json:"password"` + Database string `json:"database"` + Charset string `json:"charset"` + SSLMode string `json:"ssl_mode"` + DatabaseID string `json:"database_id"` // 可选,用于获取已保存的字段映射 +} + +// CheckResponse 检查连接响应 +type CheckResponse struct { + Success bool `json:"success"` // 是否连接成功 + Message string `json:"message"` // 消息 + Tables []TableDDLInfo `json:"tables,omitempty"` // 表列表(连接成功时返回) + Database string `json:"database"` // 数据库名 +} diff --git a/server/internal/model/sub_table_info.go b/server/internal/model/sub_table_info.go new file mode 100644 index 0000000..c4eb1c7 --- /dev/null +++ b/server/internal/model/sub_table_info.go @@ -0,0 +1,117 @@ +package model + +import ( + "encoding/json" + "time" +) + +// TableDDLInfo 表结构信息 +type TableDDLInfo struct { + TableName string `json:"table_name"` // 表名 + TableComment string `json:"table_comment"` // 表注释 + Columns []ColumnInfo `json:"columns"` // 列信息 + DDL string `json:"ddl"` // 建表DDL + Indexes []IndexInfo `json:"indexes"` // 索引信息 +} + +// ColumnInfo 列信息 +type ColumnInfo struct { + ColumnName string `json:"column_name"` // 列名 + DataType string `json:"data_type"` // 数据类型 + ColumnType string `json:"column_type"` // 列类型(含长度) + IsNullable string `json:"is_nullable"` // 是否可空 + DefaultValue string `json:"default_value"` // 默认值 + ColumnKey string `json:"column_key"` // 主键/索引 + Extra string `json:"extra"` // 自增等 + ColumnComment string `json:"column_comment"` // 列注释 + MappedName string `json:"mapped_name"` // 字段中文映射名 +} + +// IndexInfo 索引信息 +type IndexInfo struct { + IndexName string `json:"index_name"` // 索引名 + ColumnName string `json:"column_name"` // 列名 + NonUnique int `json:"non_unique"` // 是否唯一 + IndexType string `json:"index_type"` // 索引类型 +} + +// SubTableInfo 子表信息 +type SubTableInfo struct { + ID string `json:"id"` // UUID + DatabaseID string `json:"database_id"` // 关联的数据库ID + ParentTable string `json:"parent_table"` // 父表名 + SubTableName string `json:"sub_table_name"` // 子表名 + SubTableComment string `json:"sub_table_comment"` // 子表注释 + MappingType string `json:"mapping_type" gorm:"type:varchar(20)"` // 映射类型 + RelationField string `json:"relation_field" gorm:"type:varchar(100)"` // 关联字段 + RelationType string `json:"relation_type" gorm:"type:varchar(20)"` // 关联类型 + Fields string `json:"-" gorm:"type:longtext"` // 字段映射列表(JSON 格式,内部存储) + FieldsList []FieldMapping `json:"fields" gorm:"-"` // 字段映射列表(返回给前端) + DDL string `json:"ddl" gorm:"type:longtext"` // 建表 DDL + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// FieldMapping 字段映射 +type FieldMapping struct { + ColumnName string `json:"column_name"` // 列名 + MappedName string `json:"mapped_name"` // 中文映射名 +} + +// GetFields 获取字段映射列表 +func (s *SubTableInfo) GetFields() []FieldMapping { + if s.Fields == "" { + return nil + } + var fields []FieldMapping + if err := json.Unmarshal([]byte(s.Fields), &fields); err != nil { + return nil + } + return fields +} + +// SetFields 设置字段映射列表 +func (s *SubTableInfo) SetFields(fields []FieldMapping) { + if len(fields) == 0 { + s.Fields = "" + return + } + data, _ := json.Marshal(fields) + s.Fields = string(data) +} + +// TableName 表名 +func (SubTableInfo) TableName() string { + return "sub_table_info" +} + +// CreateSubTableRequest 创建子表请求 +type CreateSubTableRequest struct { + DatabaseID string `json:"database_id" binding:"required"` + ParentTable string `json:"parent_table" binding:"required"` + SubTableName string `json:"sub_table_name" binding:"required"` + SubTableComment string `json:"sub_table_comment"` + MappingType string `json:"mapping_type"` + RelationField string `json:"relation_field"` + RelationType string `json:"relation_type"` + Fields []FieldMapping `json:"fields"` // 字段映射列表 +} + +// UpdateSubTableRequest 更新子表请求 +type UpdateSubTableRequest struct { + ParentTable string `json:"parent_table"` + SubTableName string `json:"sub_table_name"` + SubTableComment string `json:"sub_table_comment"` + MappingType string `json:"mapping_type"` + RelationField string `json:"relation_field"` + RelationType string `json:"relation_type"` +} + +// SubTableMapping 完整的子表映射配置(存储到文件的格式) +type SubTableMapping struct { + DatabaseID string `json:"database_id"` + DatabaseName string `json:"database_name"` + DBType string `json:"db_type"` + Tables []SubTableInfo `json:"tables"` + UpdatedAt time.Time `json:"updated_at"` +} diff --git a/server/internal/model/system_info.go b/server/internal/model/system_info.go new file mode 100644 index 0000000..365c314 --- /dev/null +++ b/server/internal/model/system_info.go @@ -0,0 +1,25 @@ +package model + +// SystemInfo 系统信息 +type SystemInfo struct { + CPU CPUInfo `json:"cpu"` + Memory MemoryInfo `json:"memory"` +} + +// CPUInfo CPU信息 +type CPUInfo struct { + Percent float64 `json:"percent"` // CPU使用率 + CoreCount int `json:"core_count"` // 核心数 + ModelName string `json:"model_name"` // CPU型号 +} + +// MemoryInfo 内存信息 +type MemoryInfo struct { + Total uint64 `json:"total"` // 总内存(字节) + Used uint64 `json:"used"` // 已使用(字节) + Available uint64 `json:"available"` // 可用(字节) + Percent float64 `json:"percent"` // 使用率 + TotalGB float64 `json:"total_gb"` // 总内存(GB) + UsedGB float64 `json:"used_gb"` // 已使用(GB) + AvailableGB float64 `json:"available_gb"` // 可用(GB) +} diff --git a/server/internal/model/user.go b/server/internal/model/user.go new file mode 100644 index 0000000..4d37dc5 --- /dev/null +++ b/server/internal/model/user.go @@ -0,0 +1,50 @@ +package model + +import ( + "time" +) + +// PermissionLevel 权限级别 +type PermissionLevel int + +const ( + PermissionRead PermissionLevel = iota + 1 + PermissionWrite + PermissionExecute + PermissionAdmin +) + +// Role 角色 +type Role struct { + ID string `json:"id" gorm:"primaryKey"` + Name string `json:"name" gorm:"uniqueIndex"` + Permissions []PermissionLevel `json:"permissions" gorm:"type:int[]"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// User 用户 +type User struct { + ID string `json:"id" gorm:"primaryKey"` + Username string `json:"username" gorm:"uniqueIndex;size:50;not null"` + Password string `json:"-" gorm:"not null"` + Email string `json:"email" gorm:"index"` + RoleID string `json:"role_id" gorm:"size:50;not null"` + Role *Role `json:"role,omitempty" gorm:"foreignKey:RoleID"` + IsActive bool `json:"is_active" gorm:"default:true"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// HasPermission 检查是否有权限 +func (u *User) HasPermission(level PermissionLevel) bool { + if u.Role == nil { + return false + } + for _, p := range u.Role.Permissions { + if p >= level { + return true + } + } + return false +} diff --git a/server/internal/repository/agent_repo.go b/server/internal/repository/agent_repo.go new file mode 100644 index 0000000..3d99745 --- /dev/null +++ b/server/internal/repository/agent_repo.go @@ -0,0 +1,48 @@ +package repository + +import ( + "x-agents/server/internal/model" + + "gorm.io/gorm" +) + +type AgentRepository struct { + db *gorm.DB +} + +func NewAgentRepository(db *gorm.DB) *AgentRepository { + return &AgentRepository{db: db} +} + +func (r *AgentRepository) Create(agent *model.Agent) error { + return r.db.Create(agent).Error +} + +func (r *AgentRepository) FindByID(id string) (*model.Agent, error) { + var agent model.Agent + err := r.db.First(&agent, "id = ?", id).Error + if err != nil { + return nil, err + } + return &agent, nil +} + +func (r *AgentRepository) FindByOwnerID(ownerID string) ([]model.Agent, error) { + var agents []model.Agent + err := r.db.Where("owner_id = ?", ownerID).Find(&agents).Error + return agents, err +} + +func (r *AgentRepository) FindAll() ([]model.Agent, error) { + var agents []model.Agent + err := r.db.Where("is_active = ?", true).Find(&agents).Error + return agents, err +} + +func (r *AgentRepository) Update(agent *model.Agent) error { + return r.db.Save(agent).Error +} + +func (r *AgentRepository) Delete(id string) error { + return r.db.Delete(&model.Agent{}, "id = ?", id).Error +} diff --git a/server/internal/repository/audit_repo.go b/server/internal/repository/audit_repo.go new file mode 100644 index 0000000..791d6bc --- /dev/null +++ b/server/internal/repository/audit_repo.go @@ -0,0 +1,56 @@ +package repository + +import ( + "x-agents/server/internal/model" + + "gorm.io/gorm" +) + +type AuditRepository struct { + db *gorm.DB +} + +func NewAuditRepository(db *gorm.DB) *AuditRepository { + return &AuditRepository{db: db} +} + +func (r *AuditRepository) Create(log *model.AuditLog) error { + return r.db.Create(log).Error +} + +func (r *AuditRepository) FindByUserID(userID string, limit int) ([]model.AuditLog, error) { + var logs []model.AuditLog + err := r.db.Where("user_id = ?", userID).Order("created_at DESC").Limit(limit).Find(&logs).Error + return logs, err +} + +func (r *AuditRepository) FindByAgentID(agentID string, limit int) ([]model.AuditLog, error) { + var logs []model.AuditLog + err := r.db.Where("agent_id = ?", agentID).Order("created_at DESC").Limit(limit).Find(&logs).Error + return logs, err +} + +// ToolApproval 工具审批仓储 + +func (r *AuditRepository) CreateApproval(req *model.ToolApprovalRequest) error { + return r.db.Create(req).Error +} + +func (r *AuditRepository) FindApprovalByID(id string) (*model.ToolApprovalRequest, error) { + var req model.ToolApprovalRequest + err := r.db.First(&req, "id = ?", id).Error + if err != nil { + return nil, err + } + return &req, nil +} + +func (r *AuditRepository) FindPendingApprovals() ([]model.ToolApprovalRequest, error) { + var reqs []model.ToolApprovalRequest + err := r.db.Where("status = ?", model.ApprovalStatusPending).Order("created_at ASC").Find(&reqs).Error + return reqs, err +} + +func (r *AuditRepository) UpdateApproval(req *model.ToolApprovalRequest) error { + return r.db.Save(req).Error +} diff --git a/server/internal/repository/database_repo.go b/server/internal/repository/database_repo.go new file mode 100644 index 0000000..552b305 --- /dev/null +++ b/server/internal/repository/database_repo.go @@ -0,0 +1,47 @@ +package repository + +import ( + "x-agents/server/internal/model" + + "gorm.io/gorm" +) + +type DatabaseRepository struct { + db *gorm.DB +} + +func NewDatabaseRepository(db *gorm.DB) *DatabaseRepository { + return &DatabaseRepository{db: db} +} + +// Create 创建数据库信息 +func (r *DatabaseRepository) Create(info *model.DatabaseInfo) error { + return r.db.Create(info).Error +} + +// FindByID 根据ID查询 +func (r *DatabaseRepository) FindByID(id string) (*model.DatabaseInfo, error) { + var info model.DatabaseInfo + err := r.db.First(&info, "id = ?", id).Error + if err != nil { + return nil, err + } + return &info, nil +} + +// FindAll 查询所有 +func (r *DatabaseRepository) FindAll() ([]model.DatabaseInfo, error) { + var list []model.DatabaseInfo + err := r.db.Order("created_at DESC").Find(&list).Error + return list, err +} + +// Update 更新 +func (r *DatabaseRepository) Update(id string, info *model.DatabaseInfo) error { + return r.db.Model(&model.DatabaseInfo{}).Where("id = ?", id).Updates(info).Error +} + +// Delete 删除 +func (r *DatabaseRepository) Delete(id string) error { + return r.db.Delete(&model.DatabaseInfo{}, "id = ?", id).Error +} diff --git a/server/internal/repository/sub_table_repo.go b/server/internal/repository/sub_table_repo.go new file mode 100644 index 0000000..c033bd0 --- /dev/null +++ b/server/internal/repository/sub_table_repo.go @@ -0,0 +1,53 @@ +package repository + +import ( + "x-agents/server/internal/model" + + "gorm.io/gorm" +) + +type SubTableRepository struct { + db *gorm.DB +} + +func NewSubTableRepository(db *gorm.DB) *SubTableRepository { + return &SubTableRepository{db: db} +} + +// Create 创建子表信息 +func (r *SubTableRepository) Create(info *model.SubTableInfo) error { + return r.db.Create(info).Error +} + +// FindByID 根据ID查询 +func (r *SubTableRepository) FindByID(id string) (*model.SubTableInfo, error) { + var info model.SubTableInfo + if err := r.db.Where("id = ?", id).First(&info).Error; err != nil { + return nil, err + } + return &info, nil +} + +// FindByDatabaseID 根据数据库ID查询所有子表 +func (r *SubTableRepository) FindByDatabaseID(databaseID string) ([]model.SubTableInfo, error) { + var list []model.SubTableInfo + if err := r.db.Where("database_id = ?", databaseID).Find(&list).Error; err != nil { + return nil, err + } + return list, nil +} + +// Update 更新子表信息 +func (r *SubTableRepository) Update(id string, info *model.SubTableInfo) error { + return r.db.Model(info).Where("id = ?", id).Updates(info).Error +} + +// Delete 删除子表信息 +func (r *SubTableRepository) Delete(id string) error { + return r.db.Where("id = ?", id).Delete(&model.SubTableInfo{}).Error +} + +// DeleteByDatabaseID 删除数据库下所有子表信息 +func (r *SubTableRepository) DeleteByDatabaseID(databaseID string) error { + return r.db.Where("database_id = ?", databaseID).Delete(&model.SubTableInfo{}).Error +} diff --git a/server/internal/repository/user_repo.go b/server/internal/repository/user_repo.go new file mode 100644 index 0000000..af48fc7 --- /dev/null +++ b/server/internal/repository/user_repo.go @@ -0,0 +1,66 @@ +package repository + +import ( + "x-agents/server/internal/model" + + "gorm.io/gorm" +) + +type UserRepository struct { + db *gorm.DB +} + +func NewUserRepository(db *gorm.DB) *UserRepository { + return &UserRepository{db: db} +} + +func (r *UserRepository) Create(user *model.User) error { + return r.db.Create(user).Error +} + +func (r *UserRepository) FindByID(id string) (*model.User, error) { + var user model.User + err := r.db.Preload("Role").First(&user, "id = ?", id).Error + if err != nil { + return nil, err + } + return &user, nil +} + +func (r *UserRepository) FindByUsername(username string) (*model.User, error) { + var user model.User + err := r.db.Preload("Role").First(&user, "username = ?", username).Error + if err != nil { + return nil, err + } + return &user, nil +} + +func (r *UserRepository) FindAll() ([]model.User, error) { + var users []model.User + err := r.db.Preload("Role").Find(&users).Error + return users, err +} + +func (r *UserRepository) Update(user *model.User) error { + return r.db.Save(user).Error +} + +func (r *UserRepository) Delete(id string) error { + return r.db.Delete(&model.User{}, "id = ?", id).Error +} + +// FindRoleByID 根据ID查找角色 +func (r *UserRepository) FindRoleByID(id string) (*model.Role, error) { + var role model.Role + err := r.db.First(&role, "id = ?", id).Error + if err != nil { + return nil, err + } + return &role, nil +} + +// CreateRole 创建角色 +func (r *UserRepository) CreateRole(role *model.Role) error { + return r.db.Create(role).Error +} diff --git a/server/internal/service/approval_service.go b/server/internal/service/approval_service.go new file mode 100644 index 0000000..1db5469 --- /dev/null +++ b/server/internal/service/approval_service.go @@ -0,0 +1,101 @@ +package service + +import ( + "fmt" + "time" + + "x-agents/server/internal/model" + "x-agents/server/internal/repository" + + "github.com/google/uuid" +) + +type ApprovalService struct { + auditRepo *repository.AuditRepository +} + +func NewApprovalService(auditRepo *repository.AuditRepository) *ApprovalService { + return &ApprovalService{auditRepo: auditRepo} +} + +// CreateApprovalRequest 创建审批请求 +func (s *ApprovalService) CreateApprovalRequest( + toolName string, + params map[string]interface{}, + userID string, + agentID string, + reason string, +) (*model.ToolApprovalRequest, error) { + + req := &model.ToolApprovalRequest{ + ID: uuid.New().String(), + ToolName: toolName, + Params: params, + UserID: userID, + AgentID: agentID, + Reason: reason, + Status: model.ApprovalStatusPending, + } + + if err := s.auditRepo.CreateApproval(req); err != nil { + return nil, err + } + + return req, nil +} + +// Approve 批准请求 +func (s *ApprovalService) Approve(requestID, reviewedBy string) (*model.ToolApprovalRequest, error) { + req, err := s.auditRepo.FindApprovalByID(requestID) + if err != nil { + return nil, fmt.Errorf("request not found: %w", err) + } + + if req.Status != model.ApprovalStatusPending { + return nil, fmt.Errorf("request already processed") + } + + now := time.Now() + req.Status = model.ApprovalStatusApproved + req.ReviewedBy = &reviewedBy + req.ReviewedAt = &now + + if err := s.auditRepo.UpdateApproval(req); err != nil { + return nil, err + } + + return req, nil +} + +// Reject 拒绝请求 +func (s *ApprovalService) Reject(requestID, reviewedBy string) (*model.ToolApprovalRequest, error) { + req, err := s.auditRepo.FindApprovalByID(requestID) + if err != nil { + return nil, fmt.Errorf("request not found: %w", err) + } + + if req.Status != model.ApprovalStatusPending { + return nil, fmt.Errorf("request already processed") + } + + now := time.Now() + req.Status = model.ApprovalStatusRejected + req.ReviewedBy = &reviewedBy + req.ReviewedAt = &now + + if err := s.auditRepo.UpdateApproval(req); err != nil { + return nil, err + } + + return req, nil +} + +// GetApproval 获取审批状态 +func (s *ApprovalService) GetApproval(requestID string) (*model.ToolApprovalRequest, error) { + return s.auditRepo.FindApprovalByID(requestID) +} + +// GetPendingApprovals 获取待审批列表 +func (s *ApprovalService) GetPendingApprovals() ([]model.ToolApprovalRequest, error) { + return s.auditRepo.FindPendingApprovals() +} diff --git a/server/internal/service/auth_service.go b/server/internal/service/auth_service.go new file mode 100644 index 0000000..e8d1655 --- /dev/null +++ b/server/internal/service/auth_service.go @@ -0,0 +1,145 @@ +package service + +import ( + "errors" + "time" + + "x-agents/server/internal/model" + "x-agents/server/internal/repository" + + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "golang.org/x/crypto/bcrypt" +) + +var ( + ErrInvalidCredentials = errors.New("invalid credentials") + ErrUserNotFound = errors.New("user not found") +) + +type AuthService struct { + jwtSecret string + userRepo *repository.UserRepository +} + +func NewAuthService(jwtSecret string, userRepo *repository.UserRepository) *AuthService { + return &AuthService{ + jwtSecret: jwtSecret, + userRepo: userRepo, + } +} + +type LoginRequest struct { + Username string `json:"username" binding:"required"` + Password string `json:"password" binding:"required"` +} + +type LoginResponse struct { + Token string `json:"token"` + User *model.User `json:"user"` +} + +func (s *AuthService) Login(req LoginRequest) (*LoginResponse, error) { + // 查找用户 + user, err := s.userRepo.FindByUsername(req.Username) + if err != nil { + return nil, ErrInvalidCredentials + } + + // 验证密码 + if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.Password)); err != nil { + return nil, ErrInvalidCredentials + } + + // 生成Token + token, err := s.generateToken(user) + if err != nil { + return nil, err + } + + return &LoginResponse{ + Token: token, + User: user, + }, nil +} + +func (s *AuthService) generateToken(user *model.User) (string, error) { + claims := jwt.MapClaims{ + "sub": user.ID, + "username": user.Username, + "role": user.RoleID, + "exp": time.Now().Add(time.Hour * 24 * 7).Unix(), // 7天有效期 + "iat": time.Now().Unix(), + "expires_at": time.Now().Add(time.Hour * 24 * 7).Format(time.RFC3339), + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString([]byte(s.jwtSecret)) +} + +func (s *AuthService) ValidateToken(tokenString string) (jwt.MapClaims, error) { + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, errors.New("unexpected signing method") + } + return []byte(s.jwtSecret), nil + }) + + if err != nil { + return nil, err + } + + if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { + return claims, nil + } + + return nil, errors.New("invalid token") +} + +func (s *AuthService) Register(username, password, email string) (*model.User, error) { + // 检查用户是否已存在 + _, err := s.userRepo.FindByUsername(username) + if err == nil { + return nil, errors.New("user already exists") + } + + // 加密密码 + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return nil, err + } + + // 创建用户 + user := &model.User{ + ID: uuid.New().String(), + Username: username, + Password: string(hashedPassword), + Email: email, + RoleID: "user", + IsActive: true, + } + + // 如果没有用户,创建默认管理员角色 + role, err := s.userRepo.FindRoleByID(user.RoleID) + if err != nil { + // 创建默认角色 + role = &model.Role{ + ID: "user", + Name: "user", + Permissions: []model.PermissionLevel{model.PermissionRead, model.PermissionWrite}, + } + s.userRepo.CreateRole(role) + user.Role = role + } + + if err := s.userRepo.Create(user); err != nil { + return nil, err + } + + return user, nil +} + +// GetUserByID 根据ID获取用户 +func (s *AuthService) GetUserByID(id string) (*model.User, error) { + return s.userRepo.FindByID(id) +} diff --git a/server/internal/service/chat_service.go b/server/internal/service/chat_service.go new file mode 100644 index 0000000..ca11f68 --- /dev/null +++ b/server/internal/service/chat_service.go @@ -0,0 +1,146 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "x-agents/server/internal/model" + "x-agents/server/internal/repository" + + "github.com/google/uuid" +) + +type ChatService struct { + pythonURL string + agentRepo *repository.AgentRepository +} + +func NewChatService(pythonURL string, agentRepo *repository.AgentRepository) *ChatService { + return &ChatService{ + pythonURL: pythonURL, + agentRepo: agentRepo, + } +} + +type ChatRequest struct { + AgentID string `json:"agent_id"` + Message string `json:"message"` + SessionID string `json:"session_id"` + Context map[string]interface{} `json:"context"` +} + +type ChatResponse struct { + Reply string `json:"reply"` + SessionID string `json:"session_id"` + ToolsUsed []string `json:"tools_used"` + Metadata map[string]interface{} `json:"metadata"` +} + +// Chat 处理聊天请求 +func (s *ChatService) Chat(ctx context.Context, userID string, req model.AgentRequest) (*model.AgentResponse, error) { + // 1. 检查 Agent 是否存在 + agent, err := s.agentRepo.FindByID(req.AgentID) + if err != nil { + return nil, fmt.Errorf("agent not found: %w", err) + } + + // 2. 检查用户权限 + if !agent.IsActive { + return nil, fmt.Errorf("agent is not active") + } + + // 3. 生成会话ID + sessionID := req.SessionID + if sessionID == "" { + sessionID = uuid.New().String() + } + + // 4. 调用 Python 服务 + pythonReq := ChatRequest{ + AgentID: req.AgentID, + Message: req.Message, + SessionID: sessionID, + Context: req.Context, + } + + pythonResp, err := s.callPythonChat(ctx, pythonReq) + if err != nil { + return nil, fmt.Errorf("failed to call python service: %w", err) + } + + return &model.AgentResponse{ + Reply: pythonResp.Reply, + SessionID: pythonResp.SessionID, + ToolsUsed: pythonResp.ToolsUsed, + Metadata: pythonResp.Metadata, + }, nil +} + +func (s *ChatService) callPythonChat(ctx context.Context, req ChatRequest) (*ChatResponse, error) { + jsonData, err := json.Marshal(req) + if err != nil { + return nil, err + } + + httpReq, err := http.NewRequestWithContext( + ctx, + "POST", + s.pythonURL+"/agent/chat", + bytes.NewBuffer(jsonData), + ) + if err != nil { + return nil, err + } + + httpReq.Header.Set("Content-Type", "application/json") + + client := &http.Client{ + Timeout: 120 * time.Second, // Agent 可能需要较长时间 + } + + resp, err := client.Do(httpReq) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("python service returned status: %d", resp.StatusCode) + } + + var chatResp ChatResponse + if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil { + return nil, err + } + + return &chatResp, nil +} + +// ListAgents 获取用户可用的 Agent 列表 +func (s *ChatService) ListAgents(userID string) ([]model.Agent, error) { + return s.agentRepo.FindByOwnerID(userID) +} + +// CreateAgent 创建新的 Agent +func (s *ChatService) CreateAgent(userID string, name, description string) (*model.Agent, error) { + agent := &model.Agent{ + ID: uuid.New().String(), + Name: name, + Description: description, + OwnerID: userID, + SecurityLevel: model.SecurityLevelSafe, + IsActive: true, + Timeout: 60, + MemoryLimit: 134217728, // 128MB + } + + if err := s.agentRepo.Create(agent); err != nil { + return nil, err + } + + return agent, nil +} diff --git a/server/internal/service/database_service.go b/server/internal/service/database_service.go new file mode 100644 index 0000000..18ee629 --- /dev/null +++ b/server/internal/service/database_service.go @@ -0,0 +1,765 @@ +package service + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "log" + "os" + "strings" + "time" + + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + "x-agents/server/internal/model" + "x-agents/server/internal/repository" + + "github.com/google/uuid" +) + +var ( + ErrDatabaseNotFound = errors.New("database not found") + ErrDatabaseUnreachable = errors.New("database cannot be connected") +) + +type DatabaseService struct { + repo *repository.DatabaseRepository + subTableRepo *repository.SubTableRepository +} + +func NewDatabaseService(repo *repository.DatabaseRepository, subTableRepo *repository.SubTableRepository) *DatabaseService { + return &DatabaseService{ + repo: repo, + subTableRepo: subTableRepo, + } +} + +// TestConnection 测试数据库连通性 +func (s *DatabaseService) TestConnection(info *model.DatabaseInfo) error { + log.Printf("[数据库连接测试] 开始测试连接: 类型=%s, 主机=%s, 端口=%d, 数据库=%s, 用户=%s", + info.DBType, info.Host, info.Port, info.Database, info.Username) + + // 统一转换为小写处理 + dbType := strings.ToLower(info.DBType) + + // 构建连接字符串 + dsn := s.buildDSN(info) + log.Printf("[数据库连接测试] DSN构建完成: %s", dsn) + + // 设置超时 + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // 根据数据库类型连接 + var db *sql.DB + var err error + + switch dbType { + case "mysql": + db, err = sql.Open("mysql", dsn) + case "postgres", "postgresql": + db, err = sql.Open("postgres", dsn) + default: + errMsg := fmt.Sprintf("unsupported database type: %s", info.DBType) + log.Printf("[数据库连接测试] 错误: %s", errMsg) + return fmt.Errorf(errMsg) + } + + if err != nil { + errMsg := fmt.Sprintf("failed to create connection: %v", err) + log.Printf("[数据库连接测试] 错误: %s", errMsg) + return fmt.Errorf(errMsg) + } + defer db.Close() + + // 测试连接 + if err := db.PingContext(ctx); err != nil { + errMsg := fmt.Sprintf("cannot connect to database: %v", err) + log.Printf("[数据库连接测试] 连接失败: %s", errMsg) + return fmt.Errorf(errMsg) + } + + log.Printf("[数据库连接测试] 连接成功!") + return nil +} + +// buildDSN 构建数据库连接字符串 +func (s *DatabaseService) buildDSN(info *model.DatabaseInfo) string { + dbType := strings.ToLower(info.DBType) + switch dbType { + case "mysql": + charset := info.Charset + if charset == "" { + charset = "utf8mb4" + } + // 如果没有指定数据库名,只测试连接 + dbName := info.Database + if dbName == "" { + dbName = "mysql" + } + return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&timeout=5s&parseTime=True", + info.Username, + info.Password, + info.Host, + info.Port, + dbName, + charset, + ) + case "postgres", "postgresql": + sslmode := "disable" + if info.SSLMode != "" { + sslmode = info.SSLMode + } + return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s connect_timeout=5", + info.Host, + info.Port, + info.Username, + info.Password, + info.Database, + sslmode, + ) + default: + return "" + } +} + +// getConnection 获取数据库连接 +func (s *DatabaseService) getConnection(info *model.DatabaseInfo) (*sql.DB, error) { + dsn := s.buildDSN(info) + dbType := strings.ToLower(info.DBType) + + var db *sql.DB + var err error + + switch dbType { + case "mysql": + db, err = sql.Open("mysql", dsn) + case "postgres", "postgresql": + db, err = sql.Open("postgres", dsn) + default: + return nil, fmt.Errorf("unsupported database type: %s", dbType) + } + + if err != nil { + return nil, err + } + + if err := db.Ping(); err != nil { + return nil, err + } + + return db, nil +} + +// getTableDDL 获取表的 DDL +func (s *DatabaseService) getTableDDL(db *sql.DB, dbType, tableName string) (string, error) { + switch dbType { + case "mysql": + query := fmt.Sprintf("SHOW CREATE TABLE `%s`", tableName) + row := db.QueryRow(query) + var tblName, createStmt string + if err := row.Scan(&tblName, &createStmt); err != nil { + return "", err + } + return createStmt, nil + case "postgres", "postgresql": + query := fmt.Sprintf("SELECT pg_get_create('%s')", tableName) + var ddl string + if err := db.QueryRow(query).Scan(&ddl); err != nil { + return "", err + } + return ddl, nil + default: + return "", fmt.Errorf("unsupported database type: %s", dbType) + } +} + +// buildMappedDDL 根据字段映射生成带 COMMENT 的 DDL +func (s *DatabaseService) buildMappedDDL(originalDDL string, fields []model.FieldMapping) string { + // 构建列名到映射名的映射 + columnMap := make(map[string]string) + for _, f := range fields { + if f.MappedName != "" { + columnMap[f.ColumnName] = f.MappedName + } + } + + if len(columnMap) == 0 { + return originalDDL + } + + // 解析原始 DDL,为有映射的列添加 COMMENT + lines := strings.Split(originalDDL, "\n") + var resultLines []string + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + + // 检查是否是列定义行(以 ` 开头,包含数据类型) + if strings.HasPrefix(trimmed, "`") { + // 提取列名 + parts := strings.SplitN(trimmed, " ", 2) + if len(parts) >= 1 { + colName := strings.Trim(parts[0], "`") + + // 检查是否有映射 + if mappedName, ok := columnMap[colName]; ok { + // 去掉结尾的逗号(如果有) + trimmed = strings.TrimRight(trimmed, ",") + // 检查是否已经有 COMMENT + if strings.Contains(strings.ToUpper(trimmed), "COMMENT") { + // 替换已有的 COMMENT + trimmed = strings.TrimSuffix(trimmed, " COMMENT '...'") + trimmed = fmt.Sprintf("%s COMMENT '%s'", trimmed, mappedName) + } else { + // 在末尾添加 COMMENT + trimmed = fmt.Sprintf("%s COMMENT '%s'", trimmed, mappedName) + } + // 替换原始行为修改后的行 + resultLines = append(resultLines, trimmed) + continue + } + } + } + resultLines = append(resultLines, line) + } + + return strings.Join(resultLines, "\n") +} + +// Check 检查数据库连接 +func (s *DatabaseService) Check(req model.CheckRequest) (*model.CheckResponse, error) { + log.Printf("[Check] 开始检查连接: 类型=%s, 主机=%s, 端口=%d, 数据库=%s, 用户=%s", + req.DBType, req.Host, req.Port, req.Database, req.Username) + + info := &model.DatabaseInfo{ + DBType: req.DBType, + Host: req.Host, + Port: req.Port, + Username: req.Username, + Password: req.Password, + Database: req.Database, + Charset: req.Charset, + SSLMode: req.SSLMode, + } + + if info.Charset == "" { + info.Charset = "utf8mb4" + } + + // 构建连接 + dsn := s.buildDSN(info) + dbType := strings.ToLower(info.DBType) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + var db *sql.DB + var err error + + switch dbType { + case "mysql": + db, err = sql.Open("mysql", dsn) + case "postgres", "postgresql": + db, err = sql.Open("postgres", dsn) + default: + return &model.CheckResponse{ + Success: false, + Message: fmt.Sprintf("unsupported database type: %s", req.DBType), + }, nil + } + + if err != nil { + return &model.CheckResponse{ + Success: false, + Message: fmt.Sprintf("failed to create connection: %v", err), + }, nil + } + defer db.Close() + + // 测试连接 + if err := db.PingContext(ctx); err != nil { + log.Printf("[Check] 连接失败: %v", err) + return &model.CheckResponse{ + Success: false, + Message: fmt.Sprintf("cannot connect to database: %v", err), + }, nil + } + + log.Printf("[Check] 连接成功,开始获取表列表...") + + // 获取表列表 + var tables []model.TableDDLInfo + switch dbType { + case "mysql": + tables, _ = s.getMySQLTables(db, req.Database) + case "postgres", "postgresql": + tables, _ = s.getPostgresTables(db, req.Database) + } + + log.Printf("[Check] 获取到 %d 个表", len(tables)) + + // 如果传入了 database_id,获取已保存的字段映射和 DDL 并填充到表结构中 + if req.DatabaseID != "" && s.subTableRepo != nil { + s.fillFieldMappings(req.DatabaseID, tables) + s.fillDDL(req.DatabaseID, tables) + } + + return &model.CheckResponse{ + Success: true, + Message: "connection successful", + Tables: tables, + Database: req.Database, + }, nil +} + +// getMySQLTables 获取MySQL表结构 +func (s *DatabaseService) getMySQLTables(db *sql.DB, dbName string) ([]model.TableDDLInfo, error) { + rows, err := db.Query(` + SELECT TABLE_NAME, TABLE_COMMENT + FROM information_schema.TABLES + WHERE TABLE_SCHEMA = ? + AND TABLE_TYPE = 'BASE TABLE' + `, dbName) + if err != nil { + return nil, err + } + defer rows.Close() + + var tables []model.TableDDLInfo + for rows.Next() { + var tableName, tableComment string + if err := rows.Scan(&tableName, &tableComment); err != nil { + continue + } + + table := model.TableDDLInfo{ + TableName: tableName, + TableComment: tableComment, + } + + // 获取列信息 + table.Columns, _ = s.getMySQLColumns(db, dbName, tableName) + + // 获取 DDL + table.DDL, _ = s.getMySQLDDL(db, tableName) + + tables = append(tables, table) + } + + return tables, nil +} + +// getMySQLDDL 获取 MySQL 表的 DDL +func (s *DatabaseService) getMySQLDDL(db *sql.DB, tableName string) (string, error) { + // 使用反引号包裹表名,防止关键字冲突 + query := fmt.Sprintf("SHOW CREATE TABLE `%s`", tableName) + row := db.QueryRow(query) + var tblName, createStmt string + if err := row.Scan(&tblName, &createStmt); err != nil { + log.Printf("[getMySQLDDL] 获取 DDL 失败: %v", err) + return "", nil + } + return createStmt, nil +} + +// getMySQLColumns 获取MySQL列信息 +func (s *DatabaseService) getMySQLColumns(db *sql.DB, dbName, tableName string) ([]model.ColumnInfo, error) { + rows, err := db.Query(` + SELECT + COLUMN_NAME, DATA_TYPE, COLUMN_TYPE, + IS_NULLABLE, COLUMN_DEFAULT, COLUMN_KEY, + EXTRA, COLUMN_COMMENT + FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? + ORDER BY ORDINAL_POSITION + `, dbName, tableName) + if err != nil { + log.Printf("[getMySQLColumns] 查询列信息失败: %v", err) + return nil, err + } + defer rows.Close() + + columns := make([]model.ColumnInfo, 0) + for rows.Next() { + var col model.ColumnInfo + var defaultValue, extra, columnComment sql.NullString + if err := rows.Scan(&col.ColumnName, &col.DataType, &col.ColumnType, + &col.IsNullable, &defaultValue, &col.ColumnKey, &extra, &columnComment); err != nil { + log.Printf("[getMySQLColumns] Scan 失败: %v", err) + continue + } + col.DefaultValue = defaultValue.String + col.Extra = extra.String + col.ColumnComment = columnComment.String + columns = append(columns, col) + } + + // 检查是否有迭代错误 + if err := rows.Err(); err != nil { + log.Printf("[getMySQLColumns] 迭代错误: %v", err) + } + + return columns, nil +} + +// getPostgresTables 获取PostgreSQL表结构 +func (s *DatabaseService) getPostgresTables(db *sql.DB, dbName string) ([]model.TableDDLInfo, error) { + rows, err := db.Query(` + SELECT t.table_name, obj_description((t.table_schema || '.' || t.table_name)::regclass) + FROM information_schema.tables t + WHERE t.table_schema = 'public' AND t.table_type = 'BASE TABLE' + `, dbName) + if err != nil { + return nil, err + } + defer rows.Close() + + var tables []model.TableDDLInfo + for rows.Next() { + var tableName, tableComment string + if err := rows.Scan(&tableName, &tableComment); err != nil { + continue + } + + table := model.TableDDLInfo{ + TableName: tableName, + TableComment: tableComment, + } + + // 获取列信息 + table.Columns, _ = s.getPostgresColumns(db, tableName) + + // 获取 DDL + table.DDL, _ = s.getPostgresDDL(db, tableName) + + tables = append(tables, table) + } + + return tables, nil +} + +// getPostgresDDL 获取 PostgreSQL 表的 DDL +func (s *DatabaseService) getPostgresDDL(db *sql.DB, tableName string) (string, error) { + var ddl string + query := fmt.Sprintf("SELECT pg_get_create('%s')", tableName) + row := db.QueryRow(query) + if err := row.Scan(&ddl); err != nil { + log.Printf("[getPostgresDDL] 获取 DDL 失败: %v", err) + return "", nil + } + return ddl, nil +} + +// getPostgresColumns 获取PostgreSQL列信息 +func (s *DatabaseService) getPostgresColumns(db *sql.DB, tableName string) ([]model.ColumnInfo, error) { + rows, err := db.Query(` + SELECT + c.column_name, c.data_type, c.udt_name, + c.is_nullable, c.column_default, c.column_name, + '', c.column_comment + FROM information_schema.columns c + WHERE c.table_name = $1 AND c.table_schema = 'public' + ORDER BY c.ordinal_position + `, tableName) + if err != nil { + return nil, err + } + defer rows.Close() + + var columns []model.ColumnInfo + for rows.Next() { + var col model.ColumnInfo + if err := rows.Scan(&col.ColumnName, &col.DataType, &col.ColumnType, + &col.IsNullable, &col.DefaultValue, &col.ColumnKey, &col.Extra, &col.ColumnComment); err != nil { + continue + } + columns = append(columns, col) + } + + return columns, nil +} + +// Create 创建数据库信息(支持同时保存子表配置) +func (s *DatabaseService) Create(req model.CreateDatabaseRequest) (*model.DatabaseInfo, error) { + log.Printf("[Create] 收到创建请求: %+v", req) + + info := &model.DatabaseInfo{ + ID: uuid.New().String(), + Name: req.Name, + Description: req.Description, + DBType: strings.ToLower(req.DBType), // 统一转为小写 + Host: req.Host, + Port: req.Port, + Username: req.Username, + Password: req.Password, + Database: req.Database, + Charset: req.Charset, + SSLMode: req.SSLMode, + TableCount: len(req.SubTables), + } + + // 默认值 + if info.Charset == "" { + info.Charset = "utf8mb4" + } + + // 测试数据库连通性 + log.Printf("[Create] 开始测试数据库连接...") + if err := s.TestConnection(info); err != nil { + log.Printf("[Create] 数据库连接测试失败: %v", err) + return nil, fmt.Errorf("database connection failed: %v", err) + } + log.Printf("[Create] 数据库连接测试成功!") + + // 保存数据库信息 + if err := s.repo.Create(info); err != nil { + log.Printf("[Create] 保存数据库失败: %v", err) + return nil, err + } + + // 保存子表配置(如有) + if len(req.SubTables) > 0 && s.subTableRepo != nil { + log.Printf("[Create] 保存 %d 个子表配置", len(req.SubTables)) + + // 获取数据库连接用于查询 DDL + db, err := s.getConnection(info) + if err != nil { + log.Printf("[Create] 获取数据库连接失败: %v", err) + } else { + defer db.Close() + } + + for _, subReq := range req.SubTables { + subTable := &model.SubTableInfo{ + ID: uuid.New().String(), + DatabaseID: info.ID, + ParentTable: subReq.ParentTable, + SubTableName: subReq.SubTableName, + SubTableComment: subReq.SubTableComment, + MappingType: subReq.MappingType, + RelationField: subReq.RelationField, + RelationType: subReq.RelationType, + } + // 使用 SetFields 方法保存字段映射 + subTable.SetFields(subReq.Fields) + + // 获取并保存 DDL + if db != nil { + ddl, err := s.getTableDDL(db, strings.ToLower(info.DBType), subReq.ParentTable) + if err != nil { + log.Printf("[Create] 获取原始 DDL 失败: %v", err) + } else { + // 如果有字段映射,生成带 COMMENT 的新 DDL + if len(subReq.Fields) > 0 { + subTable.DDL = s.buildMappedDDL(ddl, subReq.Fields) + log.Printf("[Create] 生成映射后的 DDL,长度: %d", len(subTable.DDL)) + } else { + subTable.DDL = ddl + } + } + } + + if err := s.subTableRepo.Create(subTable); err != nil { + log.Printf("[Create] 保存子表失败: %v", err) + } + } + + // 同步到文件 + s.syncSubTablesToFile(info) + } + + log.Printf("[Create] 创建成功, ID=%s", info.ID) + return info, nil +} + +// syncSubTablesToFile 同步子表到文件 +func (s *DatabaseService) syncSubTablesToFile(info *model.DatabaseInfo) { + if s.subTableRepo == nil { + return + } + + tables, err := s.subTableRepo.FindByDatabaseID(info.ID) + if err != nil { + log.Printf("[syncSubTablesToFile] 查询子表失败: %v", err) + return + } + + mapping := &model.SubTableMapping{ + DatabaseID: info.ID, + DatabaseName: info.Name, + DBType: info.DBType, + Tables: tables, + UpdatedAt: time.Now(), + } + + resourceDir := "resources/db_info" + os.MkdirAll(resourceDir, 0755) + + data, err := json.MarshalIndent(mapping, "", " ") + if err != nil { + log.Printf("[syncSubTablesToFile] 序列化失败: %v", err) + return + } + + filePath := fmt.Sprintf("%s/%s.json", resourceDir, info.ID) + if err := os.WriteFile(filePath, data, 0644); err != nil { + log.Printf("[syncSubTablesToFile] 写入文件失败: %v", err) + } + + log.Printf("[syncSubTablesToFile] 同步成功: %s", filePath) +} + +// GetByID 获取详情 +func (s *DatabaseService) GetByID(id string) (*model.DatabaseInfo, error) { + log.Printf("[GetByID] 查询 ID=%s", id) + info, err := s.repo.FindByID(id) + if err != nil { + log.Printf("[GetByID] 查询失败: %v", err) + return nil, ErrDatabaseNotFound + } + return info, nil +} + +// List 获取列表 +func (s *DatabaseService) List() ([]model.DatabaseInfo, error) { + log.Printf("[List] 查询所有数据库列表") + return s.repo.FindAll() +} + +// Update 更新 +func (s *DatabaseService) Update(id string, req model.UpdateDatabaseRequest) (*model.DatabaseInfo, error) { + log.Printf("[Update] 更新 ID=%s, 数据=%+v", id, req) + // 检查是否存在 + _, err := s.repo.FindByID(id) + if err != nil { + log.Printf("[Update] 不存在: %v", err) + return nil, ErrDatabaseNotFound + } + + // 构建更新数据 + updates := map[string]interface{}{} + if req.Name != "" { + updates["name"] = req.Name + } + if req.Description != "" { + updates["description"] = req.Description + } + if req.DBType != "" { + updates["db_type"] = req.DBType + } + if req.Host != "" { + updates["host"] = req.Host + } + if req.Port > 0 { + updates["port"] = req.Port + } + if req.Username != "" { + updates["username"] = req.Username + } + if req.Password != "" { + updates["password"] = req.Password + } + if req.Database != "" { + updates["database"] = req.Database + } + if req.TableCount > 0 { + updates["table_count"] = req.TableCount + } + if req.Charset != "" { + updates["charset"] = req.Charset + } + if req.SSLMode != "" { + updates["ssl_mode"] = req.SSLMode + } + + info := &model.DatabaseInfo{} + if err := s.repo.Update(id, info); err != nil { + log.Printf("[Update] 更新失败: %v", err) + return nil, err + } + + return s.repo.FindByID(id) +} + +// fillFieldMappings 填充字段映射到表结构中 +func (s *DatabaseService) fillFieldMappings(databaseID string, tables []model.TableDDLInfo) { + // 从数据库中获取该数据库下所有子表的字段映射 + subTables, err := s.subTableRepo.FindByDatabaseID(databaseID) + if err != nil { + log.Printf("[fillFieldMappings] 查询子表失败: %v", err) + return + } + + // 构建表名到字段映射的映射 + tableFieldsMap := make(map[string][]model.FieldMapping) + for _, st := range subTables { + fields := st.GetFields() + if len(fields) > 0 { + tableFieldsMap[st.ParentTable] = fields + } + } + + // 遍历返回的表结构,填充字段映射 + for i := range tables { + tableName := tables[i].TableName + if fields, ok := tableFieldsMap[tableName]; ok { + // 构建列名到映射名的映射 + columnMap := make(map[string]string) + for _, f := range fields { + columnMap[f.ColumnName] = f.MappedName + } + + // 填充到每个列 + for j := range tables[i].Columns { + colName := tables[i].Columns[j].ColumnName + if mappedName, ok := columnMap[colName]; ok { + tables[i].Columns[j].MappedName = mappedName + } + } + } + } + + log.Printf("[fillFieldMappings] 已填充字段映射到 %d 个表", len(tables)) +} + +// fillDDL 填充已保存的 DDL 到表结构中 +func (s *DatabaseService) fillDDL(databaseID string, tables []model.TableDDLInfo) { + // 从数据库中获取该数据库下所有子表的 DDL + subTables, err := s.subTableRepo.FindByDatabaseID(databaseID) + if err != nil { + log.Printf("[fillDDL] 查询子表失败: %v", err) + return + } + + // 构建表名到 DDL 的映射 + tableDDLMap := make(map[string]string) + for _, st := range subTables { + if st.DDL != "" { + tableDDLMap[st.ParentTable] = st.DDL + } + } + + // 遍历返回的表结构,填充 DDL + for i := range tables { + tableName := tables[i].TableName + if ddl, ok := tableDDLMap[tableName]; ok { + tables[i].DDL = ddl + } + } + + log.Printf("[fillDDL] 已填充 DDL 到 %d 个表", len(tables)) +} + +// Delete 删除 +func (s *DatabaseService) Delete(id string) error { + log.Printf("[Delete] 删除 ID=%s", id) + _, err := s.repo.FindByID(id) + if err != nil { + log.Printf("[Delete] 不存在: %v", err) + return ErrDatabaseNotFound + } + return s.repo.Delete(id) +} diff --git a/server/internal/service/sub_table_service.go b/server/internal/service/sub_table_service.go new file mode 100644 index 0000000..c791528 --- /dev/null +++ b/server/internal/service/sub_table_service.go @@ -0,0 +1,602 @@ +package service + +import ( + "database/sql" + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + "strings" + "time" + + "x-agents/server/internal/model" + "x-agents/server/internal/repository" + + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + "github.com/google/uuid" +) + +type SubTableService struct { + repo *repository.SubTableRepository + dbRepo *repository.DatabaseRepository + resourceDir string +} + +func NewSubTableService(repo *repository.SubTableRepository, dbRepo *repository.DatabaseRepository) *SubTableService { + return &SubTableService{ + repo: repo, + dbRepo: dbRepo, + resourceDir: "resources/db_info", + } +} + +// ensureDir 确保目录存在 +func (s *SubTableService) ensureDir() error { + return os.MkdirAll(s.resourceDir, 0755) +} + +// getFilePath 获取文件路径 +func (s *SubTableService) getFilePath(databaseID string) string { + return filepath.Join(s.resourceDir, fmt.Sprintf("%s.json", databaseID)) +} + +// saveToFile 保存到文件 +func (s *SubTableService) saveToFile(databaseID string, mapping *model.SubTableMapping) error { + if err := s.ensureDir(); err != nil { + return err + } + + data, err := json.MarshalIndent(mapping, "", " ") + if err != nil { + return err + } + + return os.WriteFile(s.getFilePath(databaseID), data, 0644) +} + +// loadFromFile 从文件加载 +func (s *SubTableService) loadFromFile(databaseID string) (*model.SubTableMapping, error) { + filePath := s.getFilePath(databaseID) + data, err := os.ReadFile(filePath) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + + var mapping model.SubTableMapping + if err := json.Unmarshal(data, &mapping); err != nil { + return nil, err + } + return &mapping, nil +} + +// syncToFile 同步到文件 +func (s *SubTableService) syncToFile(databaseID string) error { + // 获取数据库信息 + dbInfo, err := s.dbRepo.FindByID(databaseID) + if err != nil { + return err + } + + // 获取所有子表信息 + tables, err := s.repo.FindByDatabaseID(databaseID) + if err != nil { + return err + } + + mapping := &model.SubTableMapping{ + DatabaseID: databaseID, + DatabaseName: dbInfo.Name, + DBType: dbInfo.DBType, + Tables: tables, + UpdatedAt: time.Now(), + } + + return s.saveToFile(databaseID, mapping) +} + +// Create 创建子表信息 +func (s *SubTableService) Create(req model.CreateSubTableRequest) (*model.SubTableInfo, error) { + log.Printf("[SubTable Create] 收到请求: %+v", req) + + // 验证数据库是否存在 + _, err := s.dbRepo.FindByID(req.DatabaseID) + if err != nil { + log.Printf("[SubTable Create] 数据库不存在: %v", err) + return nil, fmt.Errorf("database not found") + } + + info := &model.SubTableInfo{ + ID: uuid.New().String(), + DatabaseID: req.DatabaseID, + ParentTable: req.ParentTable, + SubTableName: req.SubTableName, + SubTableComment: req.SubTableComment, + MappingType: req.MappingType, + RelationField: req.RelationField, + RelationType: req.RelationType, + } + + if err := s.repo.Create(info); err != nil { + log.Printf("[SubTable Create] 创建失败: %v", err) + return nil, err + } + + // 同步到文件 + if err := s.syncToFile(req.DatabaseID); err != nil { + log.Printf("[SubTable Create] 同步文件失败: %v", err) + } + + log.Printf("[SubTable Create] 创建成功, ID=%s", info.ID) + return info, nil +} + +// GetByID 获取详情 +func (s *SubTableService) GetByID(id string) (*model.SubTableInfo, error) { + log.Printf("[SubTable GetByID] 查询 ID=%s", id) + info, err := s.repo.FindByID(id) + if err != nil { + log.Printf("[SubTable GetByID] 查询失败: %v", err) + return nil, fmt.Errorf("sub table not found") + } + return info, nil +} + +// ListByDatabaseID 获取数据库下所有子表 +func (s *SubTableService) ListByDatabaseID(databaseID string) ([]model.SubTableInfo, error) { + log.Printf("[SubTable ListByDatabaseID] 查询数据库ID=%s", databaseID) + tables, err := s.repo.FindByDatabaseID(databaseID) + if err != nil { + return nil, err + } + + // 填充 FieldsList 字段 + for i := range tables { + tables[i].FieldsList = tables[i].GetFields() + } + + return tables, nil +} + +// GetMappingFromFile 从文件获取映射信息 +func (s *SubTableService) GetMappingFromFile(databaseID string) (*model.SubTableMapping, error) { + log.Printf("[SubTable GetMappingFromFile] 读取文件, databaseID=%s", databaseID) + return s.loadFromFile(databaseID) +} + +// Update 更新 +func (s *SubTableService) Update(id string, req model.UpdateSubTableRequest) (*model.SubTableInfo, error) { + log.Printf("[SubTable Update] 更新 ID=%s, 数据=%+v", id, req) + + info, err := s.repo.FindByID(id) + if err != nil { + log.Printf("[SubTable Update] 不存在: %v", err) + return nil, fmt.Errorf("sub table not found") + } + + if req.ParentTable != "" { + info.ParentTable = req.ParentTable + } + if req.SubTableName != "" { + info.SubTableName = req.SubTableName + } + if req.SubTableComment != "" { + info.SubTableComment = req.SubTableComment + } + if req.MappingType != "" { + info.MappingType = req.MappingType + } + if req.RelationField != "" { + info.RelationField = req.RelationField + } + if req.RelationType != "" { + info.RelationType = req.RelationType + } + + if err := s.repo.Update(id, info); err != nil { + log.Printf("[SubTable Update] 更新失败: %v", err) + return nil, err + } + + // 同步到文件 + if err := s.syncToFile(info.DatabaseID); err != nil { + log.Printf("[SubTable Update] 同步文件失败: %v", err) + } + + return info, nil +} + +// Delete 删除 +func (s *SubTableService) Delete(id string) error { + log.Printf("[SubTable Delete] 删除 ID=%s", id) + + info, err := s.repo.FindByID(id) + if err != nil { + log.Printf("[SubTable Delete] 不存在: %v", err) + return fmt.Errorf("sub table not found") + } + + databaseID := info.DatabaseID + + if err := s.repo.Delete(id); err != nil { + log.Printf("[SubTable Delete] 删除失败: %v", err) + return err + } + + // 同步到文件 + if err := s.syncToFile(databaseID); err != nil { + log.Printf("[SubTable Delete] 同步文件失败: %v", err) + } + + return nil +} + +// GetTableDDLFromDatabase 从实际数据库获取表结构和DDL +func (s *SubTableService) GetTableDDLFromDatabase(databaseID string) ([]model.TableDDLInfo, error) { + log.Printf("[GetTableDDLFromDatabase] 获取数据库ID=%s 的表结构", databaseID) + + // 获取数据库连接信息 + dbInfo, err := s.dbRepo.FindByID(databaseID) + if err != nil { + log.Printf("[GetTableDDLFromDatabase] 数据库不存在: %v", err) + return nil, fmt.Errorf("database not found") + } + + // 构建连接 + dsn := s.buildDSN(dbInfo) + dbType := strings.ToLower(dbInfo.DBType) + + var db *sql.DB + switch dbType { + case "mysql": + db, err = sql.Open("mysql", dsn) + case "postgres", "postgresql": + db, err = sql.Open("postgres", dsn) + default: + return nil, fmt.Errorf("unsupported database type: %s", dbInfo.DBType) + } + if err != nil { + return nil, fmt.Errorf("failed to connect: %v", err) + } + defer db.Close() + + // 获取所有表 + var tables []model.TableDDLInfo + switch dbType { + case "mysql": + tables, err = s.getMySQLTables(db, dbInfo.Database) + case "postgres", "postgresql": + tables, err = s.getPostgresTables(db, dbInfo.Database) + } + if err != nil { + return nil, err + } + + log.Printf("[GetTableDDLFromDatabase] 获取到 %d 个表", len(tables)) + return tables, nil +} + +// buildDSN 构建数据库连接字符串 +func (s *SubTableService) buildDSN(info *model.DatabaseInfo) string { + dbType := strings.ToLower(info.DBType) + switch dbType { + case "mysql": + charset := info.Charset + if charset == "" { + charset = "utf8mb4" + } + dbName := info.Database + if dbName == "" { + dbName = "mysql" + } + return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&timeout=5s&parseTime=True", + info.Username, info.Password, info.Host, info.Port, dbName, charset) + case "postgres", "postgresql": + sslmode := "disable" + if info.SSLMode != "" { + sslmode = info.SSLMode + } + return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s connect_timeout=5", + info.Host, info.Port, info.Username, info.Password, info.Database, sslmode) + } + return "" +} + +// getMySQLTables 获取MySQL表结构 +func (s *SubTableService) getMySQLTables(db *sql.DB, dbName string) ([]model.TableDDLInfo, error) { + // 获取所有表名和注释 + rows, err := db.Query(` + SELECT TABLE_NAME, TABLE_COMMENT + FROM information_schema.TABLES + WHERE TABLE_SCHEMA = ? + AND TABLE_TYPE = 'BASE TABLE' + `, dbName) + if err != nil { + return nil, err + } + defer rows.Close() + + var tables []model.TableDDLInfo + for rows.Next() { + var tableName, tableComment string + if err := rows.Scan(&tableName, &tableComment); err != nil { + continue + } + + table := model.TableDDLInfo{ + TableName: tableName, + TableComment: tableComment, + } + + // 获取列信息 + table.Columns, _ = s.getMySQLColumns(db, dbName, tableName) + + // 获取索引信息 + table.Indexes, _ = s.getMySQLIndexes(db, dbName, tableName) + + // 生成DDL + table.DDL = s.generateMySQLDDL(table) + + tables = append(tables, table) + } + + return tables, nil +} + +// getMySQLColumns 获取MySQL列信息 +func (s *SubTableService) getMySQLColumns(db *sql.DB, dbName, tableName string) ([]model.ColumnInfo, error) { + rows, err := db.Query(` + SELECT + COLUMN_NAME, DATA_TYPE, COLUMN_TYPE, + IS_NULLABLE, COLUMN_DEFAULT, COLUMN_KEY, + EXTRA, COLUMN_COMMENT + FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? + ORDER BY ORDINAL_POSITION + `, dbName, tableName) + if err != nil { + return nil, err + } + defer rows.Close() + + var columns []model.ColumnInfo + for rows.Next() { + var col model.ColumnInfo + if err := rows.Scan(&col.ColumnName, &col.DataType, &col.ColumnType, + &col.IsNullable, &col.DefaultValue, &col.ColumnKey, &col.Extra, &col.ColumnComment); err != nil { + continue + } + columns = append(columns, col) + } + + return columns, nil +} + +// getMySQLIndexes 获取MySQL索引信息 +func (s *SubTableService) getMySQLIndexes(db *sql.DB, dbName, tableName string) ([]model.IndexInfo, error) { + rows, err := db.Query(` + SELECT INDEX_NAME, COLUMN_NAME, NON_UNIQUE, INDEX_TYPE + FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? + ORDER BY SEQ_IN_INDEX + `, dbName, tableName) + if err != nil { + return nil, err + } + defer rows.Close() + + var indexes []model.IndexInfo + for rows.Next() { + var idx model.IndexInfo + if err := rows.Scan(&idx.IndexName, &idx.ColumnName, &idx.NonUnique, &idx.IndexType); err != nil { + continue + } + indexes = append(indexes, idx) + } + + return indexes, nil +} + +// generateMySQLDDL 生成MySQL DDL +func (s *SubTableService) generateMySQLDDL(table model.TableDDLInfo) string { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("CREATE TABLE `%s` (\n", table.TableName)) + + for i, col := range table.Columns { + sb.WriteString(fmt.Sprintf(" `%s` %s", col.ColumnName, col.ColumnType)) + + if col.IsNullable == "NO" { + sb.WriteString(" NOT NULL") + } + if col.DefaultValue != "" { + sb.WriteString(fmt.Sprintf(" DEFAULT %s", col.DefaultValue)) + } + if col.Extra == "auto_increment" { + sb.WriteString(" AUTO_INCREMENT") + } + if col.ColumnComment != "" { + sb.WriteString(fmt.Sprintf(" COMMENT '%s'", col.ColumnComment)) + } + + if i < len(table.Columns)-1 { + sb.WriteString(",") + } + sb.WriteString("\n") + } + + // 添加主键 + var primaryKeys []string + for _, idx := range table.Indexes { + if idx.IndexName == "PRIMARY" { + primaryKeys = append(primaryKeys, fmt.Sprintf("`%s`", idx.ColumnName)) + } + } + if len(primaryKeys) > 0 { + sb.WriteString(fmt.Sprintf(" PRIMARY KEY (%s)\n", strings.Join(primaryKeys, ", "))) + } + + // 添加索引 + var addedIndexes []string + for _, idx := range table.Indexes { + if idx.IndexName != "PRIMARY" { + unique := "" + if idx.NonUnique == 0 { + unique = "UNIQUE " + } + if !contains(addedIndexes, idx.IndexName) { + sb.WriteString(fmt.Sprintf(" %sKEY `%s` (`%s`),\n", unique, idx.IndexName, idx.ColumnName)) + addedIndexes = append(addedIndexes, idx.IndexName) + } + } + } + + ddl := sb.String() + ddl = strings.TrimSuffix(ddl, ",\n") + ddl += "\n)" + + if table.TableComment != "" { + ddl += fmt.Sprintf(" COMMENT='%s'", table.TableComment) + } + ddl += ";\n" + + return ddl +} + +// contains 检查切片是否包含元素 +func contains(slice []string, item string) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} + +// getPostgresTables 获取PostgreSQL表结构 +func (s *SubTableService) getPostgresTables(db *sql.DB, dbName string) ([]model.TableDDLInfo, error) { + rows, err := db.Query(` + SELECT t.table_name, obj_description((t.table_schema || '.' || t.table_name)::regclass) + FROM information_schema.tables t + WHERE t.table_schema = 'public' AND t.table_type = 'BASE TABLE' + `, dbName) + if err != nil { + return nil, err + } + defer rows.Close() + + var tables []model.TableDDLInfo + for rows.Next() { + var tableName, tableComment string + if err := rows.Scan(&tableName, &tableComment); err != nil { + continue + } + + table := model.TableDDLInfo{ + TableName: tableName, + TableComment: tableComment, + } + + // 获取列信息 + table.Columns, _ = s.getPostgresColumns(db, tableName) + + // 获取索引信息 + table.Indexes, _ = s.getPostgresIndexes(db, tableName) + + // 生成DDL + table.DDL = s.generatePostgresDDL(table) + + tables = append(tables, table) + } + + return tables, nil +} + +// getPostgresColumns 获取PostgreSQL列信息 +func (s *SubTableService) getPostgresColumns(db *sql.DB, tableName string) ([]model.ColumnInfo, error) { + rows, err := db.Query(` + SELECT + c.column_name, c.data_type, c.udt_name, + c.is_nullable, c.column_default, c.column_name, + '', c.column_comment + FROM information_schema.columns c + WHERE c.table_name = $1 AND c.table_schema = 'public' + ORDER BY c.ordinal_position + `, tableName) + if err != nil { + return nil, err + } + defer rows.Close() + + var columns []model.ColumnInfo + for rows.Next() { + var col model.ColumnInfo + if err := rows.Scan(&col.ColumnName, &col.DataType, &col.ColumnType, + &col.IsNullable, &col.DefaultValue, &col.ColumnKey, &col.Extra, &col.ColumnComment); err != nil { + continue + } + columns = append(columns, col) + } + + return columns, nil +} + +// getPostgresIndexes 获取PostgreSQL索引信息 +func (s *SubTableService) getPostgresIndexes(db *sql.DB, tableName string) ([]model.IndexInfo, error) { + rows, err := db.Query(` + SELECT indexname, indexdef + FROM pg_indexes + WHERE tablename = $1 AND schemaname = 'public' + `, tableName) + if err != nil { + return nil, err + } + defer rows.Close() + + var indexes []model.IndexInfo + for rows.Next() { + var idx model.IndexInfo + var indexDef string + if err := rows.Scan(&idx.IndexName, &indexDef); err != nil { + continue + } + idx.NonUnique = 1 + if strings.Contains(indexDef, "UNIQUE") { + idx.NonUnique = 0 + } + indexes = append(indexes, idx) + } + + return indexes, nil +} + +// generatePostgresDDL 生成PostgreSQL DDL +func (s *SubTableService) generatePostgresDDL(table model.TableDDLInfo) string { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("CREATE TABLE %s (\n", table.TableName)) + + for i, col := range table.Columns { + sb.WriteString(fmt.Sprintf(" %s %s", col.ColumnName, col.ColumnType)) + + if col.IsNullable == "NO" { + sb.WriteString(" NOT NULL") + } + if col.DefaultValue != "" { + sb.WriteString(fmt.Sprintf(" DEFAULT %s", col.DefaultValue)) + } + + if i < len(table.Columns)-1 { + sb.WriteString(",") + } + sb.WriteString("\n") + } + + ddl := sb.String() + ddl = strings.TrimSuffix(ddl, ",\n") + ddl += "\n);" + + return ddl +} diff --git a/server/main.exe b/server/main.exe new file mode 100644 index 0000000..94115af Binary files /dev/null and b/server/main.exe differ diff --git a/server/resources/db_info/053388bf-d0c3-4cd9-b78f-539858705a65.json b/server/resources/db_info/053388bf-d0c3-4cd9-b78f-539858705a65.json new file mode 100644 index 0000000..8816488 --- /dev/null +++ b/server/resources/db_info/053388bf-d0c3-4cd9-b78f-539858705a65.json @@ -0,0 +1,7 @@ +{ + "database_id": "053388bf-d0c3-4cd9-b78f-539858705a65", + "database_name": "test-db", + "db_type": "mysql", + "tables": [], + "updated_at": "2026-03-06T15:46:22.8598923+08:00" +} \ No newline at end of file diff --git a/server/resources/db_info/101fbee1-8400-46ae-b83b-e3898e4888b6.json b/server/resources/db_info/101fbee1-8400-46ae-b83b-e3898e4888b6.json new file mode 100644 index 0000000..570515e --- /dev/null +++ b/server/resources/db_info/101fbee1-8400-46ae-b83b-e3898e4888b6.json @@ -0,0 +1,22 @@ +{ + "database_id": "101fbee1-8400-46ae-b83b-e3898e4888b6", + "database_name": "123", + "db_type": "mysql", + "tables": [ + { + "id": "042db4ca-512f-4ee9-aacb-2d7ff1bc2193", + "database_id": "101fbee1-8400-46ae-b83b-e3898e4888b6", + "parent_table": "scores", + "sub_table_name": "scores", + "sub_table_comment": "", + "mapping_type": "", + "relation_field": "", + "relation_type": "", + "fields": null, + "ddl": "CREATE TABLE `scores` (\n`id` int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT '分数id'\n `student_id` int(10) unsigned NOT NULL,\n `subject` varchar(50) NOT NULL COMMENT '科目',\n `score` double DEFAULT NULL COMMENT '分数',\n `teacher_id` int(10) unsigned DEFAULT NULL,\n `exam_date` date DEFAULT NULL COMMENT '考试日期',\n `created_at` datetime DEFAULT CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB AUTO_INCREMENT=48 DEFAULT CHARSET=utf8mb4", + "created_at": "2026-03-06T16:23:31.097+08:00", + "updated_at": "2026-03-06T16:23:31.097+08:00" + } + ], + "updated_at": "2026-03-06T16:23:31.1477776+08:00" +} \ No newline at end of file diff --git a/server/resources/db_info/456b6a60-c5a5-46e4-8f5e-9c07c4c08510.json b/server/resources/db_info/456b6a60-c5a5-46e4-8f5e-9c07c4c08510.json new file mode 100644 index 0000000..883d481 --- /dev/null +++ b/server/resources/db_info/456b6a60-c5a5-46e4-8f5e-9c07c4c08510.json @@ -0,0 +1,20 @@ +{ + "database_id": "456b6a60-c5a5-46e4-8f5e-9c07c4c08510", + "database_name": "123", + "db_type": "mysql", + "tables": [ + { + "id": "8b7f6a2f-3788-4499-8d3a-fe9140ccdfe1", + "database_id": "456b6a60-c5a5-46e4-8f5e-9c07c4c08510", + "parent_table": "scores", + "sub_table_name": "scores", + "sub_table_comment": "", + "mapping_type": "", + "relation_field": "", + "relation_type": "", + "created_at": "2026-03-06T15:12:45.607+08:00", + "updated_at": "2026-03-06T15:12:45.607+08:00" + } + ], + "updated_at": "2026-03-06T15:12:45.6597943+08:00" +} \ No newline at end of file diff --git a/server/resources/db_info/58f7171d-6906-4f85-b27a-20bb2f982fc4.json b/server/resources/db_info/58f7171d-6906-4f85-b27a-20bb2f982fc4.json new file mode 100644 index 0000000..970dbed --- /dev/null +++ b/server/resources/db_info/58f7171d-6906-4f85-b27a-20bb2f982fc4.json @@ -0,0 +1,22 @@ +{ + "database_id": "58f7171d-6906-4f85-b27a-20bb2f982fc4", + "database_name": "123", + "db_type": "mysql", + "tables": [ + { + "id": "12298a11-fe00-4e6a-a37e-e4c0b5de6a51", + "database_id": "58f7171d-6906-4f85-b27a-20bb2f982fc4", + "parent_table": "scores", + "sub_table_name": "scores", + "sub_table_comment": "", + "mapping_type": "", + "relation_field": "", + "relation_type": "", + "fields": null, + "ddl": "CREATE TABLE `scores` (\n`id` int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT 'fenshu id'\n `student_id` int(10) unsigned NOT NULL,\n `subject` varchar(50) NOT NULL COMMENT '科目',\n `score` double DEFAULT NULL COMMENT '分数',\n `teacher_id` int(10) unsigned DEFAULT NULL,\n `exam_date` date DEFAULT NULL COMMENT '考试日期',\n `created_at` datetime DEFAULT CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB AUTO_INCREMENT=48 DEFAULT CHARSET=utf8mb4", + "created_at": "2026-03-06T16:26:15.44+08:00", + "updated_at": "2026-03-06T16:26:15.44+08:00" + } + ], + "updated_at": "2026-03-06T16:26:15.4936638+08:00" +} \ No newline at end of file diff --git a/server/resources/db_info/5eee8840-c268-4cf1-8f86-a0d13eaf9b16.json b/server/resources/db_info/5eee8840-c268-4cf1-8f86-a0d13eaf9b16.json new file mode 100644 index 0000000..bbb9b46 --- /dev/null +++ b/server/resources/db_info/5eee8840-c268-4cf1-8f86-a0d13eaf9b16.json @@ -0,0 +1,21 @@ +{ + "database_id": "5eee8840-c268-4cf1-8f86-a0d13eaf9b16", + "database_name": "test-db-3", + "db_type": "mysql", + "tables": [ + { + "id": "2a52c3a0-0019-4634-a4b3-3627a02153ba", + "database_id": "5eee8840-c268-4cf1-8f86-a0d13eaf9b16", + "parent_table": "database_info", + "sub_table_name": "DB��Ϣ", + "sub_table_comment": "", + "mapping_type": "", + "relation_field": "", + "relation_type": "", + "ddl": "CREATE TABLE `database_info` (\n`id` varchar(36) NOT NULL, COMMENT 'ID'\n`name` varchar(100) NOT NULL, COMMENT '����'\n `description` varchar(500) DEFAULT NULL,\n `db_type` varchar(20) NOT NULL,\n`host` varchar(255) NOT NULL, COMMENT '������ַ'\n `port` bigint NOT NULL,\n `username` varchar(100) NOT NULL,\n `password` varchar(255) DEFAULT NULL,\n `database` varchar(100) DEFAULT NULL,\n `table_count` bigint DEFAULT '0',\n `charset` varchar(20) DEFAULT 'utf8mb4',\n `ssl_mode` varchar(20) DEFAULT NULL,\n `created_at` datetime(3) DEFAULT NULL,\n `updated_at` datetime(3) DEFAULT NULL,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci", + "created_at": "2026-03-06T15:57:24.011+08:00", + "updated_at": "2026-03-06T15:57:24.011+08:00" + } + ], + "updated_at": "2026-03-06T15:57:24.0628515+08:00" +} \ No newline at end of file diff --git a/server/resources/db_info/68b6fb60-eae2-495b-b248-9c46c8d8d6cb.json b/server/resources/db_info/68b6fb60-eae2-495b-b248-9c46c8d8d6cb.json new file mode 100644 index 0000000..76830f7 --- /dev/null +++ b/server/resources/db_info/68b6fb60-eae2-495b-b248-9c46c8d8d6cb.json @@ -0,0 +1,21 @@ +{ + "database_id": "68b6fb60-eae2-495b-b248-9c46c8d8d6cb", + "database_name": "test-db-4", + "db_type": "mysql", + "tables": [ + { + "id": "5107d64f-9949-4550-9030-e7e14585f080", + "database_id": "68b6fb60-eae2-495b-b248-9c46c8d8d6cb", + "parent_table": "database_info", + "sub_table_name": "DB��", + "sub_table_comment": "", + "mapping_type": "", + "relation_field": "", + "relation_type": "", + "ddl": "CREATE TABLE `database_info` (\n`id` varchar(36) NOT NULL COMMENT '����ID'\n`name` varchar(100) NOT NULL COMMENT '���ݿ���'\n `description` varchar(500) DEFAULT NULL,\n `db_type` varchar(20) NOT NULL,\n`host` varchar(255) NOT NULL COMMENT '������'\n `port` bigint NOT NULL,\n `username` varchar(100) NOT NULL,\n `password` varchar(255) DEFAULT NULL,\n `database` varchar(100) DEFAULT NULL,\n `table_count` bigint DEFAULT '0',\n `charset` varchar(20) DEFAULT 'utf8mb4',\n `ssl_mode` varchar(20) DEFAULT NULL,\n `created_at` datetime(3) DEFAULT NULL,\n `updated_at` datetime(3) DEFAULT NULL,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci", + "created_at": "2026-03-06T16:00:34.065+08:00", + "updated_at": "2026-03-06T16:00:34.065+08:00" + } + ], + "updated_at": "2026-03-06T16:00:34.1176551+08:00" +} \ No newline at end of file diff --git a/server/resources/db_info/7eb66808-db8b-428e-8548-2f754c4fc688.json b/server/resources/db_info/7eb66808-db8b-428e-8548-2f754c4fc688.json new file mode 100644 index 0000000..2ca37df --- /dev/null +++ b/server/resources/db_info/7eb66808-db8b-428e-8548-2f754c4fc688.json @@ -0,0 +1,44 @@ +{ + "database_id": "7eb66808-db8b-428e-8548-2f754c4fc688", + "database_name": "123", + "db_type": "mysql", + "tables": [ + { + "id": "66026752-77b6-4cba-a4d1-1bf3b07e920c", + "database_id": "7eb66808-db8b-428e-8548-2f754c4fc688", + "parent_table": "teachers", + "sub_table_name": "teachers", + "sub_table_comment": "", + "mapping_type": "", + "relation_field": "", + "relation_type": "", + "created_at": "2026-03-06T15:12:24.217+08:00", + "updated_at": "2026-03-06T15:12:24.217+08:00" + }, + { + "id": "be59cb63-c5cf-46bf-b77a-46ce1fcb374b", + "database_id": "7eb66808-db8b-428e-8548-2f754c4fc688", + "parent_table": "scores", + "sub_table_name": "scores", + "sub_table_comment": "", + "mapping_type": "", + "relation_field": "", + "relation_type": "", + "created_at": "2026-03-06T15:12:24.112+08:00", + "updated_at": "2026-03-06T15:12:24.112+08:00" + }, + { + "id": "d91e5cd4-09c9-40a8-9c42-8f5ce0f059e1", + "database_id": "7eb66808-db8b-428e-8548-2f754c4fc688", + "parent_table": "students", + "sub_table_name": "students", + "sub_table_comment": "", + "mapping_type": "", + "relation_field": "", + "relation_type": "", + "created_at": "2026-03-06T15:12:24.166+08:00", + "updated_at": "2026-03-06T15:12:24.166+08:00" + } + ], + "updated_at": "2026-03-06T15:12:24.2696469+08:00" +} \ No newline at end of file diff --git a/server/resources/db_info/96d39e69-c96b-4d22-9b29-6456de71c6c1.json b/server/resources/db_info/96d39e69-c96b-4d22-9b29-6456de71c6c1.json new file mode 100644 index 0000000..2abe2df --- /dev/null +++ b/server/resources/db_info/96d39e69-c96b-4d22-9b29-6456de71c6c1.json @@ -0,0 +1,21 @@ +{ + "database_id": "96d39e69-c96b-4d22-9b29-6456de71c6c1", + "database_name": "189数据库", + "db_type": "mysql", + "tables": [ + { + "id": "d56ef61e-ac0d-439d-a2e8-133d766cbdd9", + "database_id": "96d39e69-c96b-4d22-9b29-6456de71c6c1", + "parent_table": "scores", + "sub_table_name": "scores", + "sub_table_comment": "", + "mapping_type": "", + "relation_field": "", + "relation_type": "", + "ddl": "CREATE TABLE `scores` (\n`id` int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT '分数id'\n `student_id` int(10) unsigned NOT NULL,\n `subject` varchar(50) NOT NULL COMMENT '科目',\n `score` double DEFAULT NULL COMMENT '分数',\n `teacher_id` int(10) unsigned DEFAULT NULL,\n `exam_date` date DEFAULT NULL COMMENT '考试日期',\n `created_at` datetime DEFAULT CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB AUTO_INCREMENT=48 DEFAULT CHARSET=utf8mb4", + "created_at": "2026-03-06T16:07:17.146+08:00", + "updated_at": "2026-03-06T16:07:17.146+08:00" + } + ], + "updated_at": "2026-03-06T16:07:17.1980788+08:00" +} \ No newline at end of file diff --git a/server/resources/db_info/a58e6c1e-b39b-4248-8de9-b172f134197b.json b/server/resources/db_info/a58e6c1e-b39b-4248-8de9-b172f134197b.json new file mode 100644 index 0000000..77ac71f --- /dev/null +++ b/server/resources/db_info/a58e6c1e-b39b-4248-8de9-b172f134197b.json @@ -0,0 +1,21 @@ +{ + "database_id": "a58e6c1e-b39b-4248-8de9-b172f134197b", + "database_name": "test-db-2", + "db_type": "mysql", + "tables": [ + { + "id": "613c5bb7-2d42-4b75-8f19-43a2b345de8b", + "database_id": "a58e6c1e-b39b-4248-8de9-b172f134197b", + "parent_table": "database_info", + "sub_table_name": "���ݿ���Ϣ��", + "sub_table_comment": "", + "mapping_type": "", + "relation_field": "", + "relation_type": "", + "ddl": "CREATE TABLE `database_info` (\n `id` varchar(36) NOT NULL,\n `name` varchar(100) NOT NULL,\n `description` varchar(500) DEFAULT NULL,\n `db_type` varchar(20) NOT NULL,\n `host` varchar(255) NOT NULL,\n `port` bigint NOT NULL,\n `username` varchar(100) NOT NULL,\n `password` varchar(255) DEFAULT NULL,\n `database` varchar(100) DEFAULT NULL,\n `table_count` bigint DEFAULT '0',\n `charset` varchar(20) DEFAULT 'utf8mb4',\n `ssl_mode` varchar(20) DEFAULT NULL,\n `created_at` datetime(3) DEFAULT NULL,\n `updated_at` datetime(3) DEFAULT NULL,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci", + "created_at": "2026-03-06T15:51:28.706+08:00", + "updated_at": "2026-03-06T15:51:28.706+08:00" + } + ], + "updated_at": "2026-03-06T15:51:28.762063+08:00" +} \ No newline at end of file diff --git a/server/resources/db_info/b5fc80da-b681-4f6f-a35a-73e73dee50d0.json b/server/resources/db_info/b5fc80da-b681-4f6f-a35a-73e73dee50d0.json new file mode 100644 index 0000000..4b6d667 --- /dev/null +++ b/server/resources/db_info/b5fc80da-b681-4f6f-a35a-73e73dee50d0.json @@ -0,0 +1,20 @@ +{ + "database_id": "b5fc80da-b681-4f6f-a35a-73e73dee50d0", + "database_name": "123", + "db_type": "mysql", + "tables": [ + { + "id": "49805b02-3204-44b9-9d28-3b5053ca7a1e", + "database_id": "b5fc80da-b681-4f6f-a35a-73e73dee50d0", + "parent_table": "scores", + "sub_table_name": "scores", + "sub_table_comment": "", + "mapping_type": "", + "relation_field": "", + "relation_type": "", + "created_at": "2026-03-06T15:12:36.292+08:00", + "updated_at": "2026-03-06T15:12:36.292+08:00" + } + ], + "updated_at": "2026-03-06T15:12:36.3469958+08:00" +} \ No newline at end of file diff --git a/server/resources/db_info/d022a68d-cb75-405b-bee3-e923a8b5a283.json b/server/resources/db_info/d022a68d-cb75-405b-bee3-e923a8b5a283.json new file mode 100644 index 0000000..764535a --- /dev/null +++ b/server/resources/db_info/d022a68d-cb75-405b-bee3-e923a8b5a283.json @@ -0,0 +1,7 @@ +{ + "database_id": "d022a68d-cb75-405b-bee3-e923a8b5a283", + "database_name": "123", + "db_type": "mysql", + "tables": [], + "updated_at": "2026-03-06T15:32:13.1521688+08:00" +} \ No newline at end of file diff --git a/server/resources/db_info/d44fa121-5964-439f-8c5d-0384ba27b411.json b/server/resources/db_info/d44fa121-5964-439f-8c5d-0384ba27b411.json new file mode 100644 index 0000000..532c81d --- /dev/null +++ b/server/resources/db_info/d44fa121-5964-439f-8c5d-0384ba27b411.json @@ -0,0 +1,22 @@ +{ + "database_id": "d44fa121-5964-439f-8c5d-0384ba27b411", + "database_name": "123", + "db_type": "mysql", + "tables": [ + { + "id": "694da06f-d6b7-4915-8502-c4c38addf059", + "database_id": "d44fa121-5964-439f-8c5d-0384ba27b411", + "parent_table": "scores", + "sub_table_name": "scores", + "sub_table_comment": "", + "mapping_type": "", + "relation_field": "", + "relation_type": "", + "fields": null, + "ddl": "CREATE TABLE `scores` (\n`id` int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT 'fenshu id'\n `student_id` int(10) unsigned NOT NULL,\n `subject` varchar(50) NOT NULL COMMENT '科目',\n `score` double DEFAULT NULL COMMENT '分数',\n `teacher_id` int(10) unsigned DEFAULT NULL,\n `exam_date` date DEFAULT NULL COMMENT '考试日期',\n `created_at` datetime DEFAULT CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB AUTO_INCREMENT=48 DEFAULT CHARSET=utf8mb4", + "created_at": "2026-03-06T16:30:41.589+08:00", + "updated_at": "2026-03-06T16:30:41.589+08:00" + } + ], + "updated_at": "2026-03-06T16:30:41.639042+08:00" +} \ No newline at end of file diff --git a/server/temp_add_data2.go b/server/temp_add_data2.go new file mode 100644 index 0000000..5ad9c84 --- /dev/null +++ b/server/temp_add_data2.go @@ -0,0 +1,122 @@ +package main + +import ( + "fmt" + "math/rand" + "time" + + "gorm.io/driver/mysql" + "gorm.io/gorm" +) + +type Teacher struct { + ID uint `gorm:"primaryKey"` + Name string `gorm:"size:50;charset=utf8mb4"` + Subject string `gorm:"size:50;charset=utf8mb4"` + Phone string `gorm:"size:20"` + CreatedAt time.Time +} + +type Student struct { + ID uint `gorm:"primaryKey"` + Name string `gorm:"size:50;charset=utf8mb4"` + Age int + Gender string `gorm:"size:10;charset=utf8mb4"` + Class string `gorm:"size:50;charset=utf8mb4"` + Phone string `gorm:"size:20"` + CreatedAt time.Time +} + +type Score struct { + ID uint `gorm:"primaryKey"` + StudentID uint + Subject string `gorm:"size:50;charset=utf8mb4"` + Score float64 + TeacherID uint + ExamDate time.Time + CreatedAt time.Time +} + +func main() { + dsn := "root:881116142@tcp(10.10.10.189:3306)/students?charset=utf8mb4&parseTime=True&loc=Local" + db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) + if err != nil { + panic("连接数据库失败: " + err.Error()) + } + + // 自动迁移表 + db.AutoMigrate(&Teacher{}, &Student{}, &Score{}) + + // 清理旧数据 + db.Exec("DELETE FROM scores") + db.Exec("DELETE FROM students") + db.Exec("DELETE FROM teachers") + + rand.Seed(time.Now().UnixNano()) + + // 创建教师 + teachers := []Teacher{ + {Name: "张老师", Subject: "数学", Phone: "13800001001"}, + {Name: "李老师", Subject: "语文", Phone: "13800001002"}, + {Name: "王老师", Subject: "英语", Phone: "13800001003"}, + {Name: "刘老师", Subject: "物理", Phone: "13800001004"}, + {Name: "陈老师", Subject: "化学", Phone: "13800001005"}, + {Name: "杨老师", Subject: "生物", Phone: "13800001006"}, + {Name: "赵老师", Subject: "历史", Phone: "13800001007"}, + {Name: "周老师", Subject: "地理", Phone: "13800001008"}, + } + db.Create(&teachers) + + // 创建30个学生 + names := []string{"张三", "李四", "王五", "刘六", "陈七", "杨八", "赵九", "钱十", + "孙一", "周二", "吴三", "郑四", "冯五", "褚六", "卫七", "蒋八", + "沈九", "韩十", "朱十一", "秦十二", "许十三", "何十四", "吕十五", "施十六", + "张十七", "孔十八", "曹十九", "严二十", "华二十一", "金二十二"} + genders := []string{"男", "女"} + classes := []string{"高一(1)班", "高一(2)班", "高一(3)班", "高二(1)班", "高二(2)班"} + + students := make([]Student, 30) + for i := 0; i < 30; i++ { + students[i] = Student{ + Name: names[i], + Age: 15 + rand.Intn(3), + Gender: genders[rand.Intn(len(genders))], + Class: classes[rand.Intn(len(classes))], + Phone: fmt.Sprintf("139%08d", 10000000+rand.Intn(90000000)), + } + } + db.Create(&students) + + // 为每个学生创建成绩记录 + subjects := []string{"数学", "语文", "英语", "物理", "化学", "生物", "历史", "地理"} + scores := make([]Score, 0) + + for i := 0; i < 30; i++ { + numSubjects := 4 + rand.Intn(3) + selectedSubjects := make(map[string]bool) + for len(selectedSubjects) < numSubjects { + subj := subjects[rand.Intn(len(subjects))] + if !selectedSubjects[subj] { + selectedSubjects[subj] = true + + teacherID := uint(1 + rand.Intn(len(teachers))) + examDate := time.Now().AddDate(0, -rand.Intn(6), -rand.Intn(30)) + + score := Score{ + StudentID: students[i].ID, + Subject: subj, + Score: 60 + rand.Float64()*40, + TeacherID: teacherID, + ExamDate: examDate, + } + scores = append(scores, score) + } + } + } + db.Create(&scores) + + fmt.Println("数据创建成功!") + fmt.Printf("教师: %d 条\n", len(teachers)) + fmt.Printf("学生: %d 条\n", len(students)) + fmt.Printf("成绩: %d 条\n", len(scores)) +} diff --git a/server/temp_check.go b/server/temp_check.go new file mode 100644 index 0000000..4df8904 --- /dev/null +++ b/server/temp_check.go @@ -0,0 +1,27 @@ +package main + +import ( + "gorm.io/driver/mysql" + "gorm.io/gorm" +) + +func main() { + dsn := "root:881116142@tcp(10.10.10.189:3306)/mysql?charset=utf8mb4&parseTime=True&loc=Local" + db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) + if err != nil { + panic("连接失败: " + err.Error()) + } + + type User struct { + User string + Host string + } + + var users []User + db.Raw("SELECT User, Host FROM mysql.user WHERE User='root'").Scan(&users) + + println("Root 用户列表:") + for _, u := range users { + println("- User: " + u.User + ", Host: " + u.Host) + } +} diff --git a/server/temp_native.go b/server/temp_native.go new file mode 100644 index 0000000..aec8853 --- /dev/null +++ b/server/temp_native.go @@ -0,0 +1,32 @@ +package main + +import ( + "gorm.io/driver/mysql" + "gorm.io/gorm" +) + +func main() { + dsn := "root:881116142@tcp(10.10.10.189:3306)/mysql?charset=utf8mb4&parseTime=True&loc=Local" + db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) + if err != nil { + panic("连接失败: " + err.Error()) + } + + // 使用 mysql_native_password 插件重建用户 + sqls := []string{ + "DROP USER IF EXISTS 'root'@'%'", + "CREATE USER 'root'@'%' IDENTIFIED WITH mysql_native_password BY '881116142'", + "GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' WITH GRANT OPTION", + "FLUSH PRIVILEGES", + } + + for _, sql := range sqls { + if err := db.Exec(sql).Error; err != nil { + println("执行: " + sql + " - 错误: " + err.Error()) + } else { + println("成功: " + sql) + } + } + + println("完成! 用 mysql_native_password 插件重建了 root 用户") +} diff --git a/server/temp_newuser.go b/server/temp_newuser.go new file mode 100644 index 0000000..2abfefa --- /dev/null +++ b/server/temp_newuser.go @@ -0,0 +1,41 @@ +package main + +import ( + "gorm.io/driver/mysql" + "gorm.io/gorm" +) + +func main() { + // 尝试用 root 用户连接,但指定 IP + dsn := "root:881116142@tcp(127.0.0.1:3306)/mysql?charset=utf8mb4&parseTime=True&loc=Local&multiStatements=true" + db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) + if err != nil { + // 尝试其他方式 + dsn2 := "root:881116142@tcp(localhost:3306)/mysql?charset=utf8mb4&parseTime=True&loc=Local&multiStatements=true" + db, err = gorm.Open(mysql.Open(dsn2), &gorm.Config{}) + if err != nil { + println("无法连接,请通过其他方式在 MySQL 服务器上执行:") + println("CREATE USER IF NOT EXISTS 'root'@'%' IDENTIFIED BY '881116142';") + println("GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' WITH GRANT OPTION;") + println("FLUSH PRIVILEGES;") + panic("连接失败: " + err.Error()) + } + } + + // 创建新用户 + sqls := []string{ + "CREATE USER IF NOT EXISTS 'admin'@'%' IDENTIFIED BY 'admin123'", + "GRANT ALL PRIVILEGES ON *.* TO 'admin'@'%' WITH GRANT OPTION", + "FLUSH PRIVILEGES", + } + + for _, sql := range sqls { + if err := db.Exec(sql).Error; err != nil { + println("执行: " + sql + " - 错误: " + err.Error()) + } else { + println("执行成功: " + sql) + } + } + + println("创建了新用户 admin,可以用这个连接 Navicat") +} diff --git a/server/temp_regrant.go b/server/temp_regrant.go new file mode 100644 index 0000000..42c1abe --- /dev/null +++ b/server/temp_regrant.go @@ -0,0 +1,32 @@ +package main + +import ( + "gorm.io/driver/mysql" + "gorm.io/gorm" +) + +func main() { + dsn := "root:881116142@tcp(10.10.10.189:3306)/mysql?charset=utf8mb4&parseTime=True&loc=Local&multiStatements=true" + db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) + if err != nil { + panic("连接失败: " + err.Error()) + } + + // 重建 root@% 用户并设置密码 + sqls := []string{ + "DROP USER IF EXISTS 'root'@'%'", + "CREATE USER 'root'@'%' IDENTIFIED BY '881116142'", + "GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' WITH GRANT OPTION", + "FLUSH PRIVILEGES", + } + + for _, sql := range sqls { + if err := db.Exec(sql).Error; err != nil { + println("执行: " + sql + " - 错误: " + err.Error()) + } else { + println("执行成功: " + sql) + } + } + + println("完成!") +} diff --git a/server/temp_reset.go b/server/temp_reset.go new file mode 100644 index 0000000..2ab68fd --- /dev/null +++ b/server/temp_reset.go @@ -0,0 +1,35 @@ +package main + +import ( + "gorm.io/driver/mysql" + "gorm.io/gorm" +) + +func main() { + dsn := "root:881116142@tcp(10.10.10.189:3306)/mysql?charset=utf8mb4&parseTime=True&loc=Local&allowOldStrings=true" + db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) + if err != nil { + panic("连接失败: " + err.Error()) + } + + // 删除所有 root 用户 + sqls := []string{ + "DROP USER IF EXISTS 'root'@'%'", + "DROP USER IF EXISTS 'root'@'10.10.10.122'", + "DROP USER IF EXISTS 'root'@'localhost'", + "DROP USER IF EXISTS 'root'@'127.0.0.1'", + "CREATE USER 'root'@'%' IDENTIFIED BY '881116142'", + "GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' WITH GRANT OPTION", + "FLUSH PRIVILEGES", + } + + for _, sql := range sqls { + if err := db.Exec(sql).Error; err != nil { + println("执行: " + sql + " - 错误: " + err.Error()) + } else { + println("成功: " + sql) + } + } + + println("完成!") +} diff --git a/start-local.ps1 b/start-local.ps1 new file mode 100644 index 0000000..2b7ed2d --- /dev/null +++ b/start-local.ps1 @@ -0,0 +1,41 @@ +# X-Agents 本地启动脚本(Go + 前端) +# 运行方式: .\start-local.ps1 + +$ErrorActionPreference = "Stop" + +Write-Host "======================================" -ForegroundColor Cyan +Write-Host " X-Agents 本地启动 (Go + 前端)" -ForegroundColor Cyan +Write-Host "======================================" -ForegroundColor Cyan + +# 1. 启动数据库 +Write-Host "[启动] 数据库..." -ForegroundColor Green +docker compose -f docker-compose.dev.yml up -d + +# 2. 启动 Go 服务 +Write-Host "[启动] Go API 服务..." -ForegroundColor Green +Start-Process powershell -ArgumentList "-NoExit", "-Command", @" +cd $PWD\server +go run ./cmd/api +"@ -WindowStyle Normal + +# 3. 启动前端 +Write-Host "[启动] 前端服务..." -ForegroundColor Green +if (Test-Path "web/package.json") { + Start-Process powershell -ArgumentList "-NoExit", "-Command", @" +cd $PWD\web +npm run dev +"@ -WindowStyle Normal +} + +Write-Host "" +Write-Host "======================================" -ForegroundColor Green +Write-Host " 服务已启动!" -ForegroundColor Green +Write-Host "======================================" -ForegroundColor Green +Write-Host "" +Write-Host "服务地址:" -ForegroundColor White +Write-Host " - Go API: http://localhost:8080" -ForegroundColor Cyan +Write-Host " - 前端: http://localhost:5173" -ForegroundColor Cyan +Write-Host " - MySQL: localhost:6036" -ForegroundColor Cyan +Write-Host " - Redis: localhost:6037" -ForegroundColor Cyan + +Read-Host | Out-Null diff --git a/team-require/api/README.md b/team-require/api/README.md new file mode 100644 index 0000000..8ffaea0 --- /dev/null +++ b/team-require/api/README.md @@ -0,0 +1,14 @@ +# API 接口文档 + +## 目录 + +### Database 相关 + +- [检查数据库连接并获取表结构](database-check.md) +- [创建数据库配置](database-create.md) +- [获取数据库列表](database-list.md) +- [获取子表列表](subtable-list.md) + +--- + +> 接口如有更新,请同步更新此文档 diff --git a/team-require/api/database-check.md b/team-require/api/database-check.md new file mode 100644 index 0000000..68da395 --- /dev/null +++ b/team-require/api/database-check.md @@ -0,0 +1,103 @@ +# 检查数据库连接并获取表结构 + +## 接口地址 + +``` +POST /database/check +``` + +## 请求参数 + +| 参数 | 类型 | 必填 | 说明 | +|------|------|------|------| +| db_type | string | 是 | 数据库类型:`mysql`、`postgres` | +| host | string | 是 | 数据库主机 | +| port | int | 是 | 数据库端口 | +| username | string | 是 | 用户名 | +| password | string | 否 | 密码 | +| database | string | 是 | 数据库名 | +| charset | string | 否 | 字符集,默认 `utf8mb4` | +| ssl_mode | string | 否 | SSL 模式 | +| database_id | string | 否 | 已存在的数据库ID,用于恢复字段映射 | + +## 请求示例 + +```json +{ + "db_type": "mysql", + "host": "localhost", + "port": 3306, + "username": "root", + "password": "root", + "database": "students", + "charset": "utf8mb4", + "database_id": "xxx-xxx-xxx" // 可选,用于恢复字段映射 +} +``` + +## 返回参数 + +| 参数 | 类型 | 说明 | +|------|------|------| +| success | bool | 是否连接成功 | +| message | string | 消息 | +| database | string | 数据库名 | +| tables | array | 表结构列表 | + +### tables[] 详情 + +| 参数 | 类型 | 说明 | +|------|------|------| +| table_name | string | 表名 | +| table_comment | string | 表注释 | +| ddl | string | 建表 DDL(带 COMMENT 的映射后 DDL) | +| columns | array | 列信息列表 | + +### columns[] 详情 + +| 参数 | 类型 | 说明 | +|------|------|------| +| column_name | string | 列名 | +| data_type | string | 数据类型 | +| column_type | string | 完整列类型 | +| is_nullable | string | 是否可空(YES/NO) | +| default_value | string | 默认值 | +| column_key | string | 主键标识(PRI/MUL/UNI) | +| extra | string | 额外信息(如 auto_increment) | +| column_comment | string | 列注释 | +| mapped_name | string | 字段中文映射名(已保存的映射) | + +## 返回示例 + +```json +{ + "success": true, + "message": "connection successful", + "database": "students", + "tables": [ + { + "table_name": "users", + "table_comment": "用户表", + "ddl": "CREATE TABLE `users` (\n `id` int(10) unsigned NOT NULL COMMENT '用户ID'\n ...\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4", + "columns": [ + { + "column_name": "id", + "data_type": "int", + "column_type": "int(10) unsigned", + "is_nullable": "NO", + "default_value": "", + "column_key": "PRI", + "extra": "auto_increment", + "column_comment": "", + "mapped_name": "用户ID" + } + ] + } + ] +} +``` + +## 使用场景 + +1. **首次连接**:不传 `database_id`,获取实时表结构 +2. **恢复映射**:传入 `database_id`,返回已保存的 `mapped_name` 和 `ddl` diff --git a/team-require/api/database-create.md b/team-require/api/database-create.md new file mode 100644 index 0000000..d1dc72d --- /dev/null +++ b/team-require/api/database-create.md @@ -0,0 +1,104 @@ +# 创建数据库配置 + +## 接口地址 + +``` +POST /database/add +``` + +## 请求参数 + +| 参数 | 类型 | 必填 | 说明 | +|------|------|------|------| +| name | string | 是 | 数据库名称 | +| description | string | 否 | 描述 | +| db_type | string | 是 | 数据库类型 | +| host | string | 是 | 主机 | +| port | int | 是 | 端口 | +| username | string | 是 | 用户名 | +| password | string | 否 | 密码 | +| database | string | 是 | 数据库名 | +| charset | string | 否 | 字符集 | +| ssl_mode | string | 否 | SSL 模式 | +| sub_tables | array | 否 | 子表配置列表 | + +### sub_tables[] 详情 + +| 参数 | 类型 | 必填 | 说明 | +|------|------|------|------| +| parent_table | string | 是 | 原始表名 | +| sub_table_name | string | 是 | 子表别名 | +| sub_table_comment | string | 否 | 子表注释 | +| mapping_type | string | 否 | 映射类型 | +| relation_field | string | 否 | 关联字段 | +| relation_type | string | 否 | 关联类型 | +| fields | array | 否 | 字段映射列表 | + +### fields[] 详情 + +| 参数 | 类型 | 必填 | 说明 | +|------|------|------|------| +| column_name | string | 是 | 列名 | +| mapped_name | string | 是 | 中文映射名 | + +## 请求示例 + +```json +{ + "name": "学生数据库", + "description": "用于存储学生信息", + "db_type": "mysql", + "host": "localhost", + "port": 3306, + "username": "root", + "password": "root", + "database": "students", + "charset": "utf8mb4", + "sub_tables": [ + { + "parent_table": "users", + "sub_table_name": "用户表", + "sub_table_comment": "用户信息", + "fields": [ + {"column_name": "id", "mapped_name": "用户ID"}, + {"column_name": "name", "mapped_name": "用户名"} + ] + } + ] +} +``` + +## 返回参数 + +| 参数 | 类型 | 说明 | +|------|------|------| +| id | string | 数据库记录ID | +| name | string | 数据库名称 | +| db_type | string | 数据库类型 | +| host | string | 主机 | +| port | int | 端口 | +| ... | ... | 其他字段 | + +## 返回示例 + +```json +{ + "id": "xxx-xxx-xxx", + "name": "学生数据库", + "description": "用于存储学生信息", + "db_type": "mysql", + "host": "localhost", + "port": 3306, + "username": "root", + "password": "root", + "database": "students", + "table_count": 1, + "charset": "utf8mb4", + "created_at": "2026-03-06T15:00:00Z" +} +``` + +## 说明 + +- 创建时会自动连接数据库获取表结构 DDL +- 如果传入了 `fields`(字段映射),会自动生成带 COMMENT 的新 DDL 并存储 diff --git a/team-require/api/database-list.md b/team-require/api/database-list.md new file mode 100644 index 0000000..87b93b6 --- /dev/null +++ b/team-require/api/database-list.md @@ -0,0 +1,51 @@ +# 获取数据库列表 + +## 接口地址 + +``` +GET /database/list +``` + +## 请求参数 + +无 + +## 返回参数 + +| 参数 | 类型 | 说明 | +|------|------|------| +| list | array | 数据库列表 | + +### list[] 详情 + +| 参数 | 类型 | 说明 | +|------|------|------| +| id | string | 数据库ID | +| name | string | 数据库名称 | +| description | string | 描述 | +| db_type | string | 数据库类型 | +| host | string | 主机 | +| port | int | 端口 | +| database | string | 数据库名 | +| table_count | int | 子表数量 | +| created_at | string | 创建时间 | + +## 返回示例 + +```json +{ + "list": [ + { + "id": "xxx-xxx", + "name": "学生数据库", + "description": "用于存储学生信息", + "db_type": "mysql", + "host": "localhost", + "port": 3306, + "database": "students", + "table_count": 5, + "created_at": "2026-03-06T15:00:00Z" + } + ] +} +``` diff --git a/team-require/api/subtable-list.md b/team-require/api/subtable-list.md new file mode 100644 index 0000000..c916bcb --- /dev/null +++ b/team-require/api/subtable-list.md @@ -0,0 +1,75 @@ +# 获取子表列表 + +## 接口地址 + +``` +GET /sub-table/database/:database_id +``` + +## 路径参数 + +| 参数 | 类型 | 必填 | 说明 | +|------|------|------|------| +| database_id | string | 是 | 数据库ID | + +## 返回参数 + +| 参数 | 类型 | 说明 | +|------|------|------| +| list | array | 子表列表 | + +### list[] 详情 + +| 参数 | 类型 | 说明 | +|------|------|------| +| id | string | 子表ID | +| database_id | string | 关联的数据库ID | +| parent_table | string | 原始表名 | +| sub_table_name | string | 子表别名 | +| sub_table_comment | string | 子表注释 | +| mapping_type | string | 映射类型 | +| relation_field | string | 关联字段 | +| relation_type | string | 关联类型 | +| fields | array | 字段映射列表 | +| ddl | string | 建表 DDL(带 COMMENT) | +| created_at | string | 创建时间 | + +### fields[] 详情 + +| 参数 | 类型 | 说明 | +|------|------|------| +| column_name | string | 列名 | +| mapped_name | string | 中文映射名 | + +## 返回示例 + +```json +{ + "list": [ + { + "id": "xxx-xxx", + "database_id": "database-xxx", + "parent_table": "users", + "sub_table_name": "用户表", + "sub_table_comment": "用户信息", + "mapping_type": "horizontal", + "relation_field": "id", + "relation_type": "one_to_many", + "fields": [ + {"column_name": "id", "mapped_name": "用户ID"}, + {"column_name": "name", "mapped_name": "用户名"} + ], + "ddl": "CREATE TABLE `users` (\n `id` int(10) unsigned NOT NULL COMMENT '用户ID'\n ...\n)", + "created_at": "2026-03-06T15:00:00Z" + } + ] +} +``` + +## 使用场景 + +用于恢复映射状态: +1. 用户点击已存在的数据库的 "Map Tables" 按钮 +2. 调用此接口获取已保存的子表信息 +3. 根据 `parent_table` 勾选已选择的表 +4. 根据 `fields` 恢复字段映射 diff --git a/team-require/web/columns-api.md b/team-require/web/columns-api.md new file mode 100644 index 0000000..a893fb8 --- /dev/null +++ b/team-require/web/columns-api.md @@ -0,0 +1,92 @@ +# 后端需求 - 表结构返回 columns 数据 + +## 问题描述 + +前端在 Edit Mapping 页面需要展示表的列信息(字段名、类型、COMMENT等),但前端自行解析 DDL 存在困难。 + +## 需求 + +后端在获取表结构列表时,需要同时返回: + +1. **DDL 语句**(已有的需求,继续保留) +2. **结构化的 columns 数据**(新增) + +### 返回数据结构 + +```json +{ + "success": true, + "tables": [ + { + "table_name": "exam_scores", + "table_comment": "考试成绩表", + "ddl": "CREATE TABLE `exam_scores` (...)", + "columns": [ + { + "column_name": "id", + "data_type": "int", + "column_type": "int(10) unsigned", + "is_nullable": "NO", + "default_value": null, + "column_key": "PRI", + "extra": "auto_increment", + "column_comment": "" + }, + { + "column_name": "student_id", + "data_type": "int", + "column_type": "int(10) unsigned", + "is_nullable": "NO", + "default_value": null, + "column_key": "", + "extra": "", + "column_comment": "" + }, + { + "column_name": "subject", + "data_type": "varchar", + "column_type": "varchar(50)", + "is_nullable": "NO", + "default_value": null, + "column_key": "", + "extra": "", + "column_comment": "科目" + }, + { + "column_name": "score", + "data_type": "double", + "column_type": "double", + "is_nullable": "YES", + "default_value": null, + "column_key": "", + "extra": "", + "column_comment": "分数" + } + ] + } + ] +} +``` + +### 字段说明 + +| 字段 | 类型 | 说明 | +|------|------|------| +| column_name | string | 列名 | +| data_type | string | 数据类型(如 int, varchar, double) | +| column_type | string | 完整列类型(如 int(10) unsigned) | +| is_nullable | string | 是否可空(YES/NO) | +| default_value | string | 默认值 | +| column_key | string | 主键标识(PRI/MUL/UNI) | +| extra | string | 额外信息(如 auto_increment) | +| column_comment | string | 列注释 | + +## 影响范围 + +- 文件:`server/internal/service/database_service.go` +- 函数:`getMySQLTables`, `getPostgresTables` +- 数据模型:`server/internal/model/sub_table_info.go` 的 `ColumnInfo` 结构体 + +## 优先级 + +高 - 前端 Edit Mapping 页面字段映射功能依赖此数据 diff --git a/team-require/web/field-mapping.md b/team-require/web/field-mapping.md new file mode 100644 index 0000000..02afa5e --- /dev/null +++ b/team-require/web/field-mapping.md @@ -0,0 +1,89 @@ +# 后端需求 - 字段映射保存与读取 + +## 问题描述 + +前端 Edit Mapping 页面中,用户输入的字段中文映射名(mapped_name)在保存后,第二次打开时丢失了。 + +## 原因分析 + +1. **保存时**:前端只保存了表级别信息,没有保存字段的中文映射 +2. **加载时**:前端每次都从 `/database/check` 重新获取表结构,没有读取已保存的映射数据 + +## 需求 + +### 1. 保存字段映射 + +前端保存时需要传递每个字段的中文映射名,后端需要存储这些数据。 + +请求结构: +```json +{ + "name": "数据库名", + "sub_tables": [ + { + "parent_table": "users", + "sub_table_name": "用户表", + "sub_table_comment": "用户表", + "fields": [ + { + "column_name": "id", + "mapped_name": "编号" + }, + { + "column_name": "username", + "mapped_name": "用户名" + } + ] + } + ] +} +``` + +### 2. 返回字段映射 + +后端在返回表结构时,需要同时返回已保存的字段映射信息。 + +返回结构: +```json +{ + "success": true, + "tables": [ + { + "table_name": "users", + "table_comment": "用户表", + "ddl": "...", + "columns": [ + { + "column_name": "id", + "data_type": "int", + "column_type": "int(10)", + "column_comment": "", + "mapped_name": "编号" + }, + { + "column_name": "username", + "data_type": "varchar", + "column_type": "varchar(50)", + "column_comment": "用户名", + "mapped_name": "用户名" + } + ] + } + ] +} +``` + +### 3. 数据存储 + +- 可以在 `sub_table_info` 表中增加 `fields` JSON 字段存储字段映射 +- 或者创建新的关联表 `sub_table_fields` + +## 影响范围 + +- `server/internal/service/database_service.go` - Create/Update 方法 +- `server/internal/model/` - 数据模型修改 +- 子表映射的数据存储结构 + +## 优先级 + +高 - 用户输入的映射数据丢失影响使用体验 diff --git a/team-require/web/mapping-state.md b/team-require/web/mapping-state.md new file mode 100644 index 0000000..6142913 --- /dev/null +++ b/team-require/web/mapping-state.md @@ -0,0 +1,43 @@ +# 后端需求 - 保存和恢复映射状态 + +## 问题描述 + +用户第一次选择表并设置字段映射后,第二次点击 "Map Tables" 按钮进入界面时,之前选择的表和设置的字段映射都丢失了。 + +## 需求 + +前端打开已存在的数据库映射时,需要恢复以下状态: + +### 1. 已选择的表列表 + +后端需要在数据库记录中保存用户选择了哪些表(不仅仅是子表信息),或者在查询时返回该数据库关联的所有子表。 + +### 2. 字段映射 + +每个子表保存的字段映射(mapped_name)需要在前端重新加载时显示。 + +## 期望的行为 + +1. 用户点击已存在的数据库的 "Map Tables" 按钮 +2. 前端获取实时表结构 +3. 同时加载该数据库已保存的子表信息(包括选择的表和字段映射) +4. 前端合并数据,显示: + - 已选择的表(勾选状态) + - 每个字段之前设置的 mapped_name + +## 技术实现建议 + +在数据库表中增加或利用已有字段: + +- `sub_table_info` 表已包含 `Fields` JSON 字段存储字段映射 +- 需要在创建/更新数据库时保存选择的表列表 +- 或者在查询时返回该数据库下所有已创建的子表 + +## 影响范围 + +- 数据库创建/更新接口 +- 子表映射查询接口 + +## 优先级 + +高 - 影响用户体验,第二次进入无法看到之前的工作成果 diff --git a/team-require/web/todo-2026-3-6.md b/team-require/web/todo-2026-3-6.md new file mode 100644 index 0000000..cc2547f --- /dev/null +++ b/team-require/web/todo-2026-3-6.md @@ -0,0 +1,27 @@ +# Web 前端需求 TODO + +## 2026年3月 + +### 2026-03-06 + +- [x] **DDL 获取功能** - 后端需在获取表结构时返回 DDL 语句 ✔ + - 相关文件:`server/internal/service/database_service.go` + - 函数:`getMySQLTables`, `getPostgresTables` + - 详细需求:[ddl-fetch.md](./ddl-fetch.md) + +- [x] **返回结构化 columns 数据** - 后端需返回完整的列信息(column_name, data_type, column_type, is_nullable, default_value, column_key, extra, column_comment)✔ + - 相关文件:`server/internal/service/database_service.go` + - 函数:`getMySQLTables`, `getPostgresTables` + - 详细需求:[columns-api.md](./columns-api.md) + +- [x] **保存和读取字段映射** - 后端需支持保存/读取字段的中文映射名(mapped_name) ✔ + - 相关文件:`server/internal/service/database_service.go`, `server/internal/model/` + - 详细需求:[field-mapping.md](./field-mapping.md) + +- [x] **保存和恢复映射状态** - 第二次进入 Map Tables 时需恢复之前选择的表和字段映射 ✔ + - 相关文件:`server/internal/service/database_service.go`, `server/internal/model/` + - 详细需求:[mapping-state.md](./mapping-state.md) + +--- + +> 需求完成后请完成者打 ✔ \ No newline at end of file diff --git a/web/src/style.css b/web/src/style.css index b113384..1e3591e 100644 --- a/web/src/style.css +++ b/web/src/style.css @@ -46,6 +46,27 @@ html.dark .el-select .el-select__wrapper { min-height: 42px; } +html.dark .el-select:hover .el-input__wrapper, +html.dark .el-select:hover .el-select__wrapper { + background-color: #1a1c25 !important; +} + +/* 修复鼠标移出后出现白色背景的问题 */ +html.dark .el-select .el-input__wrapper:hover, +html.dark .el-select .el-select__wrapper:hover { + background-color: #1a1c25 !important; +} + +html.dark .el-select .el-input__wrapper, +html.dark .el-select .el-select__wrapper { + background-color: #1a1c25 !important; +} + +html.dark .el-form-item__content .el-input__wrapper, +html.dark .el-form-item__content .el-select .el-input__wrapper { + background-color: #1a1c25 !important; +} + html.dark .el-select.el-select--large .el-input__wrapper { padding: 5px 11px; min-height: 46px; @@ -188,6 +209,31 @@ html.dark .el-select .el-tag .el-tag__close:hover { color: #ffffff; } +/* el-checkbox 暗色主题 - 金黄色选中 */ +html.dark .el-checkbox { + --el-checkbox-checked-text-color: #ffb700; + --el-checkbox-checked-bg-color: #ffb700; + --el-checkbox-checked-border-color: #ffb700; + --el-checkbox-input-border-color-hover: #ffb700; +} + +html.dark .el-checkbox .el-checkbox__input.is-checked .el-checkbox__inner { + background-color: #ffb700; + border-color: #ffb700; +} + +html.dark .el-checkbox .el-checkbox__input.is-checked .el-checkbox__inner::after { + border-color: #1f2937; +} + +html.dark .el-checkbox .el-checkbox__label { + color: #e5e7eb; +} + +html.dark .el-checkbox:hover .el-checkbox__inner { + border-color: #ffb700; +} + /* 柱状图增长动画 */ @keyframes bar-grow { from { @@ -222,3 +268,344 @@ html.dark .el-select .el-tag .el-tag__close:hover { opacity: 0; animation: progress-grow 2.4s cubic-bezier(0.4, 0, 0.2, 1) forwards; } + +/* ===== 通用组件交互优化 ===== */ + +/* 通用弹窗动画 */ +@keyframes modal-fade-in { + 0% { + opacity: 0; + } + 100% { + opacity: 1; + } +} + +@keyframes modal-scale-in { + 0% { + opacity: 0; + transform: scale(0.95) translateY(10px); + } + 100% { + opacity: 1; + transform: scale(1) translateY(0); + } +} + +@keyframes modal-slide-up { + 0% { + opacity: 0; + transform: translateY(20px); + } + 100% { + opacity: 1; + transform: translateY(0); + } +} + +.modal-overlay { + animation: modal-fade-in 0.2s ease-out forwards; +} + +.modal-content { + animation: modal-scale-in 0.3s cubic-bezier(0.16, 1, 0.3, 1) forwards; +} + +/* 按钮交互优化 */ +.btn-primary { + @apply bg-gradient-to-r from-primary-orange to-red-500 text-white px-4 py-2 rounded-lg font-medium flex items-center gap-2 transition-all duration-200; +} + +.btn-primary:hover { + @apply from-orange-500 to-red-600; + transform: translateY(-1px); + box-shadow: 0 4px 12px rgba(255, 107, 53, 0.3); +} + +.btn-primary:active { + transform: translateY(0); + box-shadow: 0 2px 8px rgba(255, 107, 53, 0.2); +} + +.btn-primary:disabled { + @apply opacity-50 cursor-not-allowed; + transform: none; + box-shadow: none; +} + +.btn-secondary { + @apply bg-dark-600 text-gray-300 px-4 py-2 rounded-lg border border-dark-500 transition-all duration-200; +} + +.btn-secondary:hover { + @apply bg-dark-500 border-gray-500; +} + +.btn-secondary:active { + transform: scale(0.98); +} + +.btn-icon { + @apply p-2 rounded-lg transition-all duration-150; +} + +.btn-icon:hover { + @apply bg-dark-500; + transform: scale(1.05); +} + +.btn-icon:active { + transform: scale(0.95); +} + +/* 表单输入框交互 */ +.input-field { + @apply w-full bg-dark-600 border border-dark-500 rounded-lg px-4 py-2.5 text-white placeholder-gray-500 transition-all duration-200; +} + +.input-field:focus { + @apply outline-none border-primary-orange; + box-shadow: 0 0 0 3px rgba(255, 149, 0, 0.15); +} + +.input-field:hover:not(:focus) { + @apply border-gray-500; +} + +/* 表格样式优化 */ +.table-row { + @apply border-t border-dark-600 transition-all duration-200; +} + +.table-row:hover { + @apply bg-dark-600/50; +} + +.table-row:active { + @apply bg-dark-600; +} + +/* 表格头部 */ +.table-header { + @apply bg-dark-600 text-sm font-medium text-gray-400; +} + +/* 搜索框交互 */ +.search-input { + @apply bg-dark-600 border border-dark-500 rounded-lg py-2 pl-10 pr-4 text-white placeholder-gray-500 transition-all duration-200; +} + +.search-input:focus { + @apply outline-none border-primary-orange; + box-shadow: 0 0 0 3px rgba(255, 149, 0, 0.15); +} + +.search-input:hover:not(:focus) { + @apply border-gray-500; +} + +/* 空状态优化 */ +.empty-state { + @apply py-12 text-center transition-all duration-300; +} + +.empty-state-icon { + @apply text-gray-500 text-4xl mb-3 transition-transform duration-300; +} + +.empty-state:hover .empty-state-icon { + @apply text-gray-400; + transform: scale(1.1); +} + +/* 卡片悬停效果 */ +.card-hover { + @apply transition-all duration-200; +} + +.card-hover:hover { + @apply shadow-lg; + transform: translateY(-2px); +} + +/* 标签/徽章样式 */ +.badge { + @apply px-2 py-1 rounded text-sm font-medium transition-all duration-200; +} + +.badge-success { + @apply bg-green-500/20 text-green-400; +} + +.badge-warning { + @apply bg-yellow-500/20 text-yellow-400; +} + +.badge-error { + @apply bg-red-500/20 text-red-400; +} + +.badge-info { + @apply bg-blue-500/20 text-blue-400; +} + +.badge-default { + @apply bg-gray-500/20 text-gray-400; +} + +/* 状态指示点 */ +.status-dot { + @apply w-2 h-2 rounded-full transition-all duration-200; +} + +.status-dot-active { + @apply bg-primary-success; + box-shadow: 0 0 8px rgba(34, 197, 94, 0.5); +} + +.status-dot-inactive { + @apply bg-gray-500; +} + +.status-dot-error { + @apply bg-primary-danger; + box-shadow: 0 0 8px rgba(239, 68, 68, 0.5); +} + +/* 加载动画 */ +@keyframes pulse-dot { + 0%, 100% { + opacity: 1; + } + 50% { + opacity: 0.5; + } +} + +.loading-pulse { + animation: pulse-dot 1.5s ease-in-out infinite; +} + +/* 步骤指示器 */ +.step-indicator { + @apply flex items-center justify-center gap-4 py-4; +} + +.step-item { + @apply flex items-center gap-2; +} + +.step-circle { + @apply w-8 h-8 rounded-full flex items-center justify-center text-sm font-medium transition-all duration-300; +} + +.step-line { + @apply w-16 h-0.5 transition-all duration-300; +} + +/* 复选框优化 */ +.checkbox-custom { + @apply w-4 h-4 rounded border-dark-500 bg-dark-600 text-primary-cyan transition-all duration-200; +} + +.checkbox-custom:focus { + box-shadow: 0 0 0 3px rgba(34, 211, 238, 0.2); +} + +.checkbox-custom:checked { + @apply bg-primary-cyan border-primary-cyan; +} + +/* 工具提示 */ +.tooltip { + @apply absolute z-50 px-2 py-1 text-xs rounded bg-dark-700 text-white shadow-lg pointer-events-none opacity-0 transition-opacity duration-200; +} + +.tooltip-visible { + @apply opacity-100; +} + +/* 滚动条美化 */ +::-webkit-scrollbar { + @apply w-2 h-2; +} + +::-webkit-scrollbar-track { + @apply bg-dark-800 rounded-full; +} + +::-webkit-scrollbar-thumb { + @apply bg-dark-600 rounded-full; +} + +::-webkit-scrollbar-thumb:hover { + @apply bg-gray-600; +} + +/* 渐变文字 */ +.text-gradient { + @apply bg-clip-text text-transparent bg-gradient-to-r from-primary-orange to-red-500; +} + +/* 玻璃拟态效果 */ +.glass { + @apply bg-dark-700/80 backdrop-blur-md border border-dark-500/50; +} + +/* 焦点环 */ +.focus-ring { + @apply focus:outline-none focus:ring-2 focus:ring-primary-orange/50 focus:ring-offset-2 focus:ring-offset-dark-900; +} + +/* ===== 加载动画 ===== */ + +/* 简洁旋转 Loading */ +@keyframes loading-spin { + 0% { transform: rotate(0deg); } + 100% { transform: rotate(360deg); } +} + +.loading-spin { + animation: loading-spin 1s linear infinite; +} + +/* 三个点脉冲 */ +@keyframes loading-dots { + 0%, 80%, 100% { transform: scale(0.6); opacity: 0.5; } + 40% { transform: scale(1); opacity: 1; } +} + +.loading-dots span { + display: inline-block; + width: 8px; + height: 8px; + margin: 0 4px; + border-radius: 50%; + animation: loading-dots 1.4s ease-in-out infinite both; +} + +.loading-dots span:nth-child(1) { animation-delay: -0.32s; } +.loading-dots span:nth-child(2) { animation-delay: -0.16s; } +.loading-dots span:nth-child(3) { animation-delay: 0s; } + +/* 进度条动画 */ +@keyframes loading-progress { + 0% { width: 0%; } + 50% { width: 70%; } + 100% { width: 100%; } +} + +.loading-progress-bar { + animation: loading-progress 2s ease-in-out infinite; +} + +/* 骨架屏闪烁 */ +@keyframes loading-skeleton { + 0% { opacity: 0.4; } + 50% { opacity: 0.7; } + 100% { opacity: 0.4; } +} + +.loading-skeleton { + animation: loading-skeleton 1.5s ease-in-out infinite; +} diff --git a/web/src/views/Agents.vue b/web/src/views/Agents.vue index 91846c2..d1fefca 100644 --- a/web/src/views/Agents.vue +++ b/web/src/views/Agents.vue @@ -231,7 +231,7 @@ const statusClass = (status: string) => { Agents - @@ -245,7 +245,7 @@ const statusClass = (status: string) => { v-model="searchQuery" type="text" placeholder="Search agents..." - class="w-full bg-dark-600 border border-dark-500 rounded-lg py-2 pl-10 pr-4 text-white placeholder-gray-500 focus:outline-none focus:border-primary-orange" + class="search-input w-full" > @@ -271,7 +271,7 @@ const statusClass = (status: string) => { - +
{{ agent.name }}
{{ agent.description }}
@@ -294,21 +294,21 @@ const statusClass = (status: string) => {
@@ -400,7 +400,7 @@ const statusClass = (status: string) => { -
+
@@ -625,13 +625,13 @@ const statusClass = (status: string) => {