feat: 重构前后端架构,添加Go后端和Python Agent服务
- 新增 Go 语言后端服务(server/),包含用户认证、Agent管理、数据库连接等API - 新增 Python Agent 服务(agent/),实现Agent核心逻辑和工具集 - 前端从原生HTML迁移到Vue.js框架(web/src/) - 添加 Docker Compose 支持(docker-compose.yml) - 添加项目架构文档(docs/ARCHITECTURE.md) - 添加环境变量示例(.env.example)和本地启动脚本(start-local.ps1) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
61
server/internal/config/config.go
Normal file
61
server/internal/config/config.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Port string
|
||||
JWTSecret string
|
||||
DatabaseURL string
|
||||
PythonServiceURL string
|
||||
}
|
||||
|
||||
func Load() *Config {
|
||||
viper.SetConfigName("config")
|
||||
viper.SetConfigType("yaml")
|
||||
viper.AddConfigPath("./config")
|
||||
viper.AddConfigPath("../config")
|
||||
viper.AddConfigPath("../../config")
|
||||
|
||||
// 默认值
|
||||
viper.SetDefault("port", "8080")
|
||||
viper.SetDefault("jwt_secret", "your-secret-key-change-in-production")
|
||||
viper.SetDefault("python_service_url", "http://localhost:8081")
|
||||
viper.SetDefault("database_url", "root:root@tcp(localhost:3306)/x_agents?charset=utf8mb4&parseTime=True&loc=Local")
|
||||
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
log.Printf("Using default config: %v", err)
|
||||
}
|
||||
|
||||
return &Config{
|
||||
Port: viper.GetString("port"),
|
||||
JWTSecret: viper.GetString("jwt_secret"),
|
||||
DatabaseURL: viper.GetString("database_url"),
|
||||
PythonServiceURL: viper.GetString("python_service_url"),
|
||||
}
|
||||
}
|
||||
|
||||
func InitDB(cfg *Config) (*gorm.DB, error) {
|
||||
dsn := cfg.DatabaseURL
|
||||
if dsn == "" {
|
||||
return nil, fmt.Errorf("database URL is empty")
|
||||
}
|
||||
|
||||
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Info),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect database: %w", err)
|
||||
}
|
||||
|
||||
log.Println("Database connected successfully")
|
||||
return db, nil
|
||||
}
|
||||
80
server/internal/handler/approval_handler.go
Normal file
80
server/internal/handler/approval_handler.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"x-agents/server/internal/model"
|
||||
"x-agents/server/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ApprovalHandler struct {
|
||||
approvalService *service.ApprovalService
|
||||
}
|
||||
|
||||
func NewApprovalHandler(approvalService *service.ApprovalService) *ApprovalHandler {
|
||||
return &ApprovalHandler{approvalService: approvalService}
|
||||
}
|
||||
|
||||
// Approve 处理审批请求
|
||||
func (h *ApprovalHandler) Approve(c *gin.Context) {
|
||||
var req struct {
|
||||
RequestID string `json:"request_id" binding:"required"`
|
||||
Approved bool `json:"approved"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
var result interface{}
|
||||
var err error
|
||||
|
||||
if req.Approved {
|
||||
result, err = h.approvalService.Approve(req.RequestID, userID.(string))
|
||||
} else {
|
||||
result, err = h.approvalService.Reject(req.RequestID, userID.(string))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
// GetStatus 获取审批状态
|
||||
func (h *ApprovalHandler) GetStatus(c *gin.Context) {
|
||||
requestID := c.Param("id")
|
||||
|
||||
result, err := h.approvalService.GetApproval(requestID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "request not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
// ListPending 获取待审批列表
|
||||
func (h *ApprovalHandler) ListPending(c *gin.Context) {
|
||||
result, err := h.approvalService.GetPendingApprovals()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
result = []model.ToolApprovalRequest{}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"pending": result})
|
||||
}
|
||||
80
server/internal/handler/auth_handler.go
Normal file
80
server/internal/handler/auth_handler.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"x-agents/server/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type AuthHandler struct {
|
||||
authService *service.AuthService
|
||||
}
|
||||
|
||||
func NewAuthHandler(authService *service.AuthService) *AuthHandler {
|
||||
return &AuthHandler{authService: authService}
|
||||
}
|
||||
|
||||
type LoginRequest struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
type LoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
User interface{} `json:"user"`
|
||||
}
|
||||
|
||||
// Login 处理登录
|
||||
func (h *AuthHandler) Login(c *gin.Context) {
|
||||
var req LoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.authService.Login(service.LoginRequest{
|
||||
Username: req.Username,
|
||||
Password: req.Password,
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, LoginResponse{
|
||||
Token: resp.Token,
|
||||
User: gin.H{
|
||||
"id": resp.User.ID,
|
||||
"username": resp.User.Username,
|
||||
"email": resp.User.Email,
|
||||
"role": resp.User.RoleID,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Register 处理注册
|
||||
func (h *AuthHandler) Register(c *gin.Context) {
|
||||
var req struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.authService.Register(req.Username, req.Password, req.Email)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"id": user.ID,
|
||||
"username": user.Username,
|
||||
"email": user.Email,
|
||||
})
|
||||
}
|
||||
89
server/internal/handler/chat_handler.go
Normal file
89
server/internal/handler/chat_handler.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"x-agents/server/internal/model"
|
||||
"x-agents/server/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ChatHandler struct {
|
||||
chatService *service.ChatService
|
||||
}
|
||||
|
||||
func NewChatHandler(chatService *service.ChatService) *ChatHandler {
|
||||
return &ChatHandler{chatService: chatService}
|
||||
}
|
||||
|
||||
// Chat 处理聊天请求
|
||||
func (h *ChatHandler) Chat(c *gin.Context) {
|
||||
var req model.AgentRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 从上下文获取用户ID(由中间件设置)
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.chatService.Chat(c.Request.Context(), userID.(string), req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// ListAgents 获取 Agent 列表
|
||||
func (h *ChatHandler) ListAgents(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
agents, err := h.chatService.ListAgents(userID.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if agents == nil {
|
||||
agents = []model.Agent{}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"agents": agents})
|
||||
}
|
||||
|
||||
// CreateAgent 创建 Agent
|
||||
func (h *ChatHandler) CreateAgent(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := h.chatService.CreateAgent(userID.(string), req.Name, req.Description)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, agent)
|
||||
}
|
||||
112
server/internal/handler/database_handler.go
Normal file
112
server/internal/handler/database_handler.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"x-agents/server/internal/model"
|
||||
"x-agents/server/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type DatabaseHandler struct {
|
||||
service *service.DatabaseService
|
||||
}
|
||||
|
||||
func NewDatabaseHandler(svc *service.DatabaseService) *DatabaseHandler {
|
||||
return &DatabaseHandler{service: svc}
|
||||
}
|
||||
|
||||
// Check 检查数据库连接
|
||||
func (h *DatabaseHandler) Check(c *gin.Context) {
|
||||
var req model.CheckRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.service.Check(req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
// Create 创建数据库信息
|
||||
func (h *DatabaseHandler) Create(c *gin.Context) {
|
||||
var req model.CreateDatabaseRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
info, err := h.service.Create(req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, info)
|
||||
}
|
||||
|
||||
// GetByID 获取详情
|
||||
func (h *DatabaseHandler) GetByID(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
info, err := h.service.GetByID(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, info)
|
||||
}
|
||||
|
||||
// List 获取列表
|
||||
func (h *DatabaseHandler) List(c *gin.Context) {
|
||||
list, err := h.service.List()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if list == nil {
|
||||
list = []model.DatabaseInfo{}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"list": list})
|
||||
}
|
||||
|
||||
// Update 更新
|
||||
func (h *DatabaseHandler) Update(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
var req model.UpdateDatabaseRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
info, err := h.service.Update(id, req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, info)
|
||||
}
|
||||
|
||||
// Delete 删除
|
||||
func (h *DatabaseHandler) Delete(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
err := h.service.Delete(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "deleted"})
|
||||
}
|
||||
132
server/internal/handler/sub_table_handler.go
Normal file
132
server/internal/handler/sub_table_handler.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"x-agents/server/internal/model"
|
||||
"x-agents/server/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type SubTableHandler struct {
|
||||
service *service.SubTableService
|
||||
}
|
||||
|
||||
func NewSubTableHandler(svc *service.SubTableService) *SubTableHandler {
|
||||
return &SubTableHandler{service: svc}
|
||||
}
|
||||
|
||||
// Create 创建子表信息
|
||||
func (h *SubTableHandler) Create(c *gin.Context) {
|
||||
var req model.CreateSubTableRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
info, err := h.service.Create(req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, info)
|
||||
}
|
||||
|
||||
// GetByID 获取详情
|
||||
func (h *SubTableHandler) GetByID(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
info, err := h.service.GetByID(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, info)
|
||||
}
|
||||
|
||||
// ListByDatabase 获取数据库下所有子表
|
||||
func (h *SubTableHandler) ListByDatabase(c *gin.Context) {
|
||||
databaseID := c.Param("database_id")
|
||||
|
||||
list, err := h.service.ListByDatabaseID(databaseID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if list == nil {
|
||||
list = []model.SubTableInfo{}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"list": list})
|
||||
}
|
||||
|
||||
// GetMappingFromFile 从文件获取映射
|
||||
func (h *SubTableHandler) GetMappingFromFile(c *gin.Context) {
|
||||
databaseID := c.Param("database_id")
|
||||
|
||||
mapping, err := h.service.GetMappingFromFile(databaseID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if mapping == nil {
|
||||
c.JSON(http.StatusOK, gin.H{"mapping": nil, "message": "no mapping file found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"mapping": mapping})
|
||||
}
|
||||
|
||||
// Update 更新
|
||||
func (h *SubTableHandler) Update(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
var req model.UpdateSubTableRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
info, err := h.service.Update(id, req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, info)
|
||||
}
|
||||
|
||||
// Delete 删除
|
||||
func (h *SubTableHandler) Delete(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
err := h.service.Delete(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "deleted"})
|
||||
}
|
||||
|
||||
// GetTablesDDL 获取数据库下所有表及DDL
|
||||
func (h *SubTableHandler) GetTablesDDL(c *gin.Context) {
|
||||
databaseID := c.Param("database_id")
|
||||
|
||||
tables, err := h.service.GetTableDDLFromDatabase(databaseID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if tables == nil {
|
||||
tables = []model.TableDDLInfo{}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"tables": tables})
|
||||
}
|
||||
62
server/internal/handler/system_handler.go
Normal file
62
server/internal/handler/system_handler.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"x-agents/server/internal/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type SystemHandler struct{}
|
||||
|
||||
func NewSystemHandler() *SystemHandler {
|
||||
return &SystemHandler{}
|
||||
}
|
||||
|
||||
// GetSystemInfo 获取系统信息
|
||||
func (h *SystemHandler) GetSystemInfo(c *gin.Context) {
|
||||
info, err := getSystemInfo()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, info)
|
||||
}
|
||||
|
||||
// getSystemInfo 获取系统信息
|
||||
func getSystemInfo() (*model.SystemInfo, error) {
|
||||
// 获取CPU使用率
|
||||
cpuPercent, err := getCPUPercent()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取CPU核心数
|
||||
coreCount, err := getCPUCoreCount()
|
||||
if err != nil {
|
||||
coreCount = 0
|
||||
}
|
||||
|
||||
// 获取CPU型号
|
||||
modelName, err := getCPUModelName()
|
||||
if err != nil {
|
||||
modelName = "Unknown"
|
||||
}
|
||||
|
||||
// 获取内存信息
|
||||
memoryInfo, err := getMemoryInfo()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &model.SystemInfo{
|
||||
CPU: model.CPUInfo{
|
||||
Percent: cpuPercent,
|
||||
CoreCount: coreCount,
|
||||
ModelName: modelName,
|
||||
},
|
||||
Memory: *memoryInfo,
|
||||
}, nil
|
||||
}
|
||||
60
server/internal/handler/system_helper.go
Normal file
60
server/internal/handler/system_helper.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/shirou/gopsutil/v3/cpu"
|
||||
"github.com/shirou/gopsutil/v3/mem"
|
||||
"x-agents/server/internal/model"
|
||||
)
|
||||
|
||||
func getCPUPercent() (float64, error) {
|
||||
percent, err := cpu.Percent(0, false)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(percent) > 0 {
|
||||
return percent[0], nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func getCPUCoreCount() (int, error) {
|
||||
count, err := cpu.Counts(false)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func getCPUModelName() (string, error) {
|
||||
info, err := cpu.Info()
|
||||
if err != nil {
|
||||
return "Unknown", err
|
||||
}
|
||||
if len(info) > 0 {
|
||||
return info[0].ModelName, nil
|
||||
}
|
||||
return "Unknown", nil
|
||||
}
|
||||
|
||||
func getMemoryInfo() (*model.MemoryInfo, error) {
|
||||
v, err := mem.VirtualMemory()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 计算使用率
|
||||
percent := 0.0
|
||||
if v.Total > 0 {
|
||||
percent = float64(v.Used) / float64(v.Total) * 100
|
||||
}
|
||||
|
||||
return &model.MemoryInfo{
|
||||
Total: v.Total,
|
||||
Used: v.Used,
|
||||
Available: v.Available,
|
||||
Percent: percent,
|
||||
TotalGB: float64(v.Total) / 1024 / 1024 / 1024,
|
||||
UsedGB: float64(v.Used) / 1024 / 1024 / 1024,
|
||||
AvailableGB: float64(v.Available) / 1024 / 1024 / 1024,
|
||||
}, nil
|
||||
}
|
||||
71
server/internal/middleware/auth.go
Normal file
71
server/internal/middleware/auth.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"x-agents/server/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORS 中间件
|
||||
func CORS() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Authorization")
|
||||
c.Header("Access-Control-Max-Age", "86400")
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// Recovery 中间件 - 恢复 panic
|
||||
func Recovery() gin.HandlerFunc {
|
||||
return gin.Recovery()
|
||||
}
|
||||
|
||||
// Auth 认证中间件
|
||||
func Auth(jwtSecret string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 从 Header 获取 Token
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header required"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 解析 Bearer Token
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid authorization format"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := parts[1]
|
||||
|
||||
// 验证 Token
|
||||
authService := service.NewAuthService(jwtSecret, nil)
|
||||
claims, err := authService.ValidateToken(tokenString)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将用户信息存入上下文
|
||||
c.Set("user_id", claims["sub"])
|
||||
c.Set("username", claims["username"])
|
||||
c.Set("role", claims["role"])
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
53
server/internal/model/agent.go
Normal file
53
server/internal/model/agent.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// SecurityLevel 安全等级
|
||||
type SecurityLevel string
|
||||
|
||||
const (
|
||||
SecurityLevelSafe SecurityLevel = "safe"
|
||||
SecurityLevelReview SecurityLevel = "review"
|
||||
SecurityLevelDanger SecurityLevel = "danger"
|
||||
)
|
||||
|
||||
// Agent 智能体
|
||||
type Agent struct {
|
||||
ID string `json:"id" gorm:"primaryKey"`
|
||||
Name string `json:"name" gorm:"size:100;not null"`
|
||||
Description string `json:"description" gorm:"type:text"`
|
||||
OwnerID string `json:"owner_id" gorm:"size:50;not null;index"`
|
||||
|
||||
// Agent能力配置
|
||||
Capabilities []string `json:"capabilities" gorm:"type:text"` // JSON数组,可用工具列表
|
||||
MemoryLimit int64 `json:"memory_limit" gorm:"default:134217728"` // 128MB
|
||||
Timeout int `json:"timeout" gorm:"default:60"` // 60秒
|
||||
|
||||
// 安全配置
|
||||
SecurityLevel SecurityLevel `json:"security_level" gorm:"size:20;default:'safe'"`
|
||||
AllowDangerousTools bool `json:"allow_dangerous_tools" gorm:"default:false"`
|
||||
|
||||
// 状态
|
||||
IsActive bool `json:"is_active" gorm:"default:true"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// AgentRequest 聊天请求
|
||||
type AgentRequest struct {
|
||||
AgentID string `json:"agent_id" binding:"required"`
|
||||
Message string `json:"message" binding:"required"`
|
||||
SessionID string `json:"session_id"`
|
||||
Context map[string]interface{} `json:"context"`
|
||||
}
|
||||
|
||||
// AgentResponse 聊天响应
|
||||
type AgentResponse struct {
|
||||
Reply string `json:"reply"`
|
||||
SessionID string `json:"session_id"`
|
||||
ToolsUsed []string `json:"tools_used"`
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
}
|
||||
76
server/internal/model/audit.go
Normal file
76
server/internal/model/audit.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AuditAction 审计动作
|
||||
type AuditAction string
|
||||
|
||||
const (
|
||||
AuditActionLogin AuditAction = "login"
|
||||
AuditActionLogout AuditAction = "logout"
|
||||
AuditActionChat AuditAction = "chat"
|
||||
AuditActionToolExecute AuditAction = "tool_execute"
|
||||
AuditActionToolApprove AuditAction = "tool_approve"
|
||||
AuditActionToolReject AuditAction = "tool_reject"
|
||||
AuditActionAgentCreate AuditAction = "agent_create"
|
||||
AuditActionAgentUpdate AuditAction = "agent_update"
|
||||
AuditActionAgentDelete AuditAction = "agent_delete"
|
||||
)
|
||||
|
||||
// AuditLog 审计日志
|
||||
type AuditLog struct {
|
||||
ID string `json:"id" gorm:"primaryKey"`
|
||||
UserID string `json:"user_id" gorm:"size:50;index"`
|
||||
AgentID string `json:"agent_id" gorm:"size:50;index"`
|
||||
Action AuditAction `json:"action" gorm:"size:50;index"`
|
||||
Details JSONMap `json:"details" gorm:"type:jsonb"`
|
||||
Result string `json:"result" gorm:"size:20"` // success, failed, rejected
|
||||
IPAddress string `json:"ip_address" gorm:"size:45"`
|
||||
UserAgent string `json:"user_agent" gorm:"size:255"`
|
||||
CreatedAt time.Time `json:"created_at" gorm:"index"`
|
||||
}
|
||||
|
||||
// ApprovalStatus 审批状态
|
||||
type ApprovalStatus string
|
||||
|
||||
const (
|
||||
ApprovalStatusPending ApprovalStatus = "pending"
|
||||
ApprovalStatusApproved ApprovalStatus = "approved"
|
||||
ApprovalStatusRejected ApprovalStatus = "rejected"
|
||||
)
|
||||
|
||||
// ToolApprovalRequest 工具审批请求
|
||||
type ToolApprovalRequest struct {
|
||||
ID string `json:"id" gorm:"primaryKey"`
|
||||
ToolName string `json:"tool_name" gorm:"size:100;index"`
|
||||
Params JSONMap `json:"params" gorm:"type:jsonb"`
|
||||
UserID string `json:"user_id" gorm:"size:50;index"`
|
||||
AgentID string `json:"agent_id" gorm:"size:50"`
|
||||
Reason string `json:"reason" gorm:"type:text"`
|
||||
Status ApprovalStatus `json:"status" gorm:"size:20;default:'pending';index"`
|
||||
ReviewedBy *string `json:"reviewed_by" gorm:"size:50"`
|
||||
ReviewedAt *time.Time `json:"reviewed_at"`
|
||||
Result *string `json:"result" gorm:"type:text"` // 执行结果
|
||||
CreatedAt time.Time `json:"created_at" gorm:"index"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// JSONMap JSON数据映射
|
||||
type JSONMap map[string]interface{}
|
||||
|
||||
func (j JSONMap) MarshalJSON() ([]byte, error) {
|
||||
if j == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return json.Marshal(j)
|
||||
}
|
||||
|
||||
func (j *JSONMap) UnmarshalJSON(data []byte) error {
|
||||
if j == nil {
|
||||
*j = make(map[string]interface{})
|
||||
}
|
||||
return json.Unmarshal(data, j)
|
||||
}
|
||||
83
server/internal/model/database_info.go
Normal file
83
server/internal/model/database_info.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// DatabaseInfo 数据库连接信息
|
||||
type DatabaseInfo struct {
|
||||
ID string `json:"id" gorm:"primaryKey;size:36"` // UUID
|
||||
Name string `json:"name" gorm:"size:100;not null"` // 数据库名称
|
||||
Description string `json:"description" gorm:"size:500"` // 描述
|
||||
DBType string `json:"db_type" gorm:"size:20;not null"` // 数据库类型: mysql, postgres, mongodb等
|
||||
Host string `json:"host" gorm:"size:255;not null"` // 主机地址
|
||||
Port int `json:"port" gorm:"not null"` // 端口
|
||||
Username string `json:"username" gorm:"size:100;not null"` // 用户名
|
||||
Password string `json:"password" gorm:"size:255"` // 密码(建议加密存储)
|
||||
Database string `json:"database" gorm:"size:100"` // 数据库名
|
||||
TableCount int `json:"table_count" gorm:"default:0"` // 子表数量
|
||||
|
||||
// 连接选项
|
||||
Charset string `json:"charset" gorm:"size:20;default:utf8mb4"` // 字符集
|
||||
SSLMode string `json:"ssl_mode" gorm:"size:20"` // SSL模式
|
||||
|
||||
// 时间
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName 表名
|
||||
func (DatabaseInfo) TableName() string {
|
||||
return "database_info"
|
||||
}
|
||||
|
||||
// CreateRequest 创建数据库信息请求(支持同时保存子表配置)
|
||||
type CreateDatabaseRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
DBType string `json:"db_type" binding:"required"`
|
||||
Host string `json:"host" binding:"required"`
|
||||
Port int `json:"port" binding:"required"`
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password"`
|
||||
Database string `json:"database"`
|
||||
Charset string `json:"charset"`
|
||||
SSLMode string `json:"ssl_mode"`
|
||||
SubTables []CreateSubTableRequest `json:"sub_tables"` // 可选,子表配置
|
||||
}
|
||||
|
||||
// UpdateRequest 更新数据库信息请求
|
||||
type UpdateDatabaseRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
DBType string `json:"db_type"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Database string `json:"database"`
|
||||
TableCount int `json:"table_count"`
|
||||
Charset string `json:"charset"`
|
||||
SSLMode string `json:"ssl_mode"`
|
||||
}
|
||||
|
||||
// CheckRequest 检查连接请求
|
||||
type CheckRequest struct {
|
||||
DBType string `json:"db_type" binding:"required"`
|
||||
Host string `json:"host" binding:"required"`
|
||||
Port int `json:"port" binding:"required"`
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password"`
|
||||
Database string `json:"database"`
|
||||
Charset string `json:"charset"`
|
||||
SSLMode string `json:"ssl_mode"`
|
||||
DatabaseID string `json:"database_id"` // 可选,用于获取已保存的字段映射
|
||||
}
|
||||
|
||||
// CheckResponse 检查连接响应
|
||||
type CheckResponse struct {
|
||||
Success bool `json:"success"` // 是否连接成功
|
||||
Message string `json:"message"` // 消息
|
||||
Tables []TableDDLInfo `json:"tables,omitempty"` // 表列表(连接成功时返回)
|
||||
Database string `json:"database"` // 数据库名
|
||||
}
|
||||
117
server/internal/model/sub_table_info.go
Normal file
117
server/internal/model/sub_table_info.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TableDDLInfo 表结构信息
|
||||
type TableDDLInfo struct {
|
||||
TableName string `json:"table_name"` // 表名
|
||||
TableComment string `json:"table_comment"` // 表注释
|
||||
Columns []ColumnInfo `json:"columns"` // 列信息
|
||||
DDL string `json:"ddl"` // 建表DDL
|
||||
Indexes []IndexInfo `json:"indexes"` // 索引信息
|
||||
}
|
||||
|
||||
// ColumnInfo 列信息
|
||||
type ColumnInfo struct {
|
||||
ColumnName string `json:"column_name"` // 列名
|
||||
DataType string `json:"data_type"` // 数据类型
|
||||
ColumnType string `json:"column_type"` // 列类型(含长度)
|
||||
IsNullable string `json:"is_nullable"` // 是否可空
|
||||
DefaultValue string `json:"default_value"` // 默认值
|
||||
ColumnKey string `json:"column_key"` // 主键/索引
|
||||
Extra string `json:"extra"` // 自增等
|
||||
ColumnComment string `json:"column_comment"` // 列注释
|
||||
MappedName string `json:"mapped_name"` // 字段中文映射名
|
||||
}
|
||||
|
||||
// IndexInfo 索引信息
|
||||
type IndexInfo struct {
|
||||
IndexName string `json:"index_name"` // 索引名
|
||||
ColumnName string `json:"column_name"` // 列名
|
||||
NonUnique int `json:"non_unique"` // 是否唯一
|
||||
IndexType string `json:"index_type"` // 索引类型
|
||||
}
|
||||
|
||||
// SubTableInfo 子表信息
|
||||
type SubTableInfo struct {
|
||||
ID string `json:"id"` // UUID
|
||||
DatabaseID string `json:"database_id"` // 关联的数据库ID
|
||||
ParentTable string `json:"parent_table"` // 父表名
|
||||
SubTableName string `json:"sub_table_name"` // 子表名
|
||||
SubTableComment string `json:"sub_table_comment"` // 子表注释
|
||||
MappingType string `json:"mapping_type" gorm:"type:varchar(20)"` // 映射类型
|
||||
RelationField string `json:"relation_field" gorm:"type:varchar(100)"` // 关联字段
|
||||
RelationType string `json:"relation_type" gorm:"type:varchar(20)"` // 关联类型
|
||||
Fields string `json:"-" gorm:"type:longtext"` // 字段映射列表(JSON 格式,内部存储)
|
||||
FieldsList []FieldMapping `json:"fields" gorm:"-"` // 字段映射列表(返回给前端)
|
||||
DDL string `json:"ddl" gorm:"type:longtext"` // 建表 DDL
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// FieldMapping 字段映射
|
||||
type FieldMapping struct {
|
||||
ColumnName string `json:"column_name"` // 列名
|
||||
MappedName string `json:"mapped_name"` // 中文映射名
|
||||
}
|
||||
|
||||
// GetFields 获取字段映射列表
|
||||
func (s *SubTableInfo) GetFields() []FieldMapping {
|
||||
if s.Fields == "" {
|
||||
return nil
|
||||
}
|
||||
var fields []FieldMapping
|
||||
if err := json.Unmarshal([]byte(s.Fields), &fields); err != nil {
|
||||
return nil
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
// SetFields 设置字段映射列表
|
||||
func (s *SubTableInfo) SetFields(fields []FieldMapping) {
|
||||
if len(fields) == 0 {
|
||||
s.Fields = ""
|
||||
return
|
||||
}
|
||||
data, _ := json.Marshal(fields)
|
||||
s.Fields = string(data)
|
||||
}
|
||||
|
||||
// TableName 表名
|
||||
func (SubTableInfo) TableName() string {
|
||||
return "sub_table_info"
|
||||
}
|
||||
|
||||
// CreateSubTableRequest 创建子表请求
|
||||
type CreateSubTableRequest struct {
|
||||
DatabaseID string `json:"database_id" binding:"required"`
|
||||
ParentTable string `json:"parent_table" binding:"required"`
|
||||
SubTableName string `json:"sub_table_name" binding:"required"`
|
||||
SubTableComment string `json:"sub_table_comment"`
|
||||
MappingType string `json:"mapping_type"`
|
||||
RelationField string `json:"relation_field"`
|
||||
RelationType string `json:"relation_type"`
|
||||
Fields []FieldMapping `json:"fields"` // 字段映射列表
|
||||
}
|
||||
|
||||
// UpdateSubTableRequest 更新子表请求
|
||||
type UpdateSubTableRequest struct {
|
||||
ParentTable string `json:"parent_table"`
|
||||
SubTableName string `json:"sub_table_name"`
|
||||
SubTableComment string `json:"sub_table_comment"`
|
||||
MappingType string `json:"mapping_type"`
|
||||
RelationField string `json:"relation_field"`
|
||||
RelationType string `json:"relation_type"`
|
||||
}
|
||||
|
||||
// SubTableMapping 完整的子表映射配置(存储到文件的格式)
|
||||
type SubTableMapping struct {
|
||||
DatabaseID string `json:"database_id"`
|
||||
DatabaseName string `json:"database_name"`
|
||||
DBType string `json:"db_type"`
|
||||
Tables []SubTableInfo `json:"tables"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
25
server/internal/model/system_info.go
Normal file
25
server/internal/model/system_info.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package model
|
||||
|
||||
// SystemInfo 系统信息
|
||||
type SystemInfo struct {
|
||||
CPU CPUInfo `json:"cpu"`
|
||||
Memory MemoryInfo `json:"memory"`
|
||||
}
|
||||
|
||||
// CPUInfo CPU信息
|
||||
type CPUInfo struct {
|
||||
Percent float64 `json:"percent"` // CPU使用率
|
||||
CoreCount int `json:"core_count"` // 核心数
|
||||
ModelName string `json:"model_name"` // CPU型号
|
||||
}
|
||||
|
||||
// MemoryInfo 内存信息
|
||||
type MemoryInfo struct {
|
||||
Total uint64 `json:"total"` // 总内存(字节)
|
||||
Used uint64 `json:"used"` // 已使用(字节)
|
||||
Available uint64 `json:"available"` // 可用(字节)
|
||||
Percent float64 `json:"percent"` // 使用率
|
||||
TotalGB float64 `json:"total_gb"` // 总内存(GB)
|
||||
UsedGB float64 `json:"used_gb"` // 已使用(GB)
|
||||
AvailableGB float64 `json:"available_gb"` // 可用(GB)
|
||||
}
|
||||
50
server/internal/model/user.go
Normal file
50
server/internal/model/user.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// PermissionLevel 权限级别
|
||||
type PermissionLevel int
|
||||
|
||||
const (
|
||||
PermissionRead PermissionLevel = iota + 1
|
||||
PermissionWrite
|
||||
PermissionExecute
|
||||
PermissionAdmin
|
||||
)
|
||||
|
||||
// Role 角色
|
||||
type Role struct {
|
||||
ID string `json:"id" gorm:"primaryKey"`
|
||||
Name string `json:"name" gorm:"uniqueIndex"`
|
||||
Permissions []PermissionLevel `json:"permissions" gorm:"type:int[]"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// User 用户
|
||||
type User struct {
|
||||
ID string `json:"id" gorm:"primaryKey"`
|
||||
Username string `json:"username" gorm:"uniqueIndex;size:50;not null"`
|
||||
Password string `json:"-" gorm:"not null"`
|
||||
Email string `json:"email" gorm:"index"`
|
||||
RoleID string `json:"role_id" gorm:"size:50;not null"`
|
||||
Role *Role `json:"role,omitempty" gorm:"foreignKey:RoleID"`
|
||||
IsActive bool `json:"is_active" gorm:"default:true"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// HasPermission 检查是否有权限
|
||||
func (u *User) HasPermission(level PermissionLevel) bool {
|
||||
if u.Role == nil {
|
||||
return false
|
||||
}
|
||||
for _, p := range u.Role.Permissions {
|
||||
if p >= level {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
48
server/internal/repository/agent_repo.go
Normal file
48
server/internal/repository/agent_repo.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"x-agents/server/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type AgentRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewAgentRepository(db *gorm.DB) *AgentRepository {
|
||||
return &AgentRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *AgentRepository) Create(agent *model.Agent) error {
|
||||
return r.db.Create(agent).Error
|
||||
}
|
||||
|
||||
func (r *AgentRepository) FindByID(id string) (*model.Agent, error) {
|
||||
var agent model.Agent
|
||||
err := r.db.First(&agent, "id = ?", id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &agent, nil
|
||||
}
|
||||
|
||||
func (r *AgentRepository) FindByOwnerID(ownerID string) ([]model.Agent, error) {
|
||||
var agents []model.Agent
|
||||
err := r.db.Where("owner_id = ?", ownerID).Find(&agents).Error
|
||||
return agents, err
|
||||
}
|
||||
|
||||
func (r *AgentRepository) FindAll() ([]model.Agent, error) {
|
||||
var agents []model.Agent
|
||||
err := r.db.Where("is_active = ?", true).Find(&agents).Error
|
||||
return agents, err
|
||||
}
|
||||
|
||||
func (r *AgentRepository) Update(agent *model.Agent) error {
|
||||
return r.db.Save(agent).Error
|
||||
}
|
||||
|
||||
func (r *AgentRepository) Delete(id string) error {
|
||||
return r.db.Delete(&model.Agent{}, "id = ?", id).Error
|
||||
}
|
||||
56
server/internal/repository/audit_repo.go
Normal file
56
server/internal/repository/audit_repo.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"x-agents/server/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type AuditRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewAuditRepository(db *gorm.DB) *AuditRepository {
|
||||
return &AuditRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *AuditRepository) Create(log *model.AuditLog) error {
|
||||
return r.db.Create(log).Error
|
||||
}
|
||||
|
||||
func (r *AuditRepository) FindByUserID(userID string, limit int) ([]model.AuditLog, error) {
|
||||
var logs []model.AuditLog
|
||||
err := r.db.Where("user_id = ?", userID).Order("created_at DESC").Limit(limit).Find(&logs).Error
|
||||
return logs, err
|
||||
}
|
||||
|
||||
func (r *AuditRepository) FindByAgentID(agentID string, limit int) ([]model.AuditLog, error) {
|
||||
var logs []model.AuditLog
|
||||
err := r.db.Where("agent_id = ?", agentID).Order("created_at DESC").Limit(limit).Find(&logs).Error
|
||||
return logs, err
|
||||
}
|
||||
|
||||
// ToolApproval 工具审批仓储
|
||||
|
||||
func (r *AuditRepository) CreateApproval(req *model.ToolApprovalRequest) error {
|
||||
return r.db.Create(req).Error
|
||||
}
|
||||
|
||||
func (r *AuditRepository) FindApprovalByID(id string) (*model.ToolApprovalRequest, error) {
|
||||
var req model.ToolApprovalRequest
|
||||
err := r.db.First(&req, "id = ?", id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &req, nil
|
||||
}
|
||||
|
||||
func (r *AuditRepository) FindPendingApprovals() ([]model.ToolApprovalRequest, error) {
|
||||
var reqs []model.ToolApprovalRequest
|
||||
err := r.db.Where("status = ?", model.ApprovalStatusPending).Order("created_at ASC").Find(&reqs).Error
|
||||
return reqs, err
|
||||
}
|
||||
|
||||
func (r *AuditRepository) UpdateApproval(req *model.ToolApprovalRequest) error {
|
||||
return r.db.Save(req).Error
|
||||
}
|
||||
47
server/internal/repository/database_repo.go
Normal file
47
server/internal/repository/database_repo.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"x-agents/server/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type DatabaseRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewDatabaseRepository(db *gorm.DB) *DatabaseRepository {
|
||||
return &DatabaseRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建数据库信息
|
||||
func (r *DatabaseRepository) Create(info *model.DatabaseInfo) error {
|
||||
return r.db.Create(info).Error
|
||||
}
|
||||
|
||||
// FindByID 根据ID查询
|
||||
func (r *DatabaseRepository) FindByID(id string) (*model.DatabaseInfo, error) {
|
||||
var info model.DatabaseInfo
|
||||
err := r.db.First(&info, "id = ?", id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
// FindAll 查询所有
|
||||
func (r *DatabaseRepository) FindAll() ([]model.DatabaseInfo, error) {
|
||||
var list []model.DatabaseInfo
|
||||
err := r.db.Order("created_at DESC").Find(&list).Error
|
||||
return list, err
|
||||
}
|
||||
|
||||
// Update 更新
|
||||
func (r *DatabaseRepository) Update(id string, info *model.DatabaseInfo) error {
|
||||
return r.db.Model(&model.DatabaseInfo{}).Where("id = ?", id).Updates(info).Error
|
||||
}
|
||||
|
||||
// Delete 删除
|
||||
func (r *DatabaseRepository) Delete(id string) error {
|
||||
return r.db.Delete(&model.DatabaseInfo{}, "id = ?", id).Error
|
||||
}
|
||||
53
server/internal/repository/sub_table_repo.go
Normal file
53
server/internal/repository/sub_table_repo.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"x-agents/server/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type SubTableRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewSubTableRepository(db *gorm.DB) *SubTableRepository {
|
||||
return &SubTableRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建子表信息
|
||||
func (r *SubTableRepository) Create(info *model.SubTableInfo) error {
|
||||
return r.db.Create(info).Error
|
||||
}
|
||||
|
||||
// FindByID 根据ID查询
|
||||
func (r *SubTableRepository) FindByID(id string) (*model.SubTableInfo, error) {
|
||||
var info model.SubTableInfo
|
||||
if err := r.db.Where("id = ?", id).First(&info).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
// FindByDatabaseID 根据数据库ID查询所有子表
|
||||
func (r *SubTableRepository) FindByDatabaseID(databaseID string) ([]model.SubTableInfo, error) {
|
||||
var list []model.SubTableInfo
|
||||
if err := r.db.Where("database_id = ?", databaseID).Find(&list).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return list, nil
|
||||
}
|
||||
|
||||
// Update 更新子表信息
|
||||
func (r *SubTableRepository) Update(id string, info *model.SubTableInfo) error {
|
||||
return r.db.Model(info).Where("id = ?", id).Updates(info).Error
|
||||
}
|
||||
|
||||
// Delete 删除子表信息
|
||||
func (r *SubTableRepository) Delete(id string) error {
|
||||
return r.db.Where("id = ?", id).Delete(&model.SubTableInfo{}).Error
|
||||
}
|
||||
|
||||
// DeleteByDatabaseID 删除数据库下所有子表信息
|
||||
func (r *SubTableRepository) DeleteByDatabaseID(databaseID string) error {
|
||||
return r.db.Where("database_id = ?", databaseID).Delete(&model.SubTableInfo{}).Error
|
||||
}
|
||||
66
server/internal/repository/user_repo.go
Normal file
66
server/internal/repository/user_repo.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"x-agents/server/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UserRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewUserRepository(db *gorm.DB) *UserRepository {
|
||||
return &UserRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *UserRepository) Create(user *model.User) error {
|
||||
return r.db.Create(user).Error
|
||||
}
|
||||
|
||||
func (r *UserRepository) FindByID(id string) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.Preload("Role").First(&user, "id = ?", id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) FindByUsername(username string) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.Preload("Role").First(&user, "username = ?", username).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) FindAll() ([]model.User, error) {
|
||||
var users []model.User
|
||||
err := r.db.Preload("Role").Find(&users).Error
|
||||
return users, err
|
||||
}
|
||||
|
||||
func (r *UserRepository) Update(user *model.User) error {
|
||||
return r.db.Save(user).Error
|
||||
}
|
||||
|
||||
func (r *UserRepository) Delete(id string) error {
|
||||
return r.db.Delete(&model.User{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// FindRoleByID 根据ID查找角色
|
||||
func (r *UserRepository) FindRoleByID(id string) (*model.Role, error) {
|
||||
var role model.Role
|
||||
err := r.db.First(&role, "id = ?", id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &role, nil
|
||||
}
|
||||
|
||||
// CreateRole 创建角色
|
||||
func (r *UserRepository) CreateRole(role *model.Role) error {
|
||||
return r.db.Create(role).Error
|
||||
}
|
||||
101
server/internal/service/approval_service.go
Normal file
101
server/internal/service/approval_service.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"x-agents/server/internal/model"
|
||||
"x-agents/server/internal/repository"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type ApprovalService struct {
|
||||
auditRepo *repository.AuditRepository
|
||||
}
|
||||
|
||||
func NewApprovalService(auditRepo *repository.AuditRepository) *ApprovalService {
|
||||
return &ApprovalService{auditRepo: auditRepo}
|
||||
}
|
||||
|
||||
// CreateApprovalRequest 创建审批请求
|
||||
func (s *ApprovalService) CreateApprovalRequest(
|
||||
toolName string,
|
||||
params map[string]interface{},
|
||||
userID string,
|
||||
agentID string,
|
||||
reason string,
|
||||
) (*model.ToolApprovalRequest, error) {
|
||||
|
||||
req := &model.ToolApprovalRequest{
|
||||
ID: uuid.New().String(),
|
||||
ToolName: toolName,
|
||||
Params: params,
|
||||
UserID: userID,
|
||||
AgentID: agentID,
|
||||
Reason: reason,
|
||||
Status: model.ApprovalStatusPending,
|
||||
}
|
||||
|
||||
if err := s.auditRepo.CreateApproval(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// Approve 批准请求
|
||||
func (s *ApprovalService) Approve(requestID, reviewedBy string) (*model.ToolApprovalRequest, error) {
|
||||
req, err := s.auditRepo.FindApprovalByID(requestID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request not found: %w", err)
|
||||
}
|
||||
|
||||
if req.Status != model.ApprovalStatusPending {
|
||||
return nil, fmt.Errorf("request already processed")
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
req.Status = model.ApprovalStatusApproved
|
||||
req.ReviewedBy = &reviewedBy
|
||||
req.ReviewedAt = &now
|
||||
|
||||
if err := s.auditRepo.UpdateApproval(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// Reject 拒绝请求
|
||||
func (s *ApprovalService) Reject(requestID, reviewedBy string) (*model.ToolApprovalRequest, error) {
|
||||
req, err := s.auditRepo.FindApprovalByID(requestID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request not found: %w", err)
|
||||
}
|
||||
|
||||
if req.Status != model.ApprovalStatusPending {
|
||||
return nil, fmt.Errorf("request already processed")
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
req.Status = model.ApprovalStatusRejected
|
||||
req.ReviewedBy = &reviewedBy
|
||||
req.ReviewedAt = &now
|
||||
|
||||
if err := s.auditRepo.UpdateApproval(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// GetApproval 获取审批状态
|
||||
func (s *ApprovalService) GetApproval(requestID string) (*model.ToolApprovalRequest, error) {
|
||||
return s.auditRepo.FindApprovalByID(requestID)
|
||||
}
|
||||
|
||||
// GetPendingApprovals 获取待审批列表
|
||||
func (s *ApprovalService) GetPendingApprovals() ([]model.ToolApprovalRequest, error) {
|
||||
return s.auditRepo.FindPendingApprovals()
|
||||
}
|
||||
145
server/internal/service/auth_service.go
Normal file
145
server/internal/service/auth_service.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"x-agents/server/internal/model"
|
||||
"x-agents/server/internal/repository"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidCredentials = errors.New("invalid credentials")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
)
|
||||
|
||||
type AuthService struct {
|
||||
jwtSecret string
|
||||
userRepo *repository.UserRepository
|
||||
}
|
||||
|
||||
func NewAuthService(jwtSecret string, userRepo *repository.UserRepository) *AuthService {
|
||||
return &AuthService{
|
||||
jwtSecret: jwtSecret,
|
||||
userRepo: userRepo,
|
||||
}
|
||||
}
|
||||
|
||||
type LoginRequest struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
type LoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
User *model.User `json:"user"`
|
||||
}
|
||||
|
||||
func (s *AuthService) Login(req LoginRequest) (*LoginResponse, error) {
|
||||
// 查找用户
|
||||
user, err := s.userRepo.FindByUsername(req.Username)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
// 验证密码
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.Password)); err != nil {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
// 生成Token
|
||||
token, err := s.generateToken(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &LoginResponse{
|
||||
Token: token,
|
||||
User: user,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) generateToken(user *model.User) (string, error) {
|
||||
claims := jwt.MapClaims{
|
||||
"sub": user.ID,
|
||||
"username": user.Username,
|
||||
"role": user.RoleID,
|
||||
"exp": time.Now().Add(time.Hour * 24 * 7).Unix(), // 7天有效期
|
||||
"iat": time.Now().Unix(),
|
||||
"expires_at": time.Now().Add(time.Hour * 24 * 7).Format(time.RFC3339),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(s.jwtSecret))
|
||||
}
|
||||
|
||||
func (s *AuthService) ValidateToken(tokenString string) (jwt.MapClaims, error) {
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, errors.New("unexpected signing method")
|
||||
}
|
||||
return []byte(s.jwtSecret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
func (s *AuthService) Register(username, password, email string) (*model.User, error) {
|
||||
// 检查用户是否已存在
|
||||
_, err := s.userRepo.FindByUsername(username)
|
||||
if err == nil {
|
||||
return nil, errors.New("user already exists")
|
||||
}
|
||||
|
||||
// 加密密码
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建用户
|
||||
user := &model.User{
|
||||
ID: uuid.New().String(),
|
||||
Username: username,
|
||||
Password: string(hashedPassword),
|
||||
Email: email,
|
||||
RoleID: "user",
|
||||
IsActive: true,
|
||||
}
|
||||
|
||||
// 如果没有用户,创建默认管理员角色
|
||||
role, err := s.userRepo.FindRoleByID(user.RoleID)
|
||||
if err != nil {
|
||||
// 创建默认角色
|
||||
role = &model.Role{
|
||||
ID: "user",
|
||||
Name: "user",
|
||||
Permissions: []model.PermissionLevel{model.PermissionRead, model.PermissionWrite},
|
||||
}
|
||||
s.userRepo.CreateRole(role)
|
||||
user.Role = role
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetUserByID 根据ID获取用户
|
||||
func (s *AuthService) GetUserByID(id string) (*model.User, error) {
|
||||
return s.userRepo.FindByID(id)
|
||||
}
|
||||
146
server/internal/service/chat_service.go
Normal file
146
server/internal/service/chat_service.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"x-agents/server/internal/model"
|
||||
"x-agents/server/internal/repository"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type ChatService struct {
|
||||
pythonURL string
|
||||
agentRepo *repository.AgentRepository
|
||||
}
|
||||
|
||||
func NewChatService(pythonURL string, agentRepo *repository.AgentRepository) *ChatService {
|
||||
return &ChatService{
|
||||
pythonURL: pythonURL,
|
||||
agentRepo: agentRepo,
|
||||
}
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
AgentID string `json:"agent_id"`
|
||||
Message string `json:"message"`
|
||||
SessionID string `json:"session_id"`
|
||||
Context map[string]interface{} `json:"context"`
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Reply string `json:"reply"`
|
||||
SessionID string `json:"session_id"`
|
||||
ToolsUsed []string `json:"tools_used"`
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
}
|
||||
|
||||
// Chat 处理聊天请求
|
||||
func (s *ChatService) Chat(ctx context.Context, userID string, req model.AgentRequest) (*model.AgentResponse, error) {
|
||||
// 1. 检查 Agent 是否存在
|
||||
agent, err := s.agentRepo.FindByID(req.AgentID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("agent not found: %w", err)
|
||||
}
|
||||
|
||||
// 2. 检查用户权限
|
||||
if !agent.IsActive {
|
||||
return nil, fmt.Errorf("agent is not active")
|
||||
}
|
||||
|
||||
// 3. 生成会话ID
|
||||
sessionID := req.SessionID
|
||||
if sessionID == "" {
|
||||
sessionID = uuid.New().String()
|
||||
}
|
||||
|
||||
// 4. 调用 Python 服务
|
||||
pythonReq := ChatRequest{
|
||||
AgentID: req.AgentID,
|
||||
Message: req.Message,
|
||||
SessionID: sessionID,
|
||||
Context: req.Context,
|
||||
}
|
||||
|
||||
pythonResp, err := s.callPythonChat(ctx, pythonReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to call python service: %w", err)
|
||||
}
|
||||
|
||||
return &model.AgentResponse{
|
||||
Reply: pythonResp.Reply,
|
||||
SessionID: pythonResp.SessionID,
|
||||
ToolsUsed: pythonResp.ToolsUsed,
|
||||
Metadata: pythonResp.Metadata,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ChatService) callPythonChat(ctx context.Context, req ChatRequest) (*ChatResponse, error) {
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
"POST",
|
||||
s.pythonURL+"/agent/chat",
|
||||
bytes.NewBuffer(jsonData),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 120 * time.Second, // Agent 可能需要较长时间
|
||||
}
|
||||
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("python service returned status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var chatResp ChatResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &chatResp, nil
|
||||
}
|
||||
|
||||
// ListAgents 获取用户可用的 Agent 列表
|
||||
func (s *ChatService) ListAgents(userID string) ([]model.Agent, error) {
|
||||
return s.agentRepo.FindByOwnerID(userID)
|
||||
}
|
||||
|
||||
// CreateAgent 创建新的 Agent
|
||||
func (s *ChatService) CreateAgent(userID string, name, description string) (*model.Agent, error) {
|
||||
agent := &model.Agent{
|
||||
ID: uuid.New().String(),
|
||||
Name: name,
|
||||
Description: description,
|
||||
OwnerID: userID,
|
||||
SecurityLevel: model.SecurityLevelSafe,
|
||||
IsActive: true,
|
||||
Timeout: 60,
|
||||
MemoryLimit: 134217728, // 128MB
|
||||
}
|
||||
|
||||
if err := s.agentRepo.Create(agent); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return agent, nil
|
||||
}
|
||||
765
server/internal/service/database_service.go
Normal file
765
server/internal/service/database_service.go
Normal file
@@ -0,0 +1,765 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/lib/pq"
|
||||
"x-agents/server/internal/model"
|
||||
"x-agents/server/internal/repository"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDatabaseNotFound = errors.New("database not found")
|
||||
ErrDatabaseUnreachable = errors.New("database cannot be connected")
|
||||
)
|
||||
|
||||
type DatabaseService struct {
|
||||
repo *repository.DatabaseRepository
|
||||
subTableRepo *repository.SubTableRepository
|
||||
}
|
||||
|
||||
func NewDatabaseService(repo *repository.DatabaseRepository, subTableRepo *repository.SubTableRepository) *DatabaseService {
|
||||
return &DatabaseService{
|
||||
repo: repo,
|
||||
subTableRepo: subTableRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnection 测试数据库连通性
|
||||
func (s *DatabaseService) TestConnection(info *model.DatabaseInfo) error {
|
||||
log.Printf("[数据库连接测试] 开始测试连接: 类型=%s, 主机=%s, 端口=%d, 数据库=%s, 用户=%s",
|
||||
info.DBType, info.Host, info.Port, info.Database, info.Username)
|
||||
|
||||
// 统一转换为小写处理
|
||||
dbType := strings.ToLower(info.DBType)
|
||||
|
||||
// 构建连接字符串
|
||||
dsn := s.buildDSN(info)
|
||||
log.Printf("[数据库连接测试] DSN构建完成: %s", dsn)
|
||||
|
||||
// 设置超时
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 根据数据库类型连接
|
||||
var db *sql.DB
|
||||
var err error
|
||||
|
||||
switch dbType {
|
||||
case "mysql":
|
||||
db, err = sql.Open("mysql", dsn)
|
||||
case "postgres", "postgresql":
|
||||
db, err = sql.Open("postgres", dsn)
|
||||
default:
|
||||
errMsg := fmt.Sprintf("unsupported database type: %s", info.DBType)
|
||||
log.Printf("[数据库连接测试] 错误: %s", errMsg)
|
||||
return fmt.Errorf(errMsg)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("failed to create connection: %v", err)
|
||||
log.Printf("[数据库连接测试] 错误: %s", errMsg)
|
||||
return fmt.Errorf(errMsg)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// 测试连接
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
errMsg := fmt.Sprintf("cannot connect to database: %v", err)
|
||||
log.Printf("[数据库连接测试] 连接失败: %s", errMsg)
|
||||
return fmt.Errorf(errMsg)
|
||||
}
|
||||
|
||||
log.Printf("[数据库连接测试] 连接成功!")
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildDSN 构建数据库连接字符串
|
||||
func (s *DatabaseService) buildDSN(info *model.DatabaseInfo) string {
|
||||
dbType := strings.ToLower(info.DBType)
|
||||
switch dbType {
|
||||
case "mysql":
|
||||
charset := info.Charset
|
||||
if charset == "" {
|
||||
charset = "utf8mb4"
|
||||
}
|
||||
// 如果没有指定数据库名,只测试连接
|
||||
dbName := info.Database
|
||||
if dbName == "" {
|
||||
dbName = "mysql"
|
||||
}
|
||||
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&timeout=5s&parseTime=True",
|
||||
info.Username,
|
||||
info.Password,
|
||||
info.Host,
|
||||
info.Port,
|
||||
dbName,
|
||||
charset,
|
||||
)
|
||||
case "postgres", "postgresql":
|
||||
sslmode := "disable"
|
||||
if info.SSLMode != "" {
|
||||
sslmode = info.SSLMode
|
||||
}
|
||||
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s connect_timeout=5",
|
||||
info.Host,
|
||||
info.Port,
|
||||
info.Username,
|
||||
info.Password,
|
||||
info.Database,
|
||||
sslmode,
|
||||
)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// getConnection 获取数据库连接
|
||||
func (s *DatabaseService) getConnection(info *model.DatabaseInfo) (*sql.DB, error) {
|
||||
dsn := s.buildDSN(info)
|
||||
dbType := strings.ToLower(info.DBType)
|
||||
|
||||
var db *sql.DB
|
||||
var err error
|
||||
|
||||
switch dbType {
|
||||
case "mysql":
|
||||
db, err = sql.Open("mysql", dsn)
|
||||
case "postgres", "postgresql":
|
||||
db, err = sql.Open("postgres", dsn)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported database type: %s", dbType)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// getTableDDL 获取表的 DDL
|
||||
func (s *DatabaseService) getTableDDL(db *sql.DB, dbType, tableName string) (string, error) {
|
||||
switch dbType {
|
||||
case "mysql":
|
||||
query := fmt.Sprintf("SHOW CREATE TABLE `%s`", tableName)
|
||||
row := db.QueryRow(query)
|
||||
var tblName, createStmt string
|
||||
if err := row.Scan(&tblName, &createStmt); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return createStmt, nil
|
||||
case "postgres", "postgresql":
|
||||
query := fmt.Sprintf("SELECT pg_get_create('%s')", tableName)
|
||||
var ddl string
|
||||
if err := db.QueryRow(query).Scan(&ddl); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return ddl, nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported database type: %s", dbType)
|
||||
}
|
||||
}
|
||||
|
||||
// buildMappedDDL 根据字段映射生成带 COMMENT 的 DDL
|
||||
func (s *DatabaseService) buildMappedDDL(originalDDL string, fields []model.FieldMapping) string {
|
||||
// 构建列名到映射名的映射
|
||||
columnMap := make(map[string]string)
|
||||
for _, f := range fields {
|
||||
if f.MappedName != "" {
|
||||
columnMap[f.ColumnName] = f.MappedName
|
||||
}
|
||||
}
|
||||
|
||||
if len(columnMap) == 0 {
|
||||
return originalDDL
|
||||
}
|
||||
|
||||
// 解析原始 DDL,为有映射的列添加 COMMENT
|
||||
lines := strings.Split(originalDDL, "\n")
|
||||
var resultLines []string
|
||||
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
|
||||
// 检查是否是列定义行(以 ` 开头,包含数据类型)
|
||||
if strings.HasPrefix(trimmed, "`") {
|
||||
// 提取列名
|
||||
parts := strings.SplitN(trimmed, " ", 2)
|
||||
if len(parts) >= 1 {
|
||||
colName := strings.Trim(parts[0], "`")
|
||||
|
||||
// 检查是否有映射
|
||||
if mappedName, ok := columnMap[colName]; ok {
|
||||
// 去掉结尾的逗号(如果有)
|
||||
trimmed = strings.TrimRight(trimmed, ",")
|
||||
// 检查是否已经有 COMMENT
|
||||
if strings.Contains(strings.ToUpper(trimmed), "COMMENT") {
|
||||
// 替换已有的 COMMENT
|
||||
trimmed = strings.TrimSuffix(trimmed, " COMMENT '...'")
|
||||
trimmed = fmt.Sprintf("%s COMMENT '%s'", trimmed, mappedName)
|
||||
} else {
|
||||
// 在末尾添加 COMMENT
|
||||
trimmed = fmt.Sprintf("%s COMMENT '%s'", trimmed, mappedName)
|
||||
}
|
||||
// 替换原始行为修改后的行
|
||||
resultLines = append(resultLines, trimmed)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
resultLines = append(resultLines, line)
|
||||
}
|
||||
|
||||
return strings.Join(resultLines, "\n")
|
||||
}
|
||||
|
||||
// Check 检查数据库连接
|
||||
func (s *DatabaseService) Check(req model.CheckRequest) (*model.CheckResponse, error) {
|
||||
log.Printf("[Check] 开始检查连接: 类型=%s, 主机=%s, 端口=%d, 数据库=%s, 用户=%s",
|
||||
req.DBType, req.Host, req.Port, req.Database, req.Username)
|
||||
|
||||
info := &model.DatabaseInfo{
|
||||
DBType: req.DBType,
|
||||
Host: req.Host,
|
||||
Port: req.Port,
|
||||
Username: req.Username,
|
||||
Password: req.Password,
|
||||
Database: req.Database,
|
||||
Charset: req.Charset,
|
||||
SSLMode: req.SSLMode,
|
||||
}
|
||||
|
||||
if info.Charset == "" {
|
||||
info.Charset = "utf8mb4"
|
||||
}
|
||||
|
||||
// 构建连接
|
||||
dsn := s.buildDSN(info)
|
||||
dbType := strings.ToLower(info.DBType)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var db *sql.DB
|
||||
var err error
|
||||
|
||||
switch dbType {
|
||||
case "mysql":
|
||||
db, err = sql.Open("mysql", dsn)
|
||||
case "postgres", "postgresql":
|
||||
db, err = sql.Open("postgres", dsn)
|
||||
default:
|
||||
return &model.CheckResponse{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("unsupported database type: %s", req.DBType),
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return &model.CheckResponse{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("failed to create connection: %v", err),
|
||||
}, nil
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// 测试连接
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
log.Printf("[Check] 连接失败: %v", err)
|
||||
return &model.CheckResponse{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("cannot connect to database: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Printf("[Check] 连接成功,开始获取表列表...")
|
||||
|
||||
// 获取表列表
|
||||
var tables []model.TableDDLInfo
|
||||
switch dbType {
|
||||
case "mysql":
|
||||
tables, _ = s.getMySQLTables(db, req.Database)
|
||||
case "postgres", "postgresql":
|
||||
tables, _ = s.getPostgresTables(db, req.Database)
|
||||
}
|
||||
|
||||
log.Printf("[Check] 获取到 %d 个表", len(tables))
|
||||
|
||||
// 如果传入了 database_id,获取已保存的字段映射和 DDL 并填充到表结构中
|
||||
if req.DatabaseID != "" && s.subTableRepo != nil {
|
||||
s.fillFieldMappings(req.DatabaseID, tables)
|
||||
s.fillDDL(req.DatabaseID, tables)
|
||||
}
|
||||
|
||||
return &model.CheckResponse{
|
||||
Success: true,
|
||||
Message: "connection successful",
|
||||
Tables: tables,
|
||||
Database: req.Database,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getMySQLTables 获取MySQL表结构
|
||||
func (s *DatabaseService) getMySQLTables(db *sql.DB, dbName string) ([]model.TableDDLInfo, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT TABLE_NAME, TABLE_COMMENT
|
||||
FROM information_schema.TABLES
|
||||
WHERE TABLE_SCHEMA = ?
|
||||
AND TABLE_TYPE = 'BASE TABLE'
|
||||
`, dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tables []model.TableDDLInfo
|
||||
for rows.Next() {
|
||||
var tableName, tableComment string
|
||||
if err := rows.Scan(&tableName, &tableComment); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
table := model.TableDDLInfo{
|
||||
TableName: tableName,
|
||||
TableComment: tableComment,
|
||||
}
|
||||
|
||||
// 获取列信息
|
||||
table.Columns, _ = s.getMySQLColumns(db, dbName, tableName)
|
||||
|
||||
// 获取 DDL
|
||||
table.DDL, _ = s.getMySQLDDL(db, tableName)
|
||||
|
||||
tables = append(tables, table)
|
||||
}
|
||||
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
// getMySQLDDL 获取 MySQL 表的 DDL
|
||||
func (s *DatabaseService) getMySQLDDL(db *sql.DB, tableName string) (string, error) {
|
||||
// 使用反引号包裹表名,防止关键字冲突
|
||||
query := fmt.Sprintf("SHOW CREATE TABLE `%s`", tableName)
|
||||
row := db.QueryRow(query)
|
||||
var tblName, createStmt string
|
||||
if err := row.Scan(&tblName, &createStmt); err != nil {
|
||||
log.Printf("[getMySQLDDL] 获取 DDL 失败: %v", err)
|
||||
return "", nil
|
||||
}
|
||||
return createStmt, nil
|
||||
}
|
||||
|
||||
// getMySQLColumns 获取MySQL列信息
|
||||
func (s *DatabaseService) getMySQLColumns(db *sql.DB, dbName, tableName string) ([]model.ColumnInfo, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT
|
||||
COLUMN_NAME, DATA_TYPE, COLUMN_TYPE,
|
||||
IS_NULLABLE, COLUMN_DEFAULT, COLUMN_KEY,
|
||||
EXTRA, COLUMN_COMMENT
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?
|
||||
ORDER BY ORDINAL_POSITION
|
||||
`, dbName, tableName)
|
||||
if err != nil {
|
||||
log.Printf("[getMySQLColumns] 查询列信息失败: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns := make([]model.ColumnInfo, 0)
|
||||
for rows.Next() {
|
||||
var col model.ColumnInfo
|
||||
var defaultValue, extra, columnComment sql.NullString
|
||||
if err := rows.Scan(&col.ColumnName, &col.DataType, &col.ColumnType,
|
||||
&col.IsNullable, &defaultValue, &col.ColumnKey, &extra, &columnComment); err != nil {
|
||||
log.Printf("[getMySQLColumns] Scan 失败: %v", err)
|
||||
continue
|
||||
}
|
||||
col.DefaultValue = defaultValue.String
|
||||
col.Extra = extra.String
|
||||
col.ColumnComment = columnComment.String
|
||||
columns = append(columns, col)
|
||||
}
|
||||
|
||||
// 检查是否有迭代错误
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("[getMySQLColumns] 迭代错误: %v", err)
|
||||
}
|
||||
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
// getPostgresTables 获取PostgreSQL表结构
|
||||
func (s *DatabaseService) getPostgresTables(db *sql.DB, dbName string) ([]model.TableDDLInfo, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT t.table_name, obj_description((t.table_schema || '.' || t.table_name)::regclass)
|
||||
FROM information_schema.tables t
|
||||
WHERE t.table_schema = 'public' AND t.table_type = 'BASE TABLE'
|
||||
`, dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tables []model.TableDDLInfo
|
||||
for rows.Next() {
|
||||
var tableName, tableComment string
|
||||
if err := rows.Scan(&tableName, &tableComment); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
table := model.TableDDLInfo{
|
||||
TableName: tableName,
|
||||
TableComment: tableComment,
|
||||
}
|
||||
|
||||
// 获取列信息
|
||||
table.Columns, _ = s.getPostgresColumns(db, tableName)
|
||||
|
||||
// 获取 DDL
|
||||
table.DDL, _ = s.getPostgresDDL(db, tableName)
|
||||
|
||||
tables = append(tables, table)
|
||||
}
|
||||
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
// getPostgresDDL 获取 PostgreSQL 表的 DDL
|
||||
func (s *DatabaseService) getPostgresDDL(db *sql.DB, tableName string) (string, error) {
|
||||
var ddl string
|
||||
query := fmt.Sprintf("SELECT pg_get_create('%s')", tableName)
|
||||
row := db.QueryRow(query)
|
||||
if err := row.Scan(&ddl); err != nil {
|
||||
log.Printf("[getPostgresDDL] 获取 DDL 失败: %v", err)
|
||||
return "", nil
|
||||
}
|
||||
return ddl, nil
|
||||
}
|
||||
|
||||
// getPostgresColumns 获取PostgreSQL列信息
|
||||
func (s *DatabaseService) getPostgresColumns(db *sql.DB, tableName string) ([]model.ColumnInfo, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT
|
||||
c.column_name, c.data_type, c.udt_name,
|
||||
c.is_nullable, c.column_default, c.column_name,
|
||||
'', c.column_comment
|
||||
FROM information_schema.columns c
|
||||
WHERE c.table_name = $1 AND c.table_schema = 'public'
|
||||
ORDER BY c.ordinal_position
|
||||
`, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var columns []model.ColumnInfo
|
||||
for rows.Next() {
|
||||
var col model.ColumnInfo
|
||||
if err := rows.Scan(&col.ColumnName, &col.DataType, &col.ColumnType,
|
||||
&col.IsNullable, &col.DefaultValue, &col.ColumnKey, &col.Extra, &col.ColumnComment); err != nil {
|
||||
continue
|
||||
}
|
||||
columns = append(columns, col)
|
||||
}
|
||||
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
// Create 创建数据库信息(支持同时保存子表配置)
|
||||
func (s *DatabaseService) Create(req model.CreateDatabaseRequest) (*model.DatabaseInfo, error) {
|
||||
log.Printf("[Create] 收到创建请求: %+v", req)
|
||||
|
||||
info := &model.DatabaseInfo{
|
||||
ID: uuid.New().String(),
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
DBType: strings.ToLower(req.DBType), // 统一转为小写
|
||||
Host: req.Host,
|
||||
Port: req.Port,
|
||||
Username: req.Username,
|
||||
Password: req.Password,
|
||||
Database: req.Database,
|
||||
Charset: req.Charset,
|
||||
SSLMode: req.SSLMode,
|
||||
TableCount: len(req.SubTables),
|
||||
}
|
||||
|
||||
// 默认值
|
||||
if info.Charset == "" {
|
||||
info.Charset = "utf8mb4"
|
||||
}
|
||||
|
||||
// 测试数据库连通性
|
||||
log.Printf("[Create] 开始测试数据库连接...")
|
||||
if err := s.TestConnection(info); err != nil {
|
||||
log.Printf("[Create] 数据库连接测试失败: %v", err)
|
||||
return nil, fmt.Errorf("database connection failed: %v", err)
|
||||
}
|
||||
log.Printf("[Create] 数据库连接测试成功!")
|
||||
|
||||
// 保存数据库信息
|
||||
if err := s.repo.Create(info); err != nil {
|
||||
log.Printf("[Create] 保存数据库失败: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 保存子表配置(如有)
|
||||
if len(req.SubTables) > 0 && s.subTableRepo != nil {
|
||||
log.Printf("[Create] 保存 %d 个子表配置", len(req.SubTables))
|
||||
|
||||
// 获取数据库连接用于查询 DDL
|
||||
db, err := s.getConnection(info)
|
||||
if err != nil {
|
||||
log.Printf("[Create] 获取数据库连接失败: %v", err)
|
||||
} else {
|
||||
defer db.Close()
|
||||
}
|
||||
|
||||
for _, subReq := range req.SubTables {
|
||||
subTable := &model.SubTableInfo{
|
||||
ID: uuid.New().String(),
|
||||
DatabaseID: info.ID,
|
||||
ParentTable: subReq.ParentTable,
|
||||
SubTableName: subReq.SubTableName,
|
||||
SubTableComment: subReq.SubTableComment,
|
||||
MappingType: subReq.MappingType,
|
||||
RelationField: subReq.RelationField,
|
||||
RelationType: subReq.RelationType,
|
||||
}
|
||||
// 使用 SetFields 方法保存字段映射
|
||||
subTable.SetFields(subReq.Fields)
|
||||
|
||||
// 获取并保存 DDL
|
||||
if db != nil {
|
||||
ddl, err := s.getTableDDL(db, strings.ToLower(info.DBType), subReq.ParentTable)
|
||||
if err != nil {
|
||||
log.Printf("[Create] 获取原始 DDL 失败: %v", err)
|
||||
} else {
|
||||
// 如果有字段映射,生成带 COMMENT 的新 DDL
|
||||
if len(subReq.Fields) > 0 {
|
||||
subTable.DDL = s.buildMappedDDL(ddl, subReq.Fields)
|
||||
log.Printf("[Create] 生成映射后的 DDL,长度: %d", len(subTable.DDL))
|
||||
} else {
|
||||
subTable.DDL = ddl
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.subTableRepo.Create(subTable); err != nil {
|
||||
log.Printf("[Create] 保存子表失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 同步到文件
|
||||
s.syncSubTablesToFile(info)
|
||||
}
|
||||
|
||||
log.Printf("[Create] 创建成功, ID=%s", info.ID)
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// syncSubTablesToFile 同步子表到文件
|
||||
func (s *DatabaseService) syncSubTablesToFile(info *model.DatabaseInfo) {
|
||||
if s.subTableRepo == nil {
|
||||
return
|
||||
}
|
||||
|
||||
tables, err := s.subTableRepo.FindByDatabaseID(info.ID)
|
||||
if err != nil {
|
||||
log.Printf("[syncSubTablesToFile] 查询子表失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
mapping := &model.SubTableMapping{
|
||||
DatabaseID: info.ID,
|
||||
DatabaseName: info.Name,
|
||||
DBType: info.DBType,
|
||||
Tables: tables,
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
resourceDir := "resources/db_info"
|
||||
os.MkdirAll(resourceDir, 0755)
|
||||
|
||||
data, err := json.MarshalIndent(mapping, "", " ")
|
||||
if err != nil {
|
||||
log.Printf("[syncSubTablesToFile] 序列化失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
filePath := fmt.Sprintf("%s/%s.json", resourceDir, info.ID)
|
||||
if err := os.WriteFile(filePath, data, 0644); err != nil {
|
||||
log.Printf("[syncSubTablesToFile] 写入文件失败: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("[syncSubTablesToFile] 同步成功: %s", filePath)
|
||||
}
|
||||
|
||||
// GetByID 获取详情
|
||||
func (s *DatabaseService) GetByID(id string) (*model.DatabaseInfo, error) {
|
||||
log.Printf("[GetByID] 查询 ID=%s", id)
|
||||
info, err := s.repo.FindByID(id)
|
||||
if err != nil {
|
||||
log.Printf("[GetByID] 查询失败: %v", err)
|
||||
return nil, ErrDatabaseNotFound
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// List 获取列表
|
||||
func (s *DatabaseService) List() ([]model.DatabaseInfo, error) {
|
||||
log.Printf("[List] 查询所有数据库列表")
|
||||
return s.repo.FindAll()
|
||||
}
|
||||
|
||||
// Update 更新
|
||||
func (s *DatabaseService) Update(id string, req model.UpdateDatabaseRequest) (*model.DatabaseInfo, error) {
|
||||
log.Printf("[Update] 更新 ID=%s, 数据=%+v", id, req)
|
||||
// 检查是否存在
|
||||
_, err := s.repo.FindByID(id)
|
||||
if err != nil {
|
||||
log.Printf("[Update] 不存在: %v", err)
|
||||
return nil, ErrDatabaseNotFound
|
||||
}
|
||||
|
||||
// 构建更新数据
|
||||
updates := map[string]interface{}{}
|
||||
if req.Name != "" {
|
||||
updates["name"] = req.Name
|
||||
}
|
||||
if req.Description != "" {
|
||||
updates["description"] = req.Description
|
||||
}
|
||||
if req.DBType != "" {
|
||||
updates["db_type"] = req.DBType
|
||||
}
|
||||
if req.Host != "" {
|
||||
updates["host"] = req.Host
|
||||
}
|
||||
if req.Port > 0 {
|
||||
updates["port"] = req.Port
|
||||
}
|
||||
if req.Username != "" {
|
||||
updates["username"] = req.Username
|
||||
}
|
||||
if req.Password != "" {
|
||||
updates["password"] = req.Password
|
||||
}
|
||||
if req.Database != "" {
|
||||
updates["database"] = req.Database
|
||||
}
|
||||
if req.TableCount > 0 {
|
||||
updates["table_count"] = req.TableCount
|
||||
}
|
||||
if req.Charset != "" {
|
||||
updates["charset"] = req.Charset
|
||||
}
|
||||
if req.SSLMode != "" {
|
||||
updates["ssl_mode"] = req.SSLMode
|
||||
}
|
||||
|
||||
info := &model.DatabaseInfo{}
|
||||
if err := s.repo.Update(id, info); err != nil {
|
||||
log.Printf("[Update] 更新失败: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.repo.FindByID(id)
|
||||
}
|
||||
|
||||
// fillFieldMappings 填充字段映射到表结构中
|
||||
func (s *DatabaseService) fillFieldMappings(databaseID string, tables []model.TableDDLInfo) {
|
||||
// 从数据库中获取该数据库下所有子表的字段映射
|
||||
subTables, err := s.subTableRepo.FindByDatabaseID(databaseID)
|
||||
if err != nil {
|
||||
log.Printf("[fillFieldMappings] 查询子表失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 构建表名到字段映射的映射
|
||||
tableFieldsMap := make(map[string][]model.FieldMapping)
|
||||
for _, st := range subTables {
|
||||
fields := st.GetFields()
|
||||
if len(fields) > 0 {
|
||||
tableFieldsMap[st.ParentTable] = fields
|
||||
}
|
||||
}
|
||||
|
||||
// 遍历返回的表结构,填充字段映射
|
||||
for i := range tables {
|
||||
tableName := tables[i].TableName
|
||||
if fields, ok := tableFieldsMap[tableName]; ok {
|
||||
// 构建列名到映射名的映射
|
||||
columnMap := make(map[string]string)
|
||||
for _, f := range fields {
|
||||
columnMap[f.ColumnName] = f.MappedName
|
||||
}
|
||||
|
||||
// 填充到每个列
|
||||
for j := range tables[i].Columns {
|
||||
colName := tables[i].Columns[j].ColumnName
|
||||
if mappedName, ok := columnMap[colName]; ok {
|
||||
tables[i].Columns[j].MappedName = mappedName
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[fillFieldMappings] 已填充字段映射到 %d 个表", len(tables))
|
||||
}
|
||||
|
||||
// fillDDL 填充已保存的 DDL 到表结构中
|
||||
func (s *DatabaseService) fillDDL(databaseID string, tables []model.TableDDLInfo) {
|
||||
// 从数据库中获取该数据库下所有子表的 DDL
|
||||
subTables, err := s.subTableRepo.FindByDatabaseID(databaseID)
|
||||
if err != nil {
|
||||
log.Printf("[fillDDL] 查询子表失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 构建表名到 DDL 的映射
|
||||
tableDDLMap := make(map[string]string)
|
||||
for _, st := range subTables {
|
||||
if st.DDL != "" {
|
||||
tableDDLMap[st.ParentTable] = st.DDL
|
||||
}
|
||||
}
|
||||
|
||||
// 遍历返回的表结构,填充 DDL
|
||||
for i := range tables {
|
||||
tableName := tables[i].TableName
|
||||
if ddl, ok := tableDDLMap[tableName]; ok {
|
||||
tables[i].DDL = ddl
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[fillDDL] 已填充 DDL 到 %d 个表", len(tables))
|
||||
}
|
||||
|
||||
// Delete 删除
|
||||
func (s *DatabaseService) Delete(id string) error {
|
||||
log.Printf("[Delete] 删除 ID=%s", id)
|
||||
_, err := s.repo.FindByID(id)
|
||||
if err != nil {
|
||||
log.Printf("[Delete] 不存在: %v", err)
|
||||
return ErrDatabaseNotFound
|
||||
}
|
||||
return s.repo.Delete(id)
|
||||
}
|
||||
602
server/internal/service/sub_table_service.go
Normal file
602
server/internal/service/sub_table_service.go
Normal file
@@ -0,0 +1,602 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"x-agents/server/internal/model"
|
||||
"x-agents/server/internal/repository"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type SubTableService struct {
|
||||
repo *repository.SubTableRepository
|
||||
dbRepo *repository.DatabaseRepository
|
||||
resourceDir string
|
||||
}
|
||||
|
||||
func NewSubTableService(repo *repository.SubTableRepository, dbRepo *repository.DatabaseRepository) *SubTableService {
|
||||
return &SubTableService{
|
||||
repo: repo,
|
||||
dbRepo: dbRepo,
|
||||
resourceDir: "resources/db_info",
|
||||
}
|
||||
}
|
||||
|
||||
// ensureDir 确保目录存在
|
||||
func (s *SubTableService) ensureDir() error {
|
||||
return os.MkdirAll(s.resourceDir, 0755)
|
||||
}
|
||||
|
||||
// getFilePath 获取文件路径
|
||||
func (s *SubTableService) getFilePath(databaseID string) string {
|
||||
return filepath.Join(s.resourceDir, fmt.Sprintf("%s.json", databaseID))
|
||||
}
|
||||
|
||||
// saveToFile 保存到文件
|
||||
func (s *SubTableService) saveToFile(databaseID string, mapping *model.SubTableMapping) error {
|
||||
if err := s.ensureDir(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(mapping, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(s.getFilePath(databaseID), data, 0644)
|
||||
}
|
||||
|
||||
// loadFromFile 从文件加载
|
||||
func (s *SubTableService) loadFromFile(databaseID string) (*model.SubTableMapping, error) {
|
||||
filePath := s.getFilePath(databaseID)
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var mapping model.SubTableMapping
|
||||
if err := json.Unmarshal(data, &mapping); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &mapping, nil
|
||||
}
|
||||
|
||||
// syncToFile 同步到文件
|
||||
func (s *SubTableService) syncToFile(databaseID string) error {
|
||||
// 获取数据库信息
|
||||
dbInfo, err := s.dbRepo.FindByID(databaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取所有子表信息
|
||||
tables, err := s.repo.FindByDatabaseID(databaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mapping := &model.SubTableMapping{
|
||||
DatabaseID: databaseID,
|
||||
DatabaseName: dbInfo.Name,
|
||||
DBType: dbInfo.DBType,
|
||||
Tables: tables,
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
return s.saveToFile(databaseID, mapping)
|
||||
}
|
||||
|
||||
// Create 创建子表信息
|
||||
func (s *SubTableService) Create(req model.CreateSubTableRequest) (*model.SubTableInfo, error) {
|
||||
log.Printf("[SubTable Create] 收到请求: %+v", req)
|
||||
|
||||
// 验证数据库是否存在
|
||||
_, err := s.dbRepo.FindByID(req.DatabaseID)
|
||||
if err != nil {
|
||||
log.Printf("[SubTable Create] 数据库不存在: %v", err)
|
||||
return nil, fmt.Errorf("database not found")
|
||||
}
|
||||
|
||||
info := &model.SubTableInfo{
|
||||
ID: uuid.New().String(),
|
||||
DatabaseID: req.DatabaseID,
|
||||
ParentTable: req.ParentTable,
|
||||
SubTableName: req.SubTableName,
|
||||
SubTableComment: req.SubTableComment,
|
||||
MappingType: req.MappingType,
|
||||
RelationField: req.RelationField,
|
||||
RelationType: req.RelationType,
|
||||
}
|
||||
|
||||
if err := s.repo.Create(info); err != nil {
|
||||
log.Printf("[SubTable Create] 创建失败: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 同步到文件
|
||||
if err := s.syncToFile(req.DatabaseID); err != nil {
|
||||
log.Printf("[SubTable Create] 同步文件失败: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("[SubTable Create] 创建成功, ID=%s", info.ID)
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// GetByID 获取详情
|
||||
func (s *SubTableService) GetByID(id string) (*model.SubTableInfo, error) {
|
||||
log.Printf("[SubTable GetByID] 查询 ID=%s", id)
|
||||
info, err := s.repo.FindByID(id)
|
||||
if err != nil {
|
||||
log.Printf("[SubTable GetByID] 查询失败: %v", err)
|
||||
return nil, fmt.Errorf("sub table not found")
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// ListByDatabaseID 获取数据库下所有子表
|
||||
func (s *SubTableService) ListByDatabaseID(databaseID string) ([]model.SubTableInfo, error) {
|
||||
log.Printf("[SubTable ListByDatabaseID] 查询数据库ID=%s", databaseID)
|
||||
tables, err := s.repo.FindByDatabaseID(databaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 填充 FieldsList 字段
|
||||
for i := range tables {
|
||||
tables[i].FieldsList = tables[i].GetFields()
|
||||
}
|
||||
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
// GetMappingFromFile 从文件获取映射信息
|
||||
func (s *SubTableService) GetMappingFromFile(databaseID string) (*model.SubTableMapping, error) {
|
||||
log.Printf("[SubTable GetMappingFromFile] 读取文件, databaseID=%s", databaseID)
|
||||
return s.loadFromFile(databaseID)
|
||||
}
|
||||
|
||||
// Update 更新
|
||||
func (s *SubTableService) Update(id string, req model.UpdateSubTableRequest) (*model.SubTableInfo, error) {
|
||||
log.Printf("[SubTable Update] 更新 ID=%s, 数据=%+v", id, req)
|
||||
|
||||
info, err := s.repo.FindByID(id)
|
||||
if err != nil {
|
||||
log.Printf("[SubTable Update] 不存在: %v", err)
|
||||
return nil, fmt.Errorf("sub table not found")
|
||||
}
|
||||
|
||||
if req.ParentTable != "" {
|
||||
info.ParentTable = req.ParentTable
|
||||
}
|
||||
if req.SubTableName != "" {
|
||||
info.SubTableName = req.SubTableName
|
||||
}
|
||||
if req.SubTableComment != "" {
|
||||
info.SubTableComment = req.SubTableComment
|
||||
}
|
||||
if req.MappingType != "" {
|
||||
info.MappingType = req.MappingType
|
||||
}
|
||||
if req.RelationField != "" {
|
||||
info.RelationField = req.RelationField
|
||||
}
|
||||
if req.RelationType != "" {
|
||||
info.RelationType = req.RelationType
|
||||
}
|
||||
|
||||
if err := s.repo.Update(id, info); err != nil {
|
||||
log.Printf("[SubTable Update] 更新失败: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 同步到文件
|
||||
if err := s.syncToFile(info.DatabaseID); err != nil {
|
||||
log.Printf("[SubTable Update] 同步文件失败: %v", err)
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// Delete 删除
|
||||
func (s *SubTableService) Delete(id string) error {
|
||||
log.Printf("[SubTable Delete] 删除 ID=%s", id)
|
||||
|
||||
info, err := s.repo.FindByID(id)
|
||||
if err != nil {
|
||||
log.Printf("[SubTable Delete] 不存在: %v", err)
|
||||
return fmt.Errorf("sub table not found")
|
||||
}
|
||||
|
||||
databaseID := info.DatabaseID
|
||||
|
||||
if err := s.repo.Delete(id); err != nil {
|
||||
log.Printf("[SubTable Delete] 删除失败: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 同步到文件
|
||||
if err := s.syncToFile(databaseID); err != nil {
|
||||
log.Printf("[SubTable Delete] 同步文件失败: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTableDDLFromDatabase 从实际数据库获取表结构和DDL
|
||||
func (s *SubTableService) GetTableDDLFromDatabase(databaseID string) ([]model.TableDDLInfo, error) {
|
||||
log.Printf("[GetTableDDLFromDatabase] 获取数据库ID=%s 的表结构", databaseID)
|
||||
|
||||
// 获取数据库连接信息
|
||||
dbInfo, err := s.dbRepo.FindByID(databaseID)
|
||||
if err != nil {
|
||||
log.Printf("[GetTableDDLFromDatabase] 数据库不存在: %v", err)
|
||||
return nil, fmt.Errorf("database not found")
|
||||
}
|
||||
|
||||
// 构建连接
|
||||
dsn := s.buildDSN(dbInfo)
|
||||
dbType := strings.ToLower(dbInfo.DBType)
|
||||
|
||||
var db *sql.DB
|
||||
switch dbType {
|
||||
case "mysql":
|
||||
db, err = sql.Open("mysql", dsn)
|
||||
case "postgres", "postgresql":
|
||||
db, err = sql.Open("postgres", dsn)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported database type: %s", dbInfo.DBType)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// 获取所有表
|
||||
var tables []model.TableDDLInfo
|
||||
switch dbType {
|
||||
case "mysql":
|
||||
tables, err = s.getMySQLTables(db, dbInfo.Database)
|
||||
case "postgres", "postgresql":
|
||||
tables, err = s.getPostgresTables(db, dbInfo.Database)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Printf("[GetTableDDLFromDatabase] 获取到 %d 个表", len(tables))
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
// buildDSN 构建数据库连接字符串
|
||||
func (s *SubTableService) buildDSN(info *model.DatabaseInfo) string {
|
||||
dbType := strings.ToLower(info.DBType)
|
||||
switch dbType {
|
||||
case "mysql":
|
||||
charset := info.Charset
|
||||
if charset == "" {
|
||||
charset = "utf8mb4"
|
||||
}
|
||||
dbName := info.Database
|
||||
if dbName == "" {
|
||||
dbName = "mysql"
|
||||
}
|
||||
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&timeout=5s&parseTime=True",
|
||||
info.Username, info.Password, info.Host, info.Port, dbName, charset)
|
||||
case "postgres", "postgresql":
|
||||
sslmode := "disable"
|
||||
if info.SSLMode != "" {
|
||||
sslmode = info.SSLMode
|
||||
}
|
||||
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s connect_timeout=5",
|
||||
info.Host, info.Port, info.Username, info.Password, info.Database, sslmode)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getMySQLTables 获取MySQL表结构
|
||||
func (s *SubTableService) getMySQLTables(db *sql.DB, dbName string) ([]model.TableDDLInfo, error) {
|
||||
// 获取所有表名和注释
|
||||
rows, err := db.Query(`
|
||||
SELECT TABLE_NAME, TABLE_COMMENT
|
||||
FROM information_schema.TABLES
|
||||
WHERE TABLE_SCHEMA = ?
|
||||
AND TABLE_TYPE = 'BASE TABLE'
|
||||
`, dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tables []model.TableDDLInfo
|
||||
for rows.Next() {
|
||||
var tableName, tableComment string
|
||||
if err := rows.Scan(&tableName, &tableComment); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
table := model.TableDDLInfo{
|
||||
TableName: tableName,
|
||||
TableComment: tableComment,
|
||||
}
|
||||
|
||||
// 获取列信息
|
||||
table.Columns, _ = s.getMySQLColumns(db, dbName, tableName)
|
||||
|
||||
// 获取索引信息
|
||||
table.Indexes, _ = s.getMySQLIndexes(db, dbName, tableName)
|
||||
|
||||
// 生成DDL
|
||||
table.DDL = s.generateMySQLDDL(table)
|
||||
|
||||
tables = append(tables, table)
|
||||
}
|
||||
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
// getMySQLColumns 获取MySQL列信息
|
||||
func (s *SubTableService) getMySQLColumns(db *sql.DB, dbName, tableName string) ([]model.ColumnInfo, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT
|
||||
COLUMN_NAME, DATA_TYPE, COLUMN_TYPE,
|
||||
IS_NULLABLE, COLUMN_DEFAULT, COLUMN_KEY,
|
||||
EXTRA, COLUMN_COMMENT
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?
|
||||
ORDER BY ORDINAL_POSITION
|
||||
`, dbName, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var columns []model.ColumnInfo
|
||||
for rows.Next() {
|
||||
var col model.ColumnInfo
|
||||
if err := rows.Scan(&col.ColumnName, &col.DataType, &col.ColumnType,
|
||||
&col.IsNullable, &col.DefaultValue, &col.ColumnKey, &col.Extra, &col.ColumnComment); err != nil {
|
||||
continue
|
||||
}
|
||||
columns = append(columns, col)
|
||||
}
|
||||
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
// getMySQLIndexes 获取MySQL索引信息
|
||||
func (s *SubTableService) getMySQLIndexes(db *sql.DB, dbName, tableName string) ([]model.IndexInfo, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT INDEX_NAME, COLUMN_NAME, NON_UNIQUE, INDEX_TYPE
|
||||
FROM information_schema.STATISTICS
|
||||
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?
|
||||
ORDER BY SEQ_IN_INDEX
|
||||
`, dbName, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var indexes []model.IndexInfo
|
||||
for rows.Next() {
|
||||
var idx model.IndexInfo
|
||||
if err := rows.Scan(&idx.IndexName, &idx.ColumnName, &idx.NonUnique, &idx.IndexType); err != nil {
|
||||
continue
|
||||
}
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
|
||||
return indexes, nil
|
||||
}
|
||||
|
||||
// generateMySQLDDL 生成MySQL DDL
|
||||
func (s *SubTableService) generateMySQLDDL(table model.TableDDLInfo) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("CREATE TABLE `%s` (\n", table.TableName))
|
||||
|
||||
for i, col := range table.Columns {
|
||||
sb.WriteString(fmt.Sprintf(" `%s` %s", col.ColumnName, col.ColumnType))
|
||||
|
||||
if col.IsNullable == "NO" {
|
||||
sb.WriteString(" NOT NULL")
|
||||
}
|
||||
if col.DefaultValue != "" {
|
||||
sb.WriteString(fmt.Sprintf(" DEFAULT %s", col.DefaultValue))
|
||||
}
|
||||
if col.Extra == "auto_increment" {
|
||||
sb.WriteString(" AUTO_INCREMENT")
|
||||
}
|
||||
if col.ColumnComment != "" {
|
||||
sb.WriteString(fmt.Sprintf(" COMMENT '%s'", col.ColumnComment))
|
||||
}
|
||||
|
||||
if i < len(table.Columns)-1 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// 添加主键
|
||||
var primaryKeys []string
|
||||
for _, idx := range table.Indexes {
|
||||
if idx.IndexName == "PRIMARY" {
|
||||
primaryKeys = append(primaryKeys, fmt.Sprintf("`%s`", idx.ColumnName))
|
||||
}
|
||||
}
|
||||
if len(primaryKeys) > 0 {
|
||||
sb.WriteString(fmt.Sprintf(" PRIMARY KEY (%s)\n", strings.Join(primaryKeys, ", ")))
|
||||
}
|
||||
|
||||
// 添加索引
|
||||
var addedIndexes []string
|
||||
for _, idx := range table.Indexes {
|
||||
if idx.IndexName != "PRIMARY" {
|
||||
unique := ""
|
||||
if idx.NonUnique == 0 {
|
||||
unique = "UNIQUE "
|
||||
}
|
||||
if !contains(addedIndexes, idx.IndexName) {
|
||||
sb.WriteString(fmt.Sprintf(" %sKEY `%s` (`%s`),\n", unique, idx.IndexName, idx.ColumnName))
|
||||
addedIndexes = append(addedIndexes, idx.IndexName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ddl := sb.String()
|
||||
ddl = strings.TrimSuffix(ddl, ",\n")
|
||||
ddl += "\n)"
|
||||
|
||||
if table.TableComment != "" {
|
||||
ddl += fmt.Sprintf(" COMMENT='%s'", table.TableComment)
|
||||
}
|
||||
ddl += ";\n"
|
||||
|
||||
return ddl
|
||||
}
|
||||
|
||||
// contains 检查切片是否包含元素
|
||||
func contains(slice []string, item string) bool {
|
||||
for _, s := range slice {
|
||||
if s == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getPostgresTables 获取PostgreSQL表结构
|
||||
func (s *SubTableService) getPostgresTables(db *sql.DB, dbName string) ([]model.TableDDLInfo, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT t.table_name, obj_description((t.table_schema || '.' || t.table_name)::regclass)
|
||||
FROM information_schema.tables t
|
||||
WHERE t.table_schema = 'public' AND t.table_type = 'BASE TABLE'
|
||||
`, dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tables []model.TableDDLInfo
|
||||
for rows.Next() {
|
||||
var tableName, tableComment string
|
||||
if err := rows.Scan(&tableName, &tableComment); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
table := model.TableDDLInfo{
|
||||
TableName: tableName,
|
||||
TableComment: tableComment,
|
||||
}
|
||||
|
||||
// 获取列信息
|
||||
table.Columns, _ = s.getPostgresColumns(db, tableName)
|
||||
|
||||
// 获取索引信息
|
||||
table.Indexes, _ = s.getPostgresIndexes(db, tableName)
|
||||
|
||||
// 生成DDL
|
||||
table.DDL = s.generatePostgresDDL(table)
|
||||
|
||||
tables = append(tables, table)
|
||||
}
|
||||
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
// getPostgresColumns 获取PostgreSQL列信息
|
||||
func (s *SubTableService) getPostgresColumns(db *sql.DB, tableName string) ([]model.ColumnInfo, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT
|
||||
c.column_name, c.data_type, c.udt_name,
|
||||
c.is_nullable, c.column_default, c.column_name,
|
||||
'', c.column_comment
|
||||
FROM information_schema.columns c
|
||||
WHERE c.table_name = $1 AND c.table_schema = 'public'
|
||||
ORDER BY c.ordinal_position
|
||||
`, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var columns []model.ColumnInfo
|
||||
for rows.Next() {
|
||||
var col model.ColumnInfo
|
||||
if err := rows.Scan(&col.ColumnName, &col.DataType, &col.ColumnType,
|
||||
&col.IsNullable, &col.DefaultValue, &col.ColumnKey, &col.Extra, &col.ColumnComment); err != nil {
|
||||
continue
|
||||
}
|
||||
columns = append(columns, col)
|
||||
}
|
||||
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
// getPostgresIndexes 获取PostgreSQL索引信息
|
||||
func (s *SubTableService) getPostgresIndexes(db *sql.DB, tableName string) ([]model.IndexInfo, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT indexname, indexdef
|
||||
FROM pg_indexes
|
||||
WHERE tablename = $1 AND schemaname = 'public'
|
||||
`, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var indexes []model.IndexInfo
|
||||
for rows.Next() {
|
||||
var idx model.IndexInfo
|
||||
var indexDef string
|
||||
if err := rows.Scan(&idx.IndexName, &indexDef); err != nil {
|
||||
continue
|
||||
}
|
||||
idx.NonUnique = 1
|
||||
if strings.Contains(indexDef, "UNIQUE") {
|
||||
idx.NonUnique = 0
|
||||
}
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
|
||||
return indexes, nil
|
||||
}
|
||||
|
||||
// generatePostgresDDL 生成PostgreSQL DDL
|
||||
func (s *SubTableService) generatePostgresDDL(table model.TableDDLInfo) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("CREATE TABLE %s (\n", table.TableName))
|
||||
|
||||
for i, col := range table.Columns {
|
||||
sb.WriteString(fmt.Sprintf(" %s %s", col.ColumnName, col.ColumnType))
|
||||
|
||||
if col.IsNullable == "NO" {
|
||||
sb.WriteString(" NOT NULL")
|
||||
}
|
||||
if col.DefaultValue != "" {
|
||||
sb.WriteString(fmt.Sprintf(" DEFAULT %s", col.DefaultValue))
|
||||
}
|
||||
|
||||
if i < len(table.Columns)-1 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
ddl := sb.String()
|
||||
ddl = strings.TrimSuffix(ddl, ",\n")
|
||||
ddl += "\n);"
|
||||
|
||||
return ddl
|
||||
}
|
||||
Reference in New Issue
Block a user