From e494c4ce50c42478e53ae9f20c84b3f6b84bd806 Mon Sep 17 00:00:00 2001 From: "WIN-JHFT4D3SIVT\\caoxiaozhu" Date: Thu, 29 Jan 2026 15:51:45 +0800 Subject: [PATCH] =?UTF-8?q?1.=20=E4=BF=AE=E6=94=B9=E4=BA=86=E4=B8=80?= =?UTF-8?q?=E4=BA=9Bbug=202.=20=E5=81=9A=E4=BA=86=E4=B8=80=E4=BA=9B?= =?UTF-8?q?=E8=B0=83=E6=95=B4=EF=BC=8C=E6=AF=94=E5=A6=82=E5=90=AF=E5=8A=A8?= =?UTF-8?q?=E8=84=9A=E6=9C=AC=EF=BC=8C=E6=94=AF=E6=8C=81=E4=BA=86tenmsorbo?= =?UTF-8?q?ard?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api/fine_tune.py | 27 +- src/api/model_manage.py | 91 +++-- src/main.py | 1 + start_all.sh | 49 ++- web/pages/components/sidebar.html | 2 +- web/pages/custom-tool-create.html | 31 +- web/pages/dataset-create.html | 30 +- web/pages/dataset-preview.html | 28 +- web/pages/fine-tune-create.html | 39 +- web/pages/login.html | 28 +- web/pages/main.html | 30 +- web/pages/model-compare-chat.html | 30 +- web/pages/model-compare-create.html | 28 +- web/pages/model-compare-result.html | 30 +- web/pages/model-dimension-create.html | 31 +- web/pages/model-eval-create.html | 30 +- web/pages/model-eval.html | 28 +- web/pages/model-manage-create.html | 30 +- web/pages/model-manage.html | 226 ++++++++++-- web/pages/training-log.html | 493 ++++++++++++++++++++------ 20 files changed, 995 insertions(+), 287 deletions(-) diff --git a/src/api/fine_tune.py b/src/api/fine_tune.py index dc88c29..465051b 100644 --- a/src/api/fine_tune.py +++ b/src/api/fine_tune.py @@ -167,7 +167,9 @@ def start_training(): train_method = data.get('train_method', 'lora') # 创建输出目录(如果不存在) - output_model_name = data.get('output_model_name', f"{template}/{train_method}") + # 路径格式: /app/base/saves/{train_method}/{output_model_name} + output_model_name = data.get('output_model_name', template) + output_model_name = f"{train_method}/{output_model_name}" if not output_model_name.startswith('/'): output_model_name = f"/app/base/saves/{output_model_name}" output_dir = output_model_name @@ -278,7 +280,9 @@ def build_train_command(data, model_path, dataset_name=None): cmd.extend(['--finetuning_type', FINETUNING_TYPE_MAP.get(train_method, 'lora')]) # 输出目录(确保是绝对路径) - output_model_name = data.get('output_model_name', f"{template}/{train_method}") + # 路径格式: /app/base/saves/{train_method}/{output_model_name} + output_model_name = data.get('output_model_name', template) + output_model_name = f"{train_method}/{output_model_name}" if not output_model_name.startswith('/'): output_model_name = f"/app/base/saves/{output_model_name}" output_dir = output_model_name @@ -417,7 +421,12 @@ def delete_training_task(task_id): # 获取任务信息(用于删除日志文件) conn = get_db_connection() cursor = conn.cursor() - cursor.execute("SELECT name, process_id FROM fine_tune WHERE id = %s", (task_id,)) + # 尝试获取所有字段,如果tensorboard_log_dir不存在会报错 + try: + cursor.execute("SELECT name, process_id, tensorboard_log_dir FROM fine_tune WHERE id = %s", (task_id,)) + except: + # 如果列不存在,只获取基本字段 + cursor.execute("SELECT name, process_id FROM fine_tune WHERE id = %s", (task_id,)) task_result = cursor.fetchone() conn.close() @@ -425,6 +434,7 @@ def delete_training_task(task_id): return jsonify({'code': 1, 'message': '任务不存在'}) task_name = task_result.get('name', 'unknown') + tensorboard_log_dir = task_result.get('tensorboard_log_dir', '') if 'tensorboard_log_dir' in task_result else '' # 删除日志文件 (logs/{日期}/{task_id}_{task_name}.log) try: @@ -444,6 +454,17 @@ def delete_training_task(task_id): except Exception as log_err: logger.warning(f"删除日志文件失败: {log_err}") + # 删除 TensorBoard 进程(如果存在) + global tensorboard_process + if tensorboard_process and tensorboard_process.poll() is None: + try: + import signal + os.killpg(os.getpgid(tensorboard_process.pid), signal.SIGTERM) + tensorboard_process = None + logger.info(f"已停止 TensorBoard 进程") + except Exception as tb_err: + logger.warning(f"停止 TensorBoard 失败: {tb_err}") + # 删除数据库中的任务记录 conn = get_db_connection() cursor = conn.cursor() diff --git a/src/api/model_manage.py b/src/api/model_manage.py index ac3fe62..3d7fb1e 100644 --- a/src/api/model_manage.py +++ b/src/api/model_manage.py @@ -204,44 +204,67 @@ def get_trained_models(): logger = logging.getLogger(__name__) try: - # 使用 /app/base/saves 目录(容器内路径) - saves_base_path = '/app/base/saves' - # 本地开发时的备用路径 - local_saves_path = os.path.join(PROJECT_ROOT, 'saves') + # 多个可能的路径 + potential_paths = [ + '/app/base/saves', # 容器内路径 + os.path.join(PROJECT_ROOT, 'saves'), # 本地开发路径 + os.path.join(os.path.dirname(os.path.dirname(PROJECT_ROOT)), 'YG_FT_Base', 'saves'), # 上级目录 + ] - # 选择存在的路径 - base_path = saves_base_path if os.path.exists(saves_base_path) else local_saves_path + base_path = None + for path in potential_paths: + logger.info(f"[DEBUG] 检查路径: {path}, exists: {os.path.exists(path)}") + if os.path.exists(path): + base_path = path + break - logger.info(f"[DEBUG] 已训练模型目录: {base_path}, exists: {os.path.exists(base_path)}") + logger.info(f"[DEBUG] 最终使用的路径: {base_path}") models = [] - if os.path.exists(base_path): - for item in os.listdir(base_path): - item_path = os.path.join(base_path, item) - if os.path.isdir(item_path): - # 检查是否是模板目录(包含训练方法的子目录) - sub_items = [] - if os.path.exists(item_path): - for sub_item in os.listdir(item_path): - sub_path = os.path.join(item_path, sub_item) - if os.path.isdir(sub_path): - # 检查是否包含模型文件(adapter_model.bin 或 pytorch_model.bin 等) - has_model = False - for f in os.listdir(sub_path): - if f.endswith('.bin') or f.endswith('.safetensors'): - has_model = True - break - if has_model: - sub_items.append({ - 'name': sub_item, - 'path': sub_path - }) + if base_path and os.path.exists(base_path): + logger.info(f"[DEBUG] 遍历目录: {base_path}") + try: + # 路径结构: /app/base/saves/{train_method}/{model_name}/ + # train_method: lora, full, qlora, dpo, cpt 等 - models.append({ - 'name': item, - 'path': item_path, - 'train_methods': sub_items - }) + for train_method in os.listdir(base_path): + train_method_path = os.path.join(base_path, train_method) + if not os.path.isdir(train_method_path): + continue + + logger.info(f"[DEBUG] 检查训练方法目录: {train_method}") + model_count = 0 + + # 遍历模型文件夹 + for model_name in os.listdir(train_method_path): + model_path = os.path.join(train_method_path, model_name) + if not os.path.isdir(model_path): + continue + + # 检查是否有模型文件 + try: + files = os.listdir(model_path) + logger.info(f"[DEBUG] {train_method}/{model_name} 文件: {files[:5]}...") + has_model = any(f.endswith('.bin') or f.endswith('.safetensors') for f in files) + + if has_model: + logger.info(f"[DEBUG] 找到模型: {train_method}/{model_name}") + models.append({ + 'name': model_name, + 'path': model_path, + 'train_methods': [{ + 'name': train_method, + 'path': model_path + }] + }) + model_count += 1 + except Exception as file_err: + logger.error(f"[DEBUG] 读取 {model_path} 失败: {file_err}") + + logger.info(f"[DEBUG] {train_method} 找到 {model_count} 个模型") + + except Exception as list_err: + logger.error(f"[DEBUG] 遍历目录失败: {list_err}") logger.info(f"[DEBUG] 找到 {len(models)} 个已训练模型") @@ -249,7 +272,7 @@ def get_trained_models(): 'code': 0, 'data': { 'models': models, - 'base_path': base_path + 'base_path': base_path or '' } }) except Exception as e: diff --git a/src/main.py b/src/main.py index b916c68..c4abbd9 100644 --- a/src/main.py +++ b/src/main.py @@ -165,6 +165,7 @@ def init_database(): valid_ratio INT DEFAULT 10, output_model_name VARCHAR(255), process_id INT COMMENT '训练进程ID', + tensorboard_log_dir VARCHAR(255) COMMENT 'TensorBoard日志目录', status VARCHAR(50) DEFAULT 'pending', progress INT DEFAULT 0, create_time DATETIME DEFAULT CURRENT_TIMESTAMP, diff --git a/start_all.sh b/start_all.sh index 58f54af..9577177 100644 --- a/start_all.sh +++ b/start_all.sh @@ -1,6 +1,6 @@ #!/bin/bash # YG_FT_Base 统一启动脚本 -# 同时启动后端 API 服务和 Web 静态服务器 +# 同时启动后端 API 服务、Web 静态服务器和 TensorBoard # 使用方法: bash start_all.sh # 自动修复脚本换行符 @@ -34,6 +34,7 @@ fi echo "📦 端口配置:" echo " - 后端 API: $API_PORT" echo " - Web 服务器: $WEB_PORT" +echo " - TensorBoard: 6006" echo "" # 检查端口是否已被占用 @@ -70,6 +71,36 @@ start_api() { echo "$API_PID" > /tmp/ygft_api.pid } +# 启动 TensorBoard 服务 +start_tensorboard() { + echo "" + echo "🚀 启动 TensorBoard 服务..." + + # 检查端口 + if ! check_port 6006; then + echo "⚠️ 端口 6006 已被占用,TensorBoard 可能已在运行" + return 0 + fi + + # 确保日志目录存在 + LOG_DIR="/app/base/saves" + if [ ! -d "$LOG_DIR" ]; then + LOG_DIR="$SCRIPT_DIR/saves" + fi + + if [ ! -d "$LOG_DIR" ]; then + echo "⚠️ 日志目录不存在,跳过 TensorBoard 启动" + return 0 + fi + + # 启动 TensorBoard(后台运行) + nohup tensorboard --logdir "$LOG_DIR" --port 6006 --bind_all > "$LOG_DIR/tensorboard.log" 2>&1 & + TB_PID=$! + echo "✅ TensorBoard 服务已启动 (PID: $TB_PID, 端口: 6006)" + echo "$TB_PID" > /tmp/ygft_tensorboard.pid + echo "📊 TensorBoard 访问地址: http://localhost:6006" +} + # 启动 Web 静态服务器 start_web() { echo "" @@ -105,9 +136,16 @@ stop_all() { echo "✅ Web 服务已停止" fi + if [ -f /tmp/ygft_tensorboard.pid ]; then + kill $(cat /tmp/ygft_tensorboard.pid) 2>/dev/null + rm /tmp/ygft_tensorboard.pid + echo "✅ TensorBoard 服务已停止" + fi + # 清理可能残留的进程 pkill -f "src/main.py" 2>/dev/null pkill -f "http.server $WEB_PORT" 2>/dev/null + pkill -f "tensorboard.*6006" 2>/dev/null } # 显示状态 @@ -128,16 +166,24 @@ status() { echo "❌ Web 服务: 未运行" fi + if [ -f /tmp/ygft_tensorboard.pid ] && kill -0 $(cat /tmp/ygft_tensorboard.pid) 2>/dev/null; then + echo "✅ TensorBoard: 运行中 (PID: $(cat /tmp/ygft_tensorboard.pid), 端口: 6006)" + else + echo "❌ TensorBoard: 未运行" + fi + echo "" echo "🌐 访问地址:" echo " - 后端 API: http://localhost:$API_PORT" echo " - Web 页面: http://localhost:$WEB_PORT/pages/main.html" + echo " - TensorBoard: http://localhost:6006" } # 主菜单 case "$1" in start) start_api + start_tensorboard start_web echo "" echo "====================================" @@ -153,6 +199,7 @@ case "$1" in stop_all sleep 1 start_api + start_tensorboard start_web echo "" echo "====================================" diff --git a/web/pages/components/sidebar.html b/web/pages/components/sidebar.html index f1a78bc..62bc82a 100644 --- a/web/pages/components/sidebar.html +++ b/web/pages/components/sidebar.html @@ -23,7 +23,7 @@