Files
X-Agents/server/internal/service/database_service.go
DESKTOP-72TV0V4\caoxiaozhu 22be617905 feat: 完善子表删除逻辑和table_count同步更新
- 数据库更新时自动删除不在新列表中的子表
- 同步更新 table_count 为当前子表数量
- 删除数据库时级联删除关联的子表记录
- 添加相关需求文档

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-07 11:38:27 +08:00

971 lines
26 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"log"
"os"
"strings"
"time"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
"x-agents/server/internal/model"
"x-agents/server/internal/repository"
"github.com/google/uuid"
)
var (
ErrDatabaseNotFound = errors.New("database not found")
ErrDatabaseUnreachable = errors.New("database cannot be connected")
)
type DatabaseService struct {
repo *repository.DatabaseRepository
subTableRepo *repository.SubTableRepository
}
func NewDatabaseService(repo *repository.DatabaseRepository, subTableRepo *repository.SubTableRepository) *DatabaseService {
return &DatabaseService{
repo: repo,
subTableRepo: subTableRepo,
}
}
// TestConnection 测试数据库连通性
func (s *DatabaseService) TestConnection(info *model.DatabaseInfo) error {
log.Printf("[数据库连接测试] 开始测试连接: 类型=%s, 主机=%s, 端口=%d, 数据库=%s, 用户=%s",
info.DBType, info.Host, info.Port, info.Database, info.Username)
// 统一转换为小写处理
dbType := strings.ToLower(info.DBType)
// 构建连接字符串
dsn := s.buildDSN(info)
log.Printf("[数据库连接测试] DSN构建完成: %s", dsn)
// 设置超时
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 根据数据库类型连接
var db *sql.DB
var err error
switch dbType {
case "mysql":
db, err = sql.Open("mysql", dsn)
case "postgres", "postgresql":
db, err = sql.Open("postgres", dsn)
default:
errMsg := fmt.Sprintf("unsupported database type: %s", info.DBType)
log.Printf("[数据库连接测试] 错误: %s", errMsg)
return fmt.Errorf(errMsg)
}
if err != nil {
errMsg := fmt.Sprintf("failed to create connection: %v", err)
log.Printf("[数据库连接测试] 错误: %s", errMsg)
return fmt.Errorf(errMsg)
}
defer db.Close()
// 测试连接
if err := db.PingContext(ctx); err != nil {
errMsg := fmt.Sprintf("cannot connect to database: %v", err)
log.Printf("[数据库连接测试] 连接失败: %s", errMsg)
return fmt.Errorf(errMsg)
}
log.Printf("[数据库连接测试] 连接成功!")
return nil
}
// buildDSN 构建数据库连接字符串
func (s *DatabaseService) buildDSN(info *model.DatabaseInfo) string {
dbType := strings.ToLower(info.DBType)
switch dbType {
case "mysql":
charset := info.Charset
if charset == "" {
charset = "utf8mb4"
}
// 如果没有指定数据库名,只测试连接
dbName := info.Database
if dbName == "" {
dbName = "mysql"
}
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&timeout=5s&parseTime=True",
info.Username,
info.Password,
info.Host,
info.Port,
dbName,
charset,
)
case "postgres", "postgresql":
sslmode := "disable"
if info.SSLMode != "" {
sslmode = info.SSLMode
}
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s connect_timeout=5",
info.Host,
info.Port,
info.Username,
info.Password,
info.Database,
sslmode,
)
default:
return ""
}
}
// getConnection 获取数据库连接
func (s *DatabaseService) getConnection(info *model.DatabaseInfo) (*sql.DB, error) {
dsn := s.buildDSN(info)
dbType := strings.ToLower(info.DBType)
var db *sql.DB
var err error
switch dbType {
case "mysql":
db, err = sql.Open("mysql", dsn)
case "postgres", "postgresql":
db, err = sql.Open("postgres", dsn)
default:
return nil, fmt.Errorf("unsupported database type: %s", dbType)
}
if err != nil {
return nil, err
}
if err := db.Ping(); err != nil {
return nil, err
}
return db, nil
}
// getTableDDL 获取表的 DDL
func (s *DatabaseService) getTableDDL(db *sql.DB, dbType, tableName string) (string, error) {
switch dbType {
case "mysql":
query := fmt.Sprintf("SHOW CREATE TABLE `%s`", tableName)
row := db.QueryRow(query)
var tblName, createStmt string
if err := row.Scan(&tblName, &createStmt); err != nil {
return "", err
}
return createStmt, nil
case "postgres", "postgresql":
query := fmt.Sprintf("SELECT pg_get_create('%s')", tableName)
var ddl string
if err := db.QueryRow(query).Scan(&ddl); err != nil {
return "", err
}
return ddl, nil
default:
return "", fmt.Errorf("unsupported database type: %s", dbType)
}
}
// buildMappedDDL 根据字段映射生成带 COMMENT 的 DDL
func (s *DatabaseService) buildMappedDDL(originalDDL string, fields []model.FieldMapping) string {
// 构建列名到映射名的映射
columnMap := make(map[string]string)
for _, f := range fields {
if f.MappedName != "" {
columnMap[f.ColumnName] = f.MappedName
}
}
if len(columnMap) == 0 {
return originalDDL
}
// 解析原始 DDL为有映射的列添加 COMMENT
lines := strings.Split(originalDDL, "\n")
var resultLines []string
for _, line := range lines {
trimmed := strings.TrimSpace(line)
// 检查是否是列定义行(以 ` 开头,包含数据类型)
if strings.HasPrefix(trimmed, "`") {
// 提取列名
parts := strings.SplitN(trimmed, " ", 2)
if len(parts) >= 1 {
colName := strings.Trim(parts[0], "`")
// 检查是否有映射
if mappedName, ok := columnMap[colName]; ok {
// 去掉结尾的逗号(如果有)
trimmed = strings.TrimRight(trimmed, ",")
// 检查是否已经有 COMMENT
if strings.Contains(strings.ToUpper(trimmed), "COMMENT") {
// 替换已有的 COMMENT
trimmed = strings.TrimSuffix(trimmed, " COMMENT '...'")
trimmed = fmt.Sprintf("%s COMMENT '%s'", trimmed, mappedName)
} else {
// 在末尾添加 COMMENT
trimmed = fmt.Sprintf("%s COMMENT '%s'", trimmed, mappedName)
}
// 替换原始行为修改后的行
resultLines = append(resultLines, trimmed)
continue
}
}
}
resultLines = append(resultLines, line)
}
return strings.Join(resultLines, "\n")
}
// Check 检查数据库连接
func (s *DatabaseService) Check(req model.CheckRequest) (*model.CheckResponse, error) {
log.Printf("[Check] 开始检查连接: 类型=%s, 主机=%s, 端口=%d, 数据库=%s, 用户=%s",
req.DBType, req.Host, req.Port, req.Database, req.Username)
info := &model.DatabaseInfo{
DBType: req.DBType,
Host: req.Host,
Port: req.Port,
Username: req.Username,
Password: req.Password,
Database: req.Database,
Charset: req.Charset,
SSLMode: req.SSLMode,
}
if info.Charset == "" {
info.Charset = "utf8mb4"
}
// 构建连接
dsn := s.buildDSN(info)
dbType := strings.ToLower(info.DBType)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
var db *sql.DB
var err error
// Neo4j 处理
if dbType == "neo4j" {
log.Printf("[Check] 检测到 Neo4j 类型,使用图数据库连接...")
neo4jService := NewNeo4jService(s.repo)
graph, err := neo4jService.GetGraphOverview(req)
if err != nil {
log.Printf("[Check] Neo4j 连接失败: %v", err)
return &model.CheckResponse{
Success: false,
Message: fmt.Sprintf("neo4j connection failed: %v", err),
}, nil
}
log.Printf("[Check] Neo4j 连接成功,获取到 %d 个标签", len(graph.Labels))
return &model.CheckResponse{
Success: true,
Message: "connection successful",
Graphs: graph,
Database: req.Database,
}, nil
}
switch dbType {
case "mysql":
db, err = sql.Open("mysql", dsn)
case "postgres", "postgresql":
db, err = sql.Open("postgres", dsn)
default:
return &model.CheckResponse{
Success: false,
Message: fmt.Sprintf("unsupported database type: %s", req.DBType),
}, nil
}
if err != nil {
return &model.CheckResponse{
Success: false,
Message: fmt.Sprintf("failed to create connection: %v", err),
}, nil
}
defer db.Close()
// 测试连接
if err := db.PingContext(ctx); err != nil {
log.Printf("[Check] 连接失败: %v", err)
return &model.CheckResponse{
Success: false,
Message: fmt.Sprintf("cannot connect to database: %v", err),
}, nil
}
log.Printf("[Check] 连接成功,开始获取表列表...")
// 获取表列表
var tables []model.TableDDLInfo
switch dbType {
case "mysql":
tables, _ = s.getMySQLTables(db, req.Database)
case "postgres", "postgresql":
tables, _ = s.getPostgresTables(db, req.Database)
}
log.Printf("[Check] 获取到 %d 个表", len(tables))
// 如果传入了 database_id获取已保存的字段映射和 DDL 并填充到表结构中
if req.DatabaseID != "" && s.subTableRepo != nil {
s.fillFieldMappings(req.DatabaseID, tables)
s.fillDDL(req.DatabaseID, tables)
}
return &model.CheckResponse{
Success: true,
Message: "connection successful",
Tables: tables,
Database: req.Database,
}, nil
}
// getMySQLTables 获取MySQL表结构
func (s *DatabaseService) getMySQLTables(db *sql.DB, dbName string) ([]model.TableDDLInfo, error) {
rows, err := db.Query(`
SELECT TABLE_NAME, TABLE_COMMENT
FROM information_schema.TABLES
WHERE TABLE_SCHEMA = ?
AND TABLE_TYPE = 'BASE TABLE'
`, dbName)
if err != nil {
return nil, err
}
defer rows.Close()
var tables []model.TableDDLInfo
for rows.Next() {
var tableName, tableComment string
if err := rows.Scan(&tableName, &tableComment); err != nil {
continue
}
table := model.TableDDLInfo{
TableName: tableName,
TableComment: tableComment,
}
// 获取列信息
table.Columns, _ = s.getMySQLColumns(db, dbName, tableName)
// 获取 DDL
table.DDL, _ = s.getMySQLDDL(db, tableName)
tables = append(tables, table)
}
return tables, nil
}
// getMySQLDDL 获取 MySQL 表的 DDL
func (s *DatabaseService) getMySQLDDL(db *sql.DB, tableName string) (string, error) {
// 使用反引号包裹表名,防止关键字冲突
query := fmt.Sprintf("SHOW CREATE TABLE `%s`", tableName)
row := db.QueryRow(query)
var tblName, createStmt string
if err := row.Scan(&tblName, &createStmt); err != nil {
log.Printf("[getMySQLDDL] 获取 DDL 失败: %v", err)
return "", nil
}
return createStmt, nil
}
// getMySQLColumns 获取MySQL列信息
func (s *DatabaseService) getMySQLColumns(db *sql.DB, dbName, tableName string) ([]model.ColumnInfo, error) {
rows, err := db.Query(`
SELECT
COLUMN_NAME, DATA_TYPE, COLUMN_TYPE,
IS_NULLABLE, COLUMN_DEFAULT, COLUMN_KEY,
EXTRA, COLUMN_COMMENT
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?
ORDER BY ORDINAL_POSITION
`, dbName, tableName)
if err != nil {
log.Printf("[getMySQLColumns] 查询列信息失败: %v", err)
return nil, err
}
defer rows.Close()
columns := make([]model.ColumnInfo, 0)
for rows.Next() {
var col model.ColumnInfo
var defaultValue, extra, columnComment sql.NullString
if err := rows.Scan(&col.ColumnName, &col.DataType, &col.ColumnType,
&col.IsNullable, &defaultValue, &col.ColumnKey, &extra, &columnComment); err != nil {
log.Printf("[getMySQLColumns] Scan 失败: %v", err)
continue
}
col.DefaultValue = defaultValue.String
col.Extra = extra.String
col.ColumnComment = columnComment.String
columns = append(columns, col)
}
// 检查是否有迭代错误
if err := rows.Err(); err != nil {
log.Printf("[getMySQLColumns] 迭代错误: %v", err)
}
return columns, nil
}
// getPostgresTables 获取PostgreSQL表结构
func (s *DatabaseService) getPostgresTables(db *sql.DB, dbName string) ([]model.TableDDLInfo, error) {
rows, err := db.Query(`
SELECT t.table_name, obj_description((t.table_schema || '.' || t.table_name)::regclass)
FROM information_schema.tables t
WHERE t.table_schema = 'public' AND t.table_type = 'BASE TABLE'
`, dbName)
if err != nil {
return nil, err
}
defer rows.Close()
var tables []model.TableDDLInfo
for rows.Next() {
var tableName, tableComment string
if err := rows.Scan(&tableName, &tableComment); err != nil {
continue
}
table := model.TableDDLInfo{
TableName: tableName,
TableComment: tableComment,
}
// 获取列信息
table.Columns, _ = s.getPostgresColumns(db, tableName)
// 获取 DDL
table.DDL, _ = s.getPostgresDDL(db, tableName)
tables = append(tables, table)
}
return tables, nil
}
// getPostgresDDL 获取 PostgreSQL 表的 DDL
func (s *DatabaseService) getPostgresDDL(db *sql.DB, tableName string) (string, error) {
var ddl string
query := fmt.Sprintf("SELECT pg_get_create('%s')", tableName)
row := db.QueryRow(query)
if err := row.Scan(&ddl); err != nil {
log.Printf("[getPostgresDDL] 获取 DDL 失败: %v", err)
return "", nil
}
return ddl, nil
}
// getPostgresColumns 获取PostgreSQL列信息
func (s *DatabaseService) getPostgresColumns(db *sql.DB, tableName string) ([]model.ColumnInfo, error) {
rows, err := db.Query(`
SELECT
c.column_name, c.data_type, c.udt_name,
c.is_nullable, c.column_default, c.column_name,
'', c.column_comment
FROM information_schema.columns c
WHERE c.table_name = $1 AND c.table_schema = 'public'
ORDER BY c.ordinal_position
`, tableName)
if err != nil {
return nil, err
}
defer rows.Close()
var columns []model.ColumnInfo
for rows.Next() {
var col model.ColumnInfo
if err := rows.Scan(&col.ColumnName, &col.DataType, &col.ColumnType,
&col.IsNullable, &col.DefaultValue, &col.ColumnKey, &col.Extra, &col.ColumnComment); err != nil {
continue
}
columns = append(columns, col)
}
return columns, nil
}
// Create 创建数据库信息(支持同时保存子表配置)
func (s *DatabaseService) Create(req model.CreateDatabaseRequest) (*model.DatabaseInfo, error) {
log.Printf("[Create] 收到创建请求: %+v", req)
info := &model.DatabaseInfo{
ID: uuid.New().String(),
Name: req.Name,
Description: req.Description,
DBType: strings.ToLower(req.DBType), // 统一转为小写
Host: req.Host,
Port: req.Port,
Username: req.Username,
Password: req.Password,
Database: req.Database,
Charset: req.Charset,
SSLMode: req.SSLMode,
TableCount: len(req.SubTables),
}
// 默认值
if info.Charset == "" {
info.Charset = "utf8mb4"
}
// 测试数据库连通性
log.Printf("[Create] 开始测试数据库连接...")
if err := s.TestConnection(info); err != nil {
log.Printf("[Create] 数据库连接测试失败: %v", err)
return nil, fmt.Errorf("database connection failed: %v", err)
}
log.Printf("[Create] 数据库连接测试成功!")
// 保存数据库信息
if err := s.repo.Create(info); err != nil {
log.Printf("[Create] 保存数据库失败: %v", err)
return nil, err
}
// 保存子表配置(如有)
if len(req.SubTables) > 0 && s.subTableRepo != nil {
log.Printf("[Create] 保存 %d 个子表配置", len(req.SubTables))
// 获取数据库连接用于查询 DDL
db, err := s.getConnection(info)
if err != nil {
log.Printf("[Create] 获取数据库连接失败: %v", err)
} else {
defer db.Close()
}
for _, subReq := range req.SubTables {
subTable := &model.SubTableInfo{
ID: uuid.New().String(),
DatabaseID: info.ID,
ParentTable: subReq.ParentTable,
SubTableName: subReq.SubTableName,
SubTableComment: subReq.SubTableComment,
MappingType: subReq.MappingType,
RelationField: subReq.RelationField,
RelationType: subReq.RelationType,
}
// 使用 SetFields 方法保存字段映射
subTable.SetFields(subReq.Fields)
// 获取并保存 DDL
if db != nil {
ddl, err := s.getTableDDL(db, strings.ToLower(info.DBType), subReq.ParentTable)
if err != nil {
log.Printf("[Create] 获取原始 DDL 失败: %v", err)
} else {
// 如果有字段映射,生成带 COMMENT 的新 DDL
if len(subReq.Fields) > 0 {
subTable.DDL = s.buildMappedDDL(ddl, subReq.Fields)
log.Printf("[Create] 生成映射后的 DDL长度: %d", len(subTable.DDL))
} else {
subTable.DDL = ddl
}
}
}
if err := s.subTableRepo.Create(subTable); err != nil {
log.Printf("[Create] 保存子表失败: %v", err)
}
}
// 同步到文件
s.syncSubTablesToFile(info)
}
log.Printf("[Create] 创建成功, ID=%s", info.ID)
return info, nil
}
// syncSubTablesToFile 同步子表到文件
func (s *DatabaseService) syncSubTablesToFile(info *model.DatabaseInfo) {
if s.subTableRepo == nil {
return
}
tables, err := s.subTableRepo.FindByDatabaseID(info.ID)
if err != nil {
log.Printf("[syncSubTablesToFile] 查询子表失败: %v", err)
return
}
mapping := &model.SubTableMapping{
DatabaseID: info.ID,
DatabaseName: info.Name,
DBType: info.DBType,
Tables: tables,
UpdatedAt: time.Now(),
}
resourceDir := "resources/db_info"
os.MkdirAll(resourceDir, 0755)
data, err := json.MarshalIndent(mapping, "", " ")
if err != nil {
log.Printf("[syncSubTablesToFile] 序列化失败: %v", err)
return
}
filePath := fmt.Sprintf("%s/%s.json", resourceDir, info.ID)
if err := os.WriteFile(filePath, data, 0644); err != nil {
log.Printf("[syncSubTablesToFile] 写入文件失败: %v", err)
}
log.Printf("[syncSubTablesToFile] 同步成功: %s", filePath)
}
// GetByID 获取详情
func (s *DatabaseService) GetByID(id string) (*model.DatabaseInfo, error) {
log.Printf("[GetByID] 查询 ID=%s", id)
info, err := s.repo.FindByID(id)
if err != nil {
log.Printf("[GetByID] 查询失败: %v", err)
return nil, ErrDatabaseNotFound
}
return info, nil
}
// List 获取列表
func (s *DatabaseService) List() ([]model.DatabaseInfo, error) {
log.Printf("[List] 查询所有数据库列表")
return s.repo.FindAll()
}
// Update 更新
func (s *DatabaseService) Update(id string, req model.UpdateDatabaseRequest) (*model.DatabaseInfo, error) {
log.Printf("[Update] 更新 ID=%s, 数据=%+v", id, req)
// 检查是否存在
_, err := s.repo.FindByID(id)
if err != nil {
log.Printf("[Update] 不存在: %v", err)
return nil, ErrDatabaseNotFound
}
// 构建更新数据
updates := map[string]interface{}{}
if req.Name != "" {
updates["name"] = req.Name
}
if req.Description != "" {
updates["description"] = req.Description
}
if req.DBType != "" {
updates["db_type"] = req.DBType
}
if req.Host != "" {
updates["host"] = req.Host
}
if req.Port > 0 {
updates["port"] = req.Port
}
if req.Username != "" {
updates["username"] = req.Username
}
if req.Password != "" {
updates["password"] = req.Password
}
if req.Database != "" {
updates["database"] = req.Database
}
if req.TableCount > 0 {
updates["table_count"] = req.TableCount
}
if req.Charset != "" {
updates["charset"] = req.Charset
}
if req.SSLMode != "" {
updates["ssl_mode"] = req.SSLMode
}
// 更新数据库基本信息
info := &model.DatabaseInfo{}
if err := s.repo.Update(id, info); err != nil {
log.Printf("[Update] 更新失败: %v", err)
return nil, err
}
// 处理 SubTables - 创建或更新子表记录(包括 DDL
// 先查询当前已有的子表
existingTables, err := s.subTableRepo.FindByDatabaseID(id)
if err != nil {
log.Printf("[Update] 查询子表失败: %v", err)
}
if len(req.SubTables) > 0 {
log.Printf("[Update] 处理 %d 个子表配置", len(req.SubTables))
for _, subTableReq := range req.SubTables {
subTableReq.DatabaseID = id
// 检查是否已存在(根据 parent_table 查找)
found := false
for _, existing := range existingTables {
if existing.ParentTable == subTableReq.ParentTable {
// 存在则更新
log.Printf("[Update] 更新子表: %s", existing.ID)
err := s.subTableRepo.Update(existing.ID, &model.SubTableInfo{
ParentTable: subTableReq.ParentTable,
SubTableName: subTableReq.SubTableName,
SubTableComment: subTableReq.SubTableComment,
MappingType: subTableReq.MappingType,
RelationField: subTableReq.RelationField,
RelationType: subTableReq.RelationType,
DDL: subTableReq.DDL,
})
if err != nil {
log.Printf("[Update] 更新子表失败: %v", err)
}
found = true
break
}
}
if !found {
// 不存在则创建
log.Printf("[Update] 创建子表: %s", subTableReq.ParentTable)
err := s.subTableRepo.Create(&model.SubTableInfo{
ID: uuid.New().String(),
DatabaseID: id,
ParentTable: subTableReq.ParentTable,
SubTableName: subTableReq.SubTableName,
SubTableComment: subTableReq.SubTableComment,
MappingType: subTableReq.MappingType,
RelationField: subTableReq.RelationField,
RelationType: subTableReq.RelationType,
DDL: subTableReq.DDL,
})
if err != nil {
log.Printf("[Update] 创建子表失败: %v", err)
}
}
}
}
// 删除不在新列表中的子表(无论 req.SubTables 是否为空都执行)
if existingTables != nil && len(existingTables) > 0 {
// 重新查询最新的子表列表
allSubTables, err := s.subTableRepo.FindByDatabaseID(id)
if err != nil {
log.Printf("[Update] 查询子表失败: %v", err)
} else {
// 构建新请求中的 parent_table 集合
newParentTables := make(map[string]bool)
for _, st := range req.SubTables {
newParentTables[st.ParentTable] = true
}
// 删除不存在的子表
for _, existing := range allSubTables {
if !newParentTables[existing.ParentTable] {
log.Printf("[Update] 删除子表: %s (不在新列表中)", existing.ID)
if err := s.subTableRepo.Delete(existing.ID); err != nil {
log.Printf("[Update] 删除子表失败: %v", err)
}
}
}
}
}
// 始终更新 table_count 为当前子表数量
allSubTables, err := s.subTableRepo.FindByDatabaseID(id)
if err != nil {
log.Printf("[Update] 查询子表数量失败: %v", err)
} else {
tableCount := len(allSubTables)
log.Printf("[Update] 更新 table_count 为: %d", tableCount)
if err := s.repo.UpdateFields(id, map[string]interface{}{"table_count": tableCount}); err != nil {
log.Printf("[Update] 更新 table_count 失败: %v", err)
}
}
return s.repo.FindByID(id)
}
// SaveGraph 保存图谱信息
func (s *DatabaseService) SaveGraph(req model.SaveGraphRequest) (*model.SaveGraphResponse, error) {
log.Printf("[SaveGraph] 保存图谱信息, databaseId=%s, databaseName=%s", req.DatabaseID, req.DatabaseName)
// 检查数据库是否存在
_, err := s.repo.FindByID(req.DatabaseID)
if err != nil {
// 如果不存在,创建一个新的
if err == ErrDatabaseNotFound {
// 创建新的数据库记录
dbType := "neo4j"
// 从 URI 解析 host 和 port
host := "localhost"
port := 7687
if req.URI != "" {
uri := strings.TrimPrefix(req.URI, "bolt://")
uri = strings.TrimPrefix(uri, "neo4j://")
if idx := strings.Index(uri, ":"); idx > 0 {
host = uri[:idx]
fmt.Sscanf(uri[idx+1:], "%d", &port)
}
}
// 将 labels 和 relationshipTypes 转为 JSON 字符串
labelsJSON, _ := json.Marshal(req.Labels)
relJSON, _ := json.Marshal(req.RelationshipTypes)
info := &model.DatabaseInfo{
ID: req.DatabaseID,
Name: req.DatabaseName,
DBType: dbType,
Host: host,
Port: port,
Username: req.Username,
URI: req.URI,
GraphLabels: string(labelsJSON),
GraphRelationship: string(relJSON),
SelectedLabel: req.SelectedLabel,
}
if err := s.repo.Create(info); err != nil {
log.Printf("[SaveGraph] 创建失败: %v", err)
return &model.SaveGraphResponse{
Success: false,
Message: fmt.Sprintf("创建失败: %v", err),
}, err
}
return &model.SaveGraphResponse{
Success: true,
Message: "保存成功",
}, nil
}
log.Printf("[SaveGraph] 查询失败: %v", err)
return &model.SaveGraphResponse{
Success: false,
Message: fmt.Sprintf("查询失败: %v", err),
}, err
}
// 更新现有记录
labelsJSON, _ := json.Marshal(req.Labels)
relJSON, _ := json.Marshal(req.RelationshipTypes)
updates := map[string]interface{}{
"uri": req.URI,
"username": req.Username,
"graph_labels": string(labelsJSON),
"graph_relationship": string(relJSON),
"selected_label": req.SelectedLabel,
}
if err := s.repo.UpdateFields(req.DatabaseID, updates); err != nil {
log.Printf("[SaveGraph] 更新失败: %v", err)
return &model.SaveGraphResponse{
Success: false,
Message: fmt.Sprintf("更新失败: %v", err),
}, err
}
return &model.SaveGraphResponse{
Success: true,
Message: "保存成功",
}, nil
}
// fillFieldMappings 填充字段映射到表结构中
func (s *DatabaseService) fillFieldMappings(databaseID string, tables []model.TableDDLInfo) {
// 从数据库中获取该数据库下所有子表的字段映射
subTables, err := s.subTableRepo.FindByDatabaseID(databaseID)
if err != nil {
log.Printf("[fillFieldMappings] 查询子表失败: %v", err)
return
}
// 构建表名到字段映射的映射
tableFieldsMap := make(map[string][]model.FieldMapping)
for _, st := range subTables {
fields := st.GetFields()
if len(fields) > 0 {
tableFieldsMap[st.ParentTable] = fields
}
}
// 遍历返回的表结构,填充字段映射
for i := range tables {
tableName := tables[i].TableName
if fields, ok := tableFieldsMap[tableName]; ok {
// 构建列名到映射名的映射
columnMap := make(map[string]string)
for _, f := range fields {
columnMap[f.ColumnName] = f.MappedName
}
// 填充到每个列
for j := range tables[i].Columns {
colName := tables[i].Columns[j].ColumnName
if mappedName, ok := columnMap[colName]; ok {
tables[i].Columns[j].MappedName = mappedName
}
}
}
}
log.Printf("[fillFieldMappings] 已填充字段映射到 %d 个表", len(tables))
}
// fillDDL 填充已保存的 DDL 到表结构中
func (s *DatabaseService) fillDDL(databaseID string, tables []model.TableDDLInfo) {
// 从数据库中获取该数据库下所有子表的 DDL
subTables, err := s.subTableRepo.FindByDatabaseID(databaseID)
if err != nil {
log.Printf("[fillDDL] 查询子表失败: %v", err)
return
}
// 构建表名到 DDL 的映射
tableDDLMap := make(map[string]string)
for _, st := range subTables {
if st.DDL != "" {
tableDDLMap[st.ParentTable] = st.DDL
}
}
// 遍历返回的表结构,填充 DDL
for i := range tables {
tableName := tables[i].TableName
if ddl, ok := tableDDLMap[tableName]; ok {
tables[i].DDL = ddl
}
}
log.Printf("[fillDDL] 已填充 DDL 到 %d 个表", len(tables))
}
// Delete 删除
func (s *DatabaseService) Delete(id string) error {
log.Printf("[Delete] 删除 ID=%s", id)
_, err := s.repo.FindByID(id)
if err != nil {
log.Printf("[Delete] 不存在: %v", err)
return ErrDatabaseNotFound
}
// 先删除关联的子表记录
if err := s.subTableRepo.DeleteByDatabaseID(id); err != nil {
log.Printf("[Delete] 删除子表失败: %v", err)
// 继续尝试删除主表
}
return s.repo.Delete(id)
}