package service import ( "database/sql" "encoding/json" "fmt" "log" "os" "path/filepath" "strings" "time" "x-agents/server/internal/model" "x-agents/server/internal/repository" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" "github.com/google/uuid" ) type SubTableService struct { repo *repository.SubTableRepository dbRepo *repository.DatabaseRepository resourceDir string } func NewSubTableService(repo *repository.SubTableRepository, dbRepo *repository.DatabaseRepository) *SubTableService { return &SubTableService{ repo: repo, dbRepo: dbRepo, resourceDir: "resources/db_info", } } // ensureDir 确保目录存在 func (s *SubTableService) ensureDir() error { return os.MkdirAll(s.resourceDir, 0755) } // getFilePath 获取文件路径 func (s *SubTableService) getFilePath(databaseID string) string { return filepath.Join(s.resourceDir, fmt.Sprintf("%s.json", databaseID)) } // saveToFile 保存到文件 func (s *SubTableService) saveToFile(databaseID string, mapping *model.SubTableMapping) error { if err := s.ensureDir(); err != nil { return err } data, err := json.MarshalIndent(mapping, "", " ") if err != nil { return err } return os.WriteFile(s.getFilePath(databaseID), data, 0644) } // loadFromFile 从文件加载 func (s *SubTableService) loadFromFile(databaseID string) (*model.SubTableMapping, error) { filePath := s.getFilePath(databaseID) data, err := os.ReadFile(filePath) if err != nil { if os.IsNotExist(err) { return nil, nil } return nil, err } var mapping model.SubTableMapping if err := json.Unmarshal(data, &mapping); err != nil { return nil, err } return &mapping, nil } // syncToFile 同步到文件 func (s *SubTableService) syncToFile(databaseID string) error { // 获取数据库信息 dbInfo, err := s.dbRepo.FindByID(databaseID) if err != nil { return err } // 获取所有子表信息 tables, err := s.repo.FindByDatabaseID(databaseID) if err != nil { return err } mapping := &model.SubTableMapping{ DatabaseID: databaseID, DatabaseName: dbInfo.Name, DBType: dbInfo.DBType, Tables: tables, UpdatedAt: time.Now(), } return s.saveToFile(databaseID, mapping) } // Create 创建子表信息 func (s *SubTableService) Create(req model.CreateSubTableRequest) (*model.SubTableInfo, error) { log.Printf("[SubTable Create] 收到请求: %+v", req) // 验证数据库是否存在 _, err := s.dbRepo.FindByID(req.DatabaseID) if err != nil { log.Printf("[SubTable Create] 数据库不存在: %v", err) return nil, fmt.Errorf("database not found") } info := &model.SubTableInfo{ ID: uuid.New().String(), DatabaseID: req.DatabaseID, ParentTable: req.ParentTable, SubTableName: req.SubTableName, SubTableComment: req.SubTableComment, MappingType: req.MappingType, RelationField: req.RelationField, RelationType: req.RelationType, } if err := s.repo.Create(info); err != nil { log.Printf("[SubTable Create] 创建失败: %v", err) return nil, err } // 同步到文件 if err := s.syncToFile(req.DatabaseID); err != nil { log.Printf("[SubTable Create] 同步文件失败: %v", err) } log.Printf("[SubTable Create] 创建成功, ID=%s", info.ID) return info, nil } // GetByID 获取详情 func (s *SubTableService) GetByID(id string) (*model.SubTableInfo, error) { log.Printf("[SubTable GetByID] 查询 ID=%s", id) info, err := s.repo.FindByID(id) if err != nil { log.Printf("[SubTable GetByID] 查询失败: %v", err) return nil, fmt.Errorf("sub table not found") } return info, nil } // ListByDatabaseID 获取数据库下所有子表 func (s *SubTableService) ListByDatabaseID(databaseID string) ([]model.SubTableInfo, error) { log.Printf("[SubTable ListByDatabaseID] 查询数据库ID=%s", databaseID) tables, err := s.repo.FindByDatabaseID(databaseID) if err != nil { return nil, err } // 填充 FieldsList 字段 for i := range tables { tables[i].FieldsList = tables[i].GetFields() } return tables, nil } // GetMappingFromFile 从文件获取映射信息 func (s *SubTableService) GetMappingFromFile(databaseID string) (*model.SubTableMapping, error) { log.Printf("[SubTable GetMappingFromFile] 读取文件, databaseID=%s", databaseID) return s.loadFromFile(databaseID) } // Update 更新 func (s *SubTableService) Update(id string, req model.UpdateSubTableRequest) (*model.SubTableInfo, error) { log.Printf("[SubTable Update] 更新 ID=%s, 数据=%+v", id, req) info, err := s.repo.FindByID(id) if err != nil { log.Printf("[SubTable Update] 不存在: %v", err) return nil, fmt.Errorf("sub table not found") } if req.ParentTable != "" { info.ParentTable = req.ParentTable } if req.SubTableName != "" { info.SubTableName = req.SubTableName } if req.SubTableComment != "" { info.SubTableComment = req.SubTableComment } if req.MappingType != "" { info.MappingType = req.MappingType } if req.RelationField != "" { info.RelationField = req.RelationField } if req.RelationType != "" { info.RelationType = req.RelationType } if err := s.repo.Update(id, info); err != nil { log.Printf("[SubTable Update] 更新失败: %v", err) return nil, err } // 同步到文件 if err := s.syncToFile(info.DatabaseID); err != nil { log.Printf("[SubTable Update] 同步文件失败: %v", err) } return info, nil } // Delete 删除 func (s *SubTableService) Delete(id string) error { log.Printf("[SubTable Delete] 删除 ID=%s", id) info, err := s.repo.FindByID(id) if err != nil { log.Printf("[SubTable Delete] 不存在: %v", err) return fmt.Errorf("sub table not found") } databaseID := info.DatabaseID if err := s.repo.Delete(id); err != nil { log.Printf("[SubTable Delete] 删除失败: %v", err) return err } // 同步到文件 if err := s.syncToFile(databaseID); err != nil { log.Printf("[SubTable Delete] 同步文件失败: %v", err) } return nil } // GetTableDDLFromDatabase 从实际数据库获取表结构和DDL func (s *SubTableService) GetTableDDLFromDatabase(databaseID string) ([]model.TableDDLInfo, error) { log.Printf("[GetTableDDLFromDatabase] 获取数据库ID=%s 的表结构", databaseID) // 获取数据库连接信息 dbInfo, err := s.dbRepo.FindByID(databaseID) if err != nil { log.Printf("[GetTableDDLFromDatabase] 数据库不存在: %v", err) return nil, fmt.Errorf("database not found") } // 构建连接 dsn := s.buildDSN(dbInfo) dbType := strings.ToLower(dbInfo.DBType) var db *sql.DB switch dbType { case "mysql": db, err = sql.Open("mysql", dsn) case "postgres", "postgresql": db, err = sql.Open("postgres", dsn) default: return nil, fmt.Errorf("unsupported database type: %s", dbInfo.DBType) } if err != nil { return nil, fmt.Errorf("failed to connect: %v", err) } defer db.Close() // 获取所有表 var tables []model.TableDDLInfo switch dbType { case "mysql": tables, err = s.getMySQLTables(db, dbInfo.Database) case "postgres", "postgresql": tables, err = s.getPostgresTables(db, dbInfo.Database) } if err != nil { return nil, err } log.Printf("[GetTableDDLFromDatabase] 获取到 %d 个表", len(tables)) return tables, nil } // buildDSN 构建数据库连接字符串 func (s *SubTableService) buildDSN(info *model.DatabaseInfo) string { dbType := strings.ToLower(info.DBType) switch dbType { case "mysql": charset := info.Charset if charset == "" { charset = "utf8mb4" } dbName := info.Database if dbName == "" { dbName = "mysql" } return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&timeout=5s&parseTime=True", info.Username, info.Password, info.Host, info.Port, dbName, charset) case "postgres", "postgresql": sslmode := "disable" if info.SSLMode != "" { sslmode = info.SSLMode } return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s connect_timeout=5", info.Host, info.Port, info.Username, info.Password, info.Database, sslmode) } return "" } // getMySQLTables 获取MySQL表结构 func (s *SubTableService) getMySQLTables(db *sql.DB, dbName string) ([]model.TableDDLInfo, error) { // 获取所有表名和注释 rows, err := db.Query(` SELECT TABLE_NAME, TABLE_COMMENT FROM information_schema.TABLES WHERE TABLE_SCHEMA = ? AND TABLE_TYPE = 'BASE TABLE' `, dbName) if err != nil { return nil, err } defer rows.Close() var tables []model.TableDDLInfo for rows.Next() { var tableName, tableComment string if err := rows.Scan(&tableName, &tableComment); err != nil { continue } table := model.TableDDLInfo{ TableName: tableName, TableComment: tableComment, } // 获取列信息 table.Columns, _ = s.getMySQLColumns(db, dbName, tableName) // 获取索引信息 table.Indexes, _ = s.getMySQLIndexes(db, dbName, tableName) // 生成DDL table.DDL = s.generateMySQLDDL(table) tables = append(tables, table) } return tables, nil } // getMySQLColumns 获取MySQL列信息 func (s *SubTableService) getMySQLColumns(db *sql.DB, dbName, tableName string) ([]model.ColumnInfo, error) { rows, err := db.Query(` SELECT COLUMN_NAME, DATA_TYPE, COLUMN_TYPE, IS_NULLABLE, COLUMN_DEFAULT, COLUMN_KEY, EXTRA, COLUMN_COMMENT FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? ORDER BY ORDINAL_POSITION `, dbName, tableName) if err != nil { return nil, err } defer rows.Close() var columns []model.ColumnInfo for rows.Next() { var col model.ColumnInfo if err := rows.Scan(&col.ColumnName, &col.DataType, &col.ColumnType, &col.IsNullable, &col.DefaultValue, &col.ColumnKey, &col.Extra, &col.ColumnComment); err != nil { continue } columns = append(columns, col) } return columns, nil } // getMySQLIndexes 获取MySQL索引信息 func (s *SubTableService) getMySQLIndexes(db *sql.DB, dbName, tableName string) ([]model.IndexInfo, error) { rows, err := db.Query(` SELECT INDEX_NAME, COLUMN_NAME, NON_UNIQUE, INDEX_TYPE FROM information_schema.STATISTICS WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? ORDER BY SEQ_IN_INDEX `, dbName, tableName) if err != nil { return nil, err } defer rows.Close() var indexes []model.IndexInfo for rows.Next() { var idx model.IndexInfo if err := rows.Scan(&idx.IndexName, &idx.ColumnName, &idx.NonUnique, &idx.IndexType); err != nil { continue } indexes = append(indexes, idx) } return indexes, nil } // generateMySQLDDL 生成MySQL DDL func (s *SubTableService) generateMySQLDDL(table model.TableDDLInfo) string { var sb strings.Builder sb.WriteString(fmt.Sprintf("CREATE TABLE `%s` (\n", table.TableName)) for i, col := range table.Columns { sb.WriteString(fmt.Sprintf(" `%s` %s", col.ColumnName, col.ColumnType)) if col.IsNullable == "NO" { sb.WriteString(" NOT NULL") } if col.DefaultValue != "" { sb.WriteString(fmt.Sprintf(" DEFAULT %s", col.DefaultValue)) } if col.Extra == "auto_increment" { sb.WriteString(" AUTO_INCREMENT") } if col.ColumnComment != "" { sb.WriteString(fmt.Sprintf(" COMMENT '%s'", col.ColumnComment)) } if i < len(table.Columns)-1 { sb.WriteString(",") } sb.WriteString("\n") } // 添加主键 var primaryKeys []string for _, idx := range table.Indexes { if idx.IndexName == "PRIMARY" { primaryKeys = append(primaryKeys, fmt.Sprintf("`%s`", idx.ColumnName)) } } if len(primaryKeys) > 0 { sb.WriteString(fmt.Sprintf(" PRIMARY KEY (%s)\n", strings.Join(primaryKeys, ", "))) } // 添加索引 var addedIndexes []string for _, idx := range table.Indexes { if idx.IndexName != "PRIMARY" { unique := "" if idx.NonUnique == 0 { unique = "UNIQUE " } if !contains(addedIndexes, idx.IndexName) { sb.WriteString(fmt.Sprintf(" %sKEY `%s` (`%s`),\n", unique, idx.IndexName, idx.ColumnName)) addedIndexes = append(addedIndexes, idx.IndexName) } } } ddl := sb.String() ddl = strings.TrimSuffix(ddl, ",\n") ddl += "\n)" if table.TableComment != "" { ddl += fmt.Sprintf(" COMMENT='%s'", table.TableComment) } ddl += ";\n" return ddl } // contains 检查切片是否包含元素 func contains(slice []string, item string) bool { for _, s := range slice { if s == item { return true } } return false } // getPostgresTables 获取PostgreSQL表结构 func (s *SubTableService) getPostgresTables(db *sql.DB, dbName string) ([]model.TableDDLInfo, error) { rows, err := db.Query(` SELECT t.table_name, obj_description((t.table_schema || '.' || t.table_name)::regclass) FROM information_schema.tables t WHERE t.table_schema = 'public' AND t.table_type = 'BASE TABLE' `, dbName) if err != nil { return nil, err } defer rows.Close() var tables []model.TableDDLInfo for rows.Next() { var tableName, tableComment string if err := rows.Scan(&tableName, &tableComment); err != nil { continue } table := model.TableDDLInfo{ TableName: tableName, TableComment: tableComment, } // 获取列信息 table.Columns, _ = s.getPostgresColumns(db, tableName) // 获取索引信息 table.Indexes, _ = s.getPostgresIndexes(db, tableName) // 生成DDL table.DDL = s.generatePostgresDDL(table) tables = append(tables, table) } return tables, nil } // getPostgresColumns 获取PostgreSQL列信息 func (s *SubTableService) getPostgresColumns(db *sql.DB, tableName string) ([]model.ColumnInfo, error) { rows, err := db.Query(` SELECT c.column_name, c.data_type, c.udt_name, c.is_nullable, c.column_default, c.column_name, '', c.column_comment FROM information_schema.columns c WHERE c.table_name = $1 AND c.table_schema = 'public' ORDER BY c.ordinal_position `, tableName) if err != nil { return nil, err } defer rows.Close() var columns []model.ColumnInfo for rows.Next() { var col model.ColumnInfo if err := rows.Scan(&col.ColumnName, &col.DataType, &col.ColumnType, &col.IsNullable, &col.DefaultValue, &col.ColumnKey, &col.Extra, &col.ColumnComment); err != nil { continue } columns = append(columns, col) } return columns, nil } // getPostgresIndexes 获取PostgreSQL索引信息 func (s *SubTableService) getPostgresIndexes(db *sql.DB, tableName string) ([]model.IndexInfo, error) { rows, err := db.Query(` SELECT indexname, indexdef FROM pg_indexes WHERE tablename = $1 AND schemaname = 'public' `, tableName) if err != nil { return nil, err } defer rows.Close() var indexes []model.IndexInfo for rows.Next() { var idx model.IndexInfo var indexDef string if err := rows.Scan(&idx.IndexName, &indexDef); err != nil { continue } idx.NonUnique = 1 if strings.Contains(indexDef, "UNIQUE") { idx.NonUnique = 0 } indexes = append(indexes, idx) } return indexes, nil } // generatePostgresDDL 生成PostgreSQL DDL func (s *SubTableService) generatePostgresDDL(table model.TableDDLInfo) string { var sb strings.Builder sb.WriteString(fmt.Sprintf("CREATE TABLE %s (\n", table.TableName)) for i, col := range table.Columns { sb.WriteString(fmt.Sprintf(" %s %s", col.ColumnName, col.ColumnType)) if col.IsNullable == "NO" { sb.WriteString(" NOT NULL") } if col.DefaultValue != "" { sb.WriteString(fmt.Sprintf(" DEFAULT %s", col.DefaultValue)) } if i < len(table.Columns)-1 { sb.WriteString(",") } sb.WriteString("\n") } ddl := sb.String() ddl = strings.TrimSuffix(ddl, ",\n") ddl += "\n);" return ddl }