package service import ( "context" "database/sql" "encoding/json" "errors" "fmt" "log" "os" "strings" "time" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" "x-agents/server/internal/model" "x-agents/server/internal/repository" "github.com/google/uuid" ) var ( ErrDatabaseNotFound = errors.New("database not found") ErrDatabaseUnreachable = errors.New("database cannot be connected") ) type DatabaseService struct { repo *repository.DatabaseRepository subTableRepo *repository.SubTableRepository } func NewDatabaseService(repo *repository.DatabaseRepository, subTableRepo *repository.SubTableRepository) *DatabaseService { return &DatabaseService{ repo: repo, subTableRepo: subTableRepo, } } // TestConnection 测试数据库连通性 func (s *DatabaseService) TestConnection(info *model.DatabaseInfo) error { log.Printf("[数据库连接测试] 开始测试连接: 类型=%s, 主机=%s, 端口=%d, 数据库=%s, 用户=%s", info.DBType, info.Host, info.Port, info.Database, info.Username) // 统一转换为小写处理 dbType := strings.ToLower(info.DBType) // 构建连接字符串 dsn := s.buildDSN(info) log.Printf("[数据库连接测试] DSN构建完成: %s", dsn) // 设置超时 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // 根据数据库类型连接 var db *sql.DB var err error switch dbType { case "mysql": db, err = sql.Open("mysql", dsn) case "postgres", "postgresql": db, err = sql.Open("postgres", dsn) default: errMsg := fmt.Sprintf("unsupported database type: %s", info.DBType) log.Printf("[数据库连接测试] 错误: %s", errMsg) return fmt.Errorf(errMsg) } if err != nil { errMsg := fmt.Sprintf("failed to create connection: %v", err) log.Printf("[数据库连接测试] 错误: %s", errMsg) return fmt.Errorf(errMsg) } defer db.Close() // 测试连接 if err := db.PingContext(ctx); err != nil { errMsg := fmt.Sprintf("cannot connect to database: %v", err) log.Printf("[数据库连接测试] 连接失败: %s", errMsg) return fmt.Errorf(errMsg) } log.Printf("[数据库连接测试] 连接成功!") return nil } // buildDSN 构建数据库连接字符串 func (s *DatabaseService) buildDSN(info *model.DatabaseInfo) string { dbType := strings.ToLower(info.DBType) switch dbType { case "mysql": charset := info.Charset if charset == "" { charset = "utf8mb4" } // 如果没有指定数据库名,只测试连接 dbName := info.Database if dbName == "" { dbName = "mysql" } return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&timeout=5s&parseTime=True", info.Username, info.Password, info.Host, info.Port, dbName, charset, ) case "postgres", "postgresql": sslmode := "disable" if info.SSLMode != "" { sslmode = info.SSLMode } return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s connect_timeout=5", info.Host, info.Port, info.Username, info.Password, info.Database, sslmode, ) default: return "" } } // getConnection 获取数据库连接 func (s *DatabaseService) getConnection(info *model.DatabaseInfo) (*sql.DB, error) { dsn := s.buildDSN(info) dbType := strings.ToLower(info.DBType) var db *sql.DB var err error switch dbType { case "mysql": db, err = sql.Open("mysql", dsn) case "postgres", "postgresql": db, err = sql.Open("postgres", dsn) default: return nil, fmt.Errorf("unsupported database type: %s", dbType) } if err != nil { return nil, err } if err := db.Ping(); err != nil { return nil, err } return db, nil } // getTableDDL 获取表的 DDL func (s *DatabaseService) getTableDDL(db *sql.DB, dbType, tableName string) (string, error) { switch dbType { case "mysql": query := fmt.Sprintf("SHOW CREATE TABLE `%s`", tableName) row := db.QueryRow(query) var tblName, createStmt string if err := row.Scan(&tblName, &createStmt); err != nil { return "", err } return createStmt, nil case "postgres", "postgresql": query := fmt.Sprintf("SELECT pg_get_create('%s')", tableName) var ddl string if err := db.QueryRow(query).Scan(&ddl); err != nil { return "", err } return ddl, nil default: return "", fmt.Errorf("unsupported database type: %s", dbType) } } // buildMappedDDL 根据字段映射生成带 COMMENT 的 DDL func (s *DatabaseService) buildMappedDDL(originalDDL string, fields []model.FieldMapping) string { // 构建列名到映射名的映射 columnMap := make(map[string]string) for _, f := range fields { if f.MappedName != "" { columnMap[f.ColumnName] = f.MappedName } } if len(columnMap) == 0 { return originalDDL } // 解析原始 DDL,为有映射的列添加 COMMENT lines := strings.Split(originalDDL, "\n") var resultLines []string for _, line := range lines { trimmed := strings.TrimSpace(line) // 检查是否是列定义行(以 ` 开头,包含数据类型) if strings.HasPrefix(trimmed, "`") { // 提取列名 parts := strings.SplitN(trimmed, " ", 2) if len(parts) >= 1 { colName := strings.Trim(parts[0], "`") // 检查是否有映射 if mappedName, ok := columnMap[colName]; ok { // 去掉结尾的逗号(如果有) trimmed = strings.TrimRight(trimmed, ",") // 检查是否已经有 COMMENT if strings.Contains(strings.ToUpper(trimmed), "COMMENT") { // 替换已有的 COMMENT trimmed = strings.TrimSuffix(trimmed, " COMMENT '...'") trimmed = fmt.Sprintf("%s COMMENT '%s'", trimmed, mappedName) } else { // 在末尾添加 COMMENT trimmed = fmt.Sprintf("%s COMMENT '%s'", trimmed, mappedName) } // 替换原始行为修改后的行 resultLines = append(resultLines, trimmed) continue } } } resultLines = append(resultLines, line) } return strings.Join(resultLines, "\n") } // Check 检查数据库连接 func (s *DatabaseService) Check(req model.CheckRequest) (*model.CheckResponse, error) { log.Printf("[Check] 开始检查连接: 类型=%s, 主机=%s, 端口=%d, 数据库=%s, 用户=%s", req.DBType, req.Host, req.Port, req.Database, req.Username) info := &model.DatabaseInfo{ DBType: req.DBType, Host: req.Host, Port: req.Port, Username: req.Username, Password: req.Password, Database: req.Database, Charset: req.Charset, SSLMode: req.SSLMode, } if info.Charset == "" { info.Charset = "utf8mb4" } // 构建连接 dsn := s.buildDSN(info) dbType := strings.ToLower(info.DBType) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() var db *sql.DB var err error // Neo4j 处理 if dbType == "neo4j" { log.Printf("[Check] 检测到 Neo4j 类型,使用图数据库连接...") neo4jService := NewNeo4jService(s.repo) graph, err := neo4jService.GetGraphOverview(req) if err != nil { log.Printf("[Check] Neo4j 连接失败: %v", err) return &model.CheckResponse{ Success: false, Message: fmt.Sprintf("neo4j connection failed: %v", err), }, nil } log.Printf("[Check] Neo4j 连接成功,获取到 %d 个标签", len(graph.Labels)) return &model.CheckResponse{ Success: true, Message: "connection successful", Graphs: graph, Database: req.Database, }, nil } switch dbType { case "mysql": db, err = sql.Open("mysql", dsn) case "postgres", "postgresql": db, err = sql.Open("postgres", dsn) default: return &model.CheckResponse{ Success: false, Message: fmt.Sprintf("unsupported database type: %s", req.DBType), }, nil } if err != nil { return &model.CheckResponse{ Success: false, Message: fmt.Sprintf("failed to create connection: %v", err), }, nil } defer db.Close() // 测试连接 if err := db.PingContext(ctx); err != nil { log.Printf("[Check] 连接失败: %v", err) return &model.CheckResponse{ Success: false, Message: fmt.Sprintf("cannot connect to database: %v", err), }, nil } log.Printf("[Check] 连接成功,开始获取表列表...") // 获取表列表 var tables []model.TableDDLInfo switch dbType { case "mysql": tables, _ = s.getMySQLTables(db, req.Database) case "postgres", "postgresql": tables, _ = s.getPostgresTables(db, req.Database) } log.Printf("[Check] 获取到 %d 个表", len(tables)) // 如果传入了 database_id,获取已保存的字段映射和 DDL 并填充到表结构中 if req.DatabaseID != "" && s.subTableRepo != nil { s.fillFieldMappings(req.DatabaseID, tables) s.fillDDL(req.DatabaseID, tables) } return &model.CheckResponse{ Success: true, Message: "connection successful", Tables: tables, Database: req.Database, }, nil } // getMySQLTables 获取MySQL表结构 func (s *DatabaseService) getMySQLTables(db *sql.DB, dbName string) ([]model.TableDDLInfo, error) { rows, err := db.Query(` SELECT TABLE_NAME, TABLE_COMMENT FROM information_schema.TABLES WHERE TABLE_SCHEMA = ? AND TABLE_TYPE = 'BASE TABLE' `, dbName) if err != nil { return nil, err } defer rows.Close() var tables []model.TableDDLInfo for rows.Next() { var tableName, tableComment string if err := rows.Scan(&tableName, &tableComment); err != nil { continue } table := model.TableDDLInfo{ TableName: tableName, TableComment: tableComment, } // 获取列信息 table.Columns, _ = s.getMySQLColumns(db, dbName, tableName) // 获取 DDL table.DDL, _ = s.getMySQLDDL(db, tableName) tables = append(tables, table) } return tables, nil } // getMySQLDDL 获取 MySQL 表的 DDL func (s *DatabaseService) getMySQLDDL(db *sql.DB, tableName string) (string, error) { // 使用反引号包裹表名,防止关键字冲突 query := fmt.Sprintf("SHOW CREATE TABLE `%s`", tableName) row := db.QueryRow(query) var tblName, createStmt string if err := row.Scan(&tblName, &createStmt); err != nil { log.Printf("[getMySQLDDL] 获取 DDL 失败: %v", err) return "", nil } return createStmt, nil } // getMySQLColumns 获取MySQL列信息 func (s *DatabaseService) getMySQLColumns(db *sql.DB, dbName, tableName string) ([]model.ColumnInfo, error) { rows, err := db.Query(` SELECT COLUMN_NAME, DATA_TYPE, COLUMN_TYPE, IS_NULLABLE, COLUMN_DEFAULT, COLUMN_KEY, EXTRA, COLUMN_COMMENT FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? ORDER BY ORDINAL_POSITION `, dbName, tableName) if err != nil { log.Printf("[getMySQLColumns] 查询列信息失败: %v", err) return nil, err } defer rows.Close() columns := make([]model.ColumnInfo, 0) for rows.Next() { var col model.ColumnInfo var defaultValue, extra, columnComment sql.NullString if err := rows.Scan(&col.ColumnName, &col.DataType, &col.ColumnType, &col.IsNullable, &defaultValue, &col.ColumnKey, &extra, &columnComment); err != nil { log.Printf("[getMySQLColumns] Scan 失败: %v", err) continue } col.DefaultValue = defaultValue.String col.Extra = extra.String col.ColumnComment = columnComment.String columns = append(columns, col) } // 检查是否有迭代错误 if err := rows.Err(); err != nil { log.Printf("[getMySQLColumns] 迭代错误: %v", err) } return columns, nil } // getPostgresTables 获取PostgreSQL表结构 func (s *DatabaseService) getPostgresTables(db *sql.DB, dbName string) ([]model.TableDDLInfo, error) { rows, err := db.Query(` SELECT t.table_name, obj_description((t.table_schema || '.' || t.table_name)::regclass) FROM information_schema.tables t WHERE t.table_schema = 'public' AND t.table_type = 'BASE TABLE' `, dbName) if err != nil { return nil, err } defer rows.Close() var tables []model.TableDDLInfo for rows.Next() { var tableName, tableComment string if err := rows.Scan(&tableName, &tableComment); err != nil { continue } table := model.TableDDLInfo{ TableName: tableName, TableComment: tableComment, } // 获取列信息 table.Columns, _ = s.getPostgresColumns(db, tableName) // 获取 DDL table.DDL, _ = s.getPostgresDDL(db, tableName) tables = append(tables, table) } return tables, nil } // getPostgresDDL 获取 PostgreSQL 表的 DDL func (s *DatabaseService) getPostgresDDL(db *sql.DB, tableName string) (string, error) { var ddl string query := fmt.Sprintf("SELECT pg_get_create('%s')", tableName) row := db.QueryRow(query) if err := row.Scan(&ddl); err != nil { log.Printf("[getPostgresDDL] 获取 DDL 失败: %v", err) return "", nil } return ddl, nil } // getPostgresColumns 获取PostgreSQL列信息 func (s *DatabaseService) getPostgresColumns(db *sql.DB, tableName string) ([]model.ColumnInfo, error) { rows, err := db.Query(` SELECT c.column_name, c.data_type, c.udt_name, c.is_nullable, c.column_default, c.column_name, '', c.column_comment FROM information_schema.columns c WHERE c.table_name = $1 AND c.table_schema = 'public' ORDER BY c.ordinal_position `, tableName) if err != nil { return nil, err } defer rows.Close() var columns []model.ColumnInfo for rows.Next() { var col model.ColumnInfo if err := rows.Scan(&col.ColumnName, &col.DataType, &col.ColumnType, &col.IsNullable, &col.DefaultValue, &col.ColumnKey, &col.Extra, &col.ColumnComment); err != nil { continue } columns = append(columns, col) } return columns, nil } // Create 创建数据库信息(支持同时保存子表配置) func (s *DatabaseService) Create(req model.CreateDatabaseRequest) (*model.DatabaseInfo, error) { log.Printf("[Create] 收到创建请求: %+v", req) info := &model.DatabaseInfo{ ID: uuid.New().String(), Name: req.Name, Description: req.Description, DBType: strings.ToLower(req.DBType), // 统一转为小写 Host: req.Host, Port: req.Port, Username: req.Username, Password: req.Password, Database: req.Database, Charset: req.Charset, SSLMode: req.SSLMode, TableCount: len(req.SubTables), } // 默认值 if info.Charset == "" { info.Charset = "utf8mb4" } // 测试数据库连通性 log.Printf("[Create] 开始测试数据库连接...") if err := s.TestConnection(info); err != nil { log.Printf("[Create] 数据库连接测试失败: %v", err) return nil, fmt.Errorf("database connection failed: %v", err) } log.Printf("[Create] 数据库连接测试成功!") // 保存数据库信息 if err := s.repo.Create(info); err != nil { log.Printf("[Create] 保存数据库失败: %v", err) return nil, err } // 保存子表配置(如有) if len(req.SubTables) > 0 && s.subTableRepo != nil { log.Printf("[Create] 保存 %d 个子表配置", len(req.SubTables)) // 获取数据库连接用于查询 DDL db, err := s.getConnection(info) if err != nil { log.Printf("[Create] 获取数据库连接失败: %v", err) } else { defer db.Close() } for _, subReq := range req.SubTables { subTable := &model.SubTableInfo{ ID: uuid.New().String(), DatabaseID: info.ID, ParentTable: subReq.ParentTable, SubTableName: subReq.SubTableName, SubTableComment: subReq.SubTableComment, MappingType: subReq.MappingType, RelationField: subReq.RelationField, RelationType: subReq.RelationType, } // 使用 SetFields 方法保存字段映射 subTable.SetFields(subReq.Fields) // 获取并保存 DDL if db != nil { ddl, err := s.getTableDDL(db, strings.ToLower(info.DBType), subReq.ParentTable) if err != nil { log.Printf("[Create] 获取原始 DDL 失败: %v", err) } else { // 如果有字段映射,生成带 COMMENT 的新 DDL if len(subReq.Fields) > 0 { subTable.DDL = s.buildMappedDDL(ddl, subReq.Fields) log.Printf("[Create] 生成映射后的 DDL,长度: %d", len(subTable.DDL)) } else { subTable.DDL = ddl } } } if err := s.subTableRepo.Create(subTable); err != nil { log.Printf("[Create] 保存子表失败: %v", err) } } // 同步到文件 s.syncSubTablesToFile(info) } log.Printf("[Create] 创建成功, ID=%s", info.ID) return info, nil } // syncSubTablesToFile 同步子表到文件 func (s *DatabaseService) syncSubTablesToFile(info *model.DatabaseInfo) { if s.subTableRepo == nil { return } tables, err := s.subTableRepo.FindByDatabaseID(info.ID) if err != nil { log.Printf("[syncSubTablesToFile] 查询子表失败: %v", err) return } mapping := &model.SubTableMapping{ DatabaseID: info.ID, DatabaseName: info.Name, DBType: info.DBType, Tables: tables, UpdatedAt: time.Now(), } resourceDir := "resources/db_info" os.MkdirAll(resourceDir, 0755) data, err := json.MarshalIndent(mapping, "", " ") if err != nil { log.Printf("[syncSubTablesToFile] 序列化失败: %v", err) return } filePath := fmt.Sprintf("%s/%s.json", resourceDir, info.ID) if err := os.WriteFile(filePath, data, 0644); err != nil { log.Printf("[syncSubTablesToFile] 写入文件失败: %v", err) } log.Printf("[syncSubTablesToFile] 同步成功: %s", filePath) } // GetByID 获取详情 func (s *DatabaseService) GetByID(id string) (*model.DatabaseInfo, error) { log.Printf("[GetByID] 查询 ID=%s", id) info, err := s.repo.FindByID(id) if err != nil { log.Printf("[GetByID] 查询失败: %v", err) return nil, ErrDatabaseNotFound } return info, nil } // List 获取列表 func (s *DatabaseService) List() ([]model.DatabaseInfo, error) { log.Printf("[List] 查询所有数据库列表") return s.repo.FindAll() } // Update 更新 func (s *DatabaseService) Update(id string, req model.UpdateDatabaseRequest) (*model.DatabaseInfo, error) { log.Printf("[Update] 更新 ID=%s, 数据=%+v", id, req) // 检查是否存在 _, err := s.repo.FindByID(id) if err != nil { log.Printf("[Update] 不存在: %v", err) return nil, ErrDatabaseNotFound } // 构建更新数据 updates := map[string]interface{}{} if req.Name != "" { updates["name"] = req.Name } if req.Description != "" { updates["description"] = req.Description } if req.DBType != "" { updates["db_type"] = req.DBType } if req.Host != "" { updates["host"] = req.Host } if req.Port > 0 { updates["port"] = req.Port } if req.Username != "" { updates["username"] = req.Username } if req.Password != "" { updates["password"] = req.Password } if req.Database != "" { updates["database"] = req.Database } if req.TableCount > 0 { updates["table_count"] = req.TableCount } if req.Charset != "" { updates["charset"] = req.Charset } if req.SSLMode != "" { updates["ssl_mode"] = req.SSLMode } info := &model.DatabaseInfo{} if err := s.repo.Update(id, info); err != nil { log.Printf("[Update] 更新失败: %v", err) return nil, err } return s.repo.FindByID(id) } // SaveGraph 保存图谱信息 func (s *DatabaseService) SaveGraph(req model.SaveGraphRequest) (*model.SaveGraphResponse, error) { log.Printf("[SaveGraph] 保存图谱信息, databaseId=%s, databaseName=%s", req.DatabaseID, req.DatabaseName) // 检查数据库是否存在 _, err := s.repo.FindByID(req.DatabaseID) if err != nil { // 如果不存在,创建一个新的 if err == ErrDatabaseNotFound { // 创建新的数据库记录 dbType := "neo4j" // 从 URI 解析 host 和 port host := "localhost" port := 7687 if req.URI != "" { uri := strings.TrimPrefix(req.URI, "bolt://") uri = strings.TrimPrefix(uri, "neo4j://") if idx := strings.Index(uri, ":"); idx > 0 { host = uri[:idx] fmt.Sscanf(uri[idx+1:], "%d", &port) } } // 将 labels 和 relationshipTypes 转为 JSON 字符串 labelsJSON, _ := json.Marshal(req.Labels) relJSON, _ := json.Marshal(req.RelationshipTypes) info := &model.DatabaseInfo{ ID: req.DatabaseID, Name: req.DatabaseName, DBType: dbType, Host: host, Port: port, Username: req.Username, URI: req.URI, GraphLabels: string(labelsJSON), GraphRelationship: string(relJSON), SelectedLabel: req.SelectedLabel, } if err := s.repo.Create(info); err != nil { log.Printf("[SaveGraph] 创建失败: %v", err) return &model.SaveGraphResponse{ Success: false, Message: fmt.Sprintf("创建失败: %v", err), }, err } return &model.SaveGraphResponse{ Success: true, Message: "保存成功", }, nil } log.Printf("[SaveGraph] 查询失败: %v", err) return &model.SaveGraphResponse{ Success: false, Message: fmt.Sprintf("查询失败: %v", err), }, err } // 更新现有记录 labelsJSON, _ := json.Marshal(req.Labels) relJSON, _ := json.Marshal(req.RelationshipTypes) updates := map[string]interface{}{ "uri": req.URI, "username": req.Username, "graph_labels": string(labelsJSON), "graph_relationship": string(relJSON), "selected_label": req.SelectedLabel, } if err := s.repo.UpdateFields(req.DatabaseID, updates); err != nil { log.Printf("[SaveGraph] 更新失败: %v", err) return &model.SaveGraphResponse{ Success: false, Message: fmt.Sprintf("更新失败: %v", err), }, err } return &model.SaveGraphResponse{ Success: true, Message: "保存成功", }, nil } // fillFieldMappings 填充字段映射到表结构中 func (s *DatabaseService) fillFieldMappings(databaseID string, tables []model.TableDDLInfo) { // 从数据库中获取该数据库下所有子表的字段映射 subTables, err := s.subTableRepo.FindByDatabaseID(databaseID) if err != nil { log.Printf("[fillFieldMappings] 查询子表失败: %v", err) return } // 构建表名到字段映射的映射 tableFieldsMap := make(map[string][]model.FieldMapping) for _, st := range subTables { fields := st.GetFields() if len(fields) > 0 { tableFieldsMap[st.ParentTable] = fields } } // 遍历返回的表结构,填充字段映射 for i := range tables { tableName := tables[i].TableName if fields, ok := tableFieldsMap[tableName]; ok { // 构建列名到映射名的映射 columnMap := make(map[string]string) for _, f := range fields { columnMap[f.ColumnName] = f.MappedName } // 填充到每个列 for j := range tables[i].Columns { colName := tables[i].Columns[j].ColumnName if mappedName, ok := columnMap[colName]; ok { tables[i].Columns[j].MappedName = mappedName } } } } log.Printf("[fillFieldMappings] 已填充字段映射到 %d 个表", len(tables)) } // fillDDL 填充已保存的 DDL 到表结构中 func (s *DatabaseService) fillDDL(databaseID string, tables []model.TableDDLInfo) { // 从数据库中获取该数据库下所有子表的 DDL subTables, err := s.subTableRepo.FindByDatabaseID(databaseID) if err != nil { log.Printf("[fillDDL] 查询子表失败: %v", err) return } // 构建表名到 DDL 的映射 tableDDLMap := make(map[string]string) for _, st := range subTables { if st.DDL != "" { tableDDLMap[st.ParentTable] = st.DDL } } // 遍历返回的表结构,填充 DDL for i := range tables { tableName := tables[i].TableName if ddl, ok := tableDDLMap[tableName]; ok { tables[i].DDL = ddl } } log.Printf("[fillDDL] 已填充 DDL 到 %d 个表", len(tables)) } // Delete 删除 func (s *DatabaseService) Delete(id string) error { log.Printf("[Delete] 删除 ID=%s", id) _, err := s.repo.FindByID(id) if err != nil { log.Printf("[Delete] 不存在: %v", err) return ErrDatabaseNotFound } return s.repo.Delete(id) }