From e9e0e21e47ceacb7679301d9d8885fd9bed577c5 Mon Sep 17 00:00:00 2001 From: "WIN-JHFT4D3SIVT\\caoxiaozhu" Date: Thu, 29 Jan 2026 10:36:59 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=BC=80=E5=A7=8B=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E7=95=8C=E9=9D=A2=E4=BB=A5=E5=8F=8A=E6=9F=A5=E7=9C=8B?= =?UTF-8?q?=E6=97=A5=E5=BF=97=E5=8A=9F=E8=83=BD=E5=AE=8C=E5=96=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 3 + requirements.txt | 1 + src/api/fine_tune.py | 606 ++++++++++++++++++++++---- src/api/logs.py | 228 ++++++++++ src/api/model_manage.py | 62 +++ src/main.py | 215 +++++++++- start_all.sh | 4 +- web/pages/fine-tune-create.html | 130 ++++-- web/pages/main.html | 541 ++++++++++++++++++++--- web/pages/model-manage.html | 134 +++++- web/pages/training-log.html | 740 ++++++++++++++++++++++++++++++++ 11 files changed, 2485 insertions(+), 179 deletions(-) create mode 100644 web/pages/training-log.html diff --git a/config.yaml b/config.yaml index d67e24a..bc097c4 100644 --- a/config.yaml +++ b/config.yaml @@ -16,3 +16,6 @@ app: # 密钥配置 secret_key: "yg-ft-platform-secret-key-2024" + +# 训练日志路径 +training_logs_path: "/app/base/training_logs" diff --git a/requirements.txt b/requirements.txt index c91f3e3..7f195bf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ requests==2.31.0 psutil==5.9.8 werkzeug==3.0.1 pynvml==11.5.0 +tensorboard>=2.13.0 diff --git a/src/api/fine_tune.py b/src/api/fine_tune.py index 1195d33..dc88c29 100644 --- a/src/api/fine_tune.py +++ b/src/api/fine_tune.py @@ -3,16 +3,31 @@ 调用 llamafactory-cli 执行训练任务 """ import os +import sys import subprocess import json import threading import time +import signal +import yaml from flask import Blueprint, request, jsonify import logging +# 添加项目根目录到路径 +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.insert(0, PROJECT_ROOT) + +# 加载配置 +CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml') +with open(CONFIG_PATH, 'r', encoding='utf-8') as f: + CONFIG = yaml.safe_load(f) + logger = logging.getLogger(__name__) train_logger = logging.getLogger('train') # 专门的训练日志 logger,输出到 train.log +# 从配置获取训练日志路径 +TRAINING_LOGS_DIR = CONFIG.get('training_logs_path', '/app/base/training_logs') + # 创建蓝图 fine_tune_bp = Blueprint('fine_tune', __name__, url_prefix='/api/fine-tune') @@ -72,21 +87,21 @@ def start_training(): train_logger.info(f"[TRAIN] 模型路径: {model_path}") - # 设置工作目录为 llamafactory 目录 - llamafactory_dir = '/app/src/llamafactory' + # 设置工作目录和 llamafactory 目录 + work_dir = '/app/base' + llamafactory_dir = '/app/base' - # 处理数据集文件:将数据集复制到 llamafactory 的 datasets 目录 + # 数据集目录直接使用 /app/base/datasets(不再复制) + datasets_dir = '/app/base/datasets' + + # 获取数据集名称(用于 --dataset 参数) + dataset_key = None dataset_id = data.get('train_dataset_id') try: dataset_id_int = int(dataset_id) if str(dataset_id).isdigit() else None except (ValueError, TypeError): dataset_id_int = None - llamafactory_datasets_dir = os.path.join(llamafactory_dir, 'datasets') - os.makedirs(llamafactory_datasets_dir, exist_ok=True) - - # 获取数据集名称(用于 --dataset 参数) - dataset_key = None if dataset_id_int: from .datasets import get_db_connection as get_dataset_conn conn = get_dataset_conn() @@ -94,43 +109,8 @@ def start_training(): cursor.execute("SELECT dm.name FROM dataset_manage dm WHERE dm.id = %s", (dataset_id_int,)) dataset_result = cursor.fetchone() conn.close() - dataset_key = dataset_result['name'] if dataset_result else None - - if dataset_key: - # 从 dataset_info.json 读取实际文件名 - src_info_json = os.path.join('/app/base', 'datasets', 'dataset_info.json') - actual_file_name = None - if os.path.exists(src_info_json): - import json as json_lib - with open(src_info_json, 'r', encoding='utf-8') as f: - dataset_info = json_lib.load(f) - if dataset_key in dataset_info: - actual_file_name = dataset_info[dataset_key].get('file_name') - train_logger.info(f"[TRAIN] 从 dataset_info.json 获取文件名: {dataset_key} -> {actual_file_name}") - - # 复制数据集文件到 llamafactory 目录 - if actual_file_name: - src_file = os.path.join('/app/base', 'datasets', actual_file_name) - dst_file = os.path.join(llamafactory_datasets_dir, actual_file_name) - if os.path.exists(src_file): - import shutil - shutil.copy2(src_file, dst_file) - train_logger.info(f"[TRAIN] 复制数据集文件: {src_file} -> {dst_file}") - else: - train_logger.warning(f"[TRAIN] 数据集文件不存在: {src_file}") - - # 复制 dataset_info.json 到 llamafactory datasets 目录 - src_info_json = os.path.join('/app/base', 'datasets', 'dataset_info.json') - dst_info_json = os.path.join(llamafactory_datasets_dir, 'dataset_info.json') - try: - if os.path.exists(src_info_json): - shutil.copy2(src_info_json, dst_info_json) - train_logger.info(f"[TRAIN] 已复制 dataset_info.json 到 llamafactory 目录") - else: - train_logger.warning(f"[TRAIN] dataset_info.json 不存在: {src_info_json}") - except Exception as e: - train_logger.warning(f"[TRAIN] 复制 dataset_info.json 失败: {e}") + train_logger.info(f"[TRAIN] 数据集名称: {dataset_key}") # 获取选中的 GPU 索引 gpus = data.get('gpus', []) @@ -145,6 +125,9 @@ def start_training(): env = os.environ.copy() env['CUDA_VISIBLE_DEVICES'] = cuda_devices env['TF_CPP_MIN_LOG_LEVEL'] = '2' # 减少 TensorFlow 日志 + env['LLAMAFACTORY_DIR'] = '/app/base' # 指定 llamafactory 根目录 + env['PYTHONUNBUFFERED'] = '1' # 强制 Python 不缓冲输出,实时写入日志 + env['TRANSFORMERS_VERBOSITY'] = 'INFO' # 设置 transformers 日志级别 # 构建 llamafactory-cli 命令(传入数据集名称用于 --dataset 参数) cmd = build_train_command(data, model_path, dataset_key) @@ -154,57 +137,93 @@ def start_training(): # 在返回的命令中显示 GPU 配置 cmd_str_with_gpu = f"CUDA_VISIBLE_DEVICES={cuda_devices} {cmd_str}" - # 生成训练日志文件路径(按日期分目录) + # 生成训练日志文件路径(存储在 logs 目录下的日期子目录中) from datetime import datetime today = datetime.now().strftime('%Y-%m-%d') - task_id_str = str(data.get('task_id', 'unknown')) - log_dir = os.path.join(llamafactory_dir, 'logs', today) - train_output_log = os.path.join(log_dir, f'train_{task_id_str}.log') + now_str = datetime.now().strftime('%Y%m%d_%H%M%S') # 时间戳用于排序 + task_id = data.get('task_id', 'unknown') + task_name = data.get('name', 'unknown') + # 工作目录设为 /app/base(而非 llamafactory 目录) + work_dir = '/app/base' + # 使用 logs 目录下的日期子目录 + training_logs_dir = os.path.join('/app/base/logs', today) + os.makedirs(training_logs_dir, exist_ok=True) - # 确保日志目录存在 - os.makedirs(log_dir, exist_ok=True) + # 日志文件路径: logs/{日期}/{task_id}_{task_name}.log + log_file = os.path.join(training_logs_dir, f'{task_id}_{task_name}.log') train_logger.info(f"[TRAIN] 启动训练进程...") + # 用于存储实际进程 PID + actual_pid = None + final_log_path = log_file + # 使用线程在后台运行训练进程 def run_training(): - with open(train_output_log, 'w', encoding='utf-8') as log_file: + nonlocal actual_pid, final_log_path + + # 从 data 中获取 template 和 train_method(与 build_train_command 保持一致) + template = data.get('template', 'default') + train_method = data.get('train_method', 'lora') + + # 创建输出目录(如果不存在) + output_model_name = data.get('output_model_name', f"{template}/{train_method}") + if not output_model_name.startswith('/'): + output_model_name = f"/app/base/saves/{output_model_name}" + output_dir = output_model_name + os.makedirs(output_dir, exist_ok=True) + train_logger.info(f"[TRAIN] 输出目录: {output_dir}") + train_logger.info(f"[TRAIN] 完整训练命令: {' '.join(cmd)}") + + with open(log_file, 'w', encoding='utf-8') as f: + # 设置 cwd 为 /app,但通过 LLAMAFACTORY_DIR 环境变量指定 llamafactory 位置 process = subprocess.Popen( cmd, - cwd=llamafactory_dir, - stdout=log_file, + cwd=work_dir, + stdout=f, stderr=subprocess.STDOUT, env=env ) - train_logger.info(f"[TRAIN] 训练进程 PID: {process.pid}") + actual_pid = process.pid + train_logger.info(f"[TRAIN] 训练进程 PID: {actual_pid}") + train_logger.info(f"[TRAIN] 日志文件: {log_file}") + + # 更新数据库中的 PID(立即更新,方便停止任务) + update_fine_tune_status(task_id, 'running', actual_pid) + # 等待进程完成 process.wait() train_logger.info(f"[TRAIN] 训练进程已结束,退出码: {process.returncode}") # 更新任务状态 final_status = 'completed' if process.returncode == 0 else 'failed' - update_fine_tune_status(data.get('task_id'), final_status, process.pid) + update_fine_tune_status(task_id, final_status, actual_pid) # 启动后台线程 training_thread = threading.Thread(target=run_training, daemon=True) training_thread.start() - # 立即返回,不等待进程完成 - pid = None # 此时还不知道实际 PID,稍后可从日志获取 - train_logger.info(f"[TRAIN] 训练任务已在后台启动") - train_logger.info(f"[TRAIN] 训练日志输出到: {train_output_log}") + # 等待 PID 并更新到数据库 + for i in range(10): # 最多等待1秒 + time.sleep(0.1) + if actual_pid: + break - # 更新任务状态为运行中 - update_fine_tune_status(data.get('task_id'), 'running', 0) + # 立即返回,不等待进程完成 + train_logger.info(f"[TRAIN] 训练任务已在后台启动,PID: {actual_pid}") + + train_logger.info(f"[TRAIN] 训练日志输出到: {log_file}") return jsonify({ 'code': 0, 'message': f'训练任务已启动 (GPU: {cuda_devices})', 'data': { - 'task_id': data.get('task_id'), + 'task_id': task_id, + 'pid': actual_pid, 'gpu_ids': cuda_devices, 'command': cmd_str_with_gpu, - 'log_file': train_output_log + 'log_file': log_file, + 'training_logs_dir': training_logs_dir } }) @@ -258,10 +277,11 @@ def build_train_command(data, model_path, dataset_name=None): train_method = data.get('train_method', 'lora') cmd.extend(['--finetuning_type', FINETUNING_TYPE_MAP.get(train_method, 'lora')]) - # 输出目录 - output_dir = data.get('output_model_name', f"./saves/{template}/{train_method}") - if not output_dir.startswith('./'): - output_dir = f"./saves/{output_dir}" + # 输出目录(确保是绝对路径) + output_model_name = data.get('output_model_name', f"{template}/{train_method}") + if not output_model_name.startswith('/'): + output_model_name = f"/app/base/saves/{output_model_name}" + output_dir = output_model_name cmd.extend(['--output_dir', output_dir]) # 常用参数 @@ -274,10 +294,11 @@ def build_train_command(data, model_path, dataset_name=None): '--per_device_eval_batch_size', '1', '--gradient_accumulation_steps', str(data.get('gradient_accumulation_steps', 8)), '--lr_scheduler_type', data.get('lr_scheduler_type', 'cosine'), - '--logging_steps', '50', + '--logging_steps', '5', '--warmup_steps', str(data.get('warmup_steps', 20)), - '--save_steps', '100', - '--eval_steps', str(data.get('eval_steps', 100)), + '--save_steps', str(data.get('save_steps', 100)), + '--log_level', 'info', # 设置日志级别为 info + '--log_level_replica', 'info', # 设置副本日志级别 ]) # 学习率 @@ -295,9 +316,10 @@ def build_train_command(data, model_path, dataset_name=None): if data.get('max_samples'): cmd.extend(['--max_samples', str(data.get('max_samples'))]) + # 启用 TensorBoard 日志(用于可视化训练曲线) + cmd.append('--plot_loss') + # 其他选项 - if data.get('plot_loss'): - cmd.append('--plot_loss') if data.get('fp16'): cmd.append('--fp16') @@ -386,6 +408,57 @@ def stop_training(task_id): return jsonify({'code': 1, 'message': str(e)}) +@fine_tune_bp.route('/', methods=['DELETE']) +def delete_training_task(task_id): + """删除训练任务及对应的日志文件""" + try: + from .model_manage import get_db_connection + + # 获取任务信息(用于删除日志文件) + conn = get_db_connection() + cursor = conn.cursor() + cursor.execute("SELECT name, process_id FROM fine_tune WHERE id = %s", (task_id,)) + task_result = cursor.fetchone() + conn.close() + + if not task_result: + return jsonify({'code': 1, 'message': '任务不存在'}) + + task_name = task_result.get('name', 'unknown') + + # 删除日志文件 (logs/{日期}/{task_id}_{task_name}.log) + try: + from datetime import datetime + today = datetime.now().strftime('%Y-%m-%d') + + # 可能的日志文件路径 + log_paths = [ + f'/app/base/logs/{today}/{task_id}_{task_name}.log', + f'/app/base/logs/{task_id}_{task_name}.log', + ] + + for log_path in log_paths: + if os.path.exists(log_path): + os.remove(log_path) + logger.info(f"已删除日志文件: {log_path}") + except Exception as log_err: + logger.warning(f"删除日志文件失败: {log_err}") + + # 删除数据库中的任务记录 + conn = get_db_connection() + cursor = conn.cursor() + cursor.execute("DELETE FROM fine_tune WHERE id = %s", (task_id,)) + conn.commit() + conn.close() + + logger.info(f"已删除训练任务 {task_id}: {task_name}") + return jsonify({'code': 0, 'message': '删除成功'}) + + except Exception as e: + logger.error(f"删除训练任务失败: {e}") + return jsonify({'code': 1, 'message': str(e)}) + + @fine_tune_bp.route('/status/', methods=['GET']) def get_training_status(task_id): """获取训练任务状态""" @@ -402,12 +475,25 @@ def get_training_status(task_id): conn.close() if result: + # 检查 PID 是否仍在运行 + actual_status = result['status'] + pid = result.get('process_id') + if pid and actual_status == 'running': + try: + # 检查进程是否存在 + os.kill(pid, 0) + # 进程仍在运行 + actual_status = 'running' + except (OSError, ProcessLookupError): + # 进程已结束,尝试更新状态 + actual_status = 'completed' # 假设完成(实际可能失败) + return jsonify({ 'code': 0, 'data': { 'task_id': result['id'], 'name': result['name'], - 'status': result['status'], + 'status': actual_status, 'progress': result['progress'], 'pid': result.get('process_id') } @@ -420,6 +506,254 @@ def get_training_status(task_id): return jsonify({'code': 1, 'message': str(e)}) +@fine_tune_bp.route('/check-pid/', methods=['GET']) +def check_pid_status(pid): + """检查 PID 是否仍在运行""" + try: + if pid <= 0: + return jsonify({ + 'code': 0, + 'data': { + 'exists': False, + 'message': '无效的 PID' + } + }) + + try: + # 发送信号 0 来检查进程是否存在(不会实际终止进程) + os.kill(pid, 0) + return jsonify({ + 'code': 0, + 'data': { + 'exists': True, + 'message': '进程仍在运行' + } + }) + except (OSError, ProcessLookupError): + # 进程不存在 + return jsonify({ + 'code': 0, + 'data': { + 'exists': False, + 'message': '进程已结束' + } + }) + except Exception as e: + logger.error(f"检查 PID 状态失败: {e}") + return jsonify({ + 'code': 0, + 'data': { + 'exists': False, + 'message': f'检查失败: {str(e)}' + } + }) + + +@fine_tune_bp.route('/log/', methods=['GET']) +def get_training_log(task_id): + """获取训练任务日志内容(支持实时读取)""" + try: + from .model_manage import get_db_connection + + # 获取任务信息和进程ID + conn = get_db_connection() + cursor = conn.cursor() + cursor.execute( + "SELECT name, process_id, status FROM fine_tune WHERE id = %s", + (task_id,) + ) + result = cursor.fetchone() + conn.close() + + if not result: + return jsonify({'code': 1, 'message': '任务不存在'}) + + process_id = result.get('process_id') + task_name = result['name'] + status = result['status'] + + if not process_id: + return jsonify({'code': 1, 'message': '任务尚未启动'}) + + # 构建日志文件路径 - 新格式: logs/{日期}/{task_id}_{task_name}.log + from datetime import datetime + today = datetime.now().strftime('%Y-%m-%d') + training_logs_dir = os.path.join('/app/base/logs', today) + + # 查找日志文件 (新格式: {task_id}_{task_name}.log) + log_file = os.path.join(training_logs_dir, f'{task_id}_{task_name}.log') + + if not os.path.exists(log_file): + # 如果没找到,返回空日志 + return jsonify({ + 'code': 0, + 'data': { + 'content': '', + 'status': status, + 'message': '日志文件尚未创建' + } + }) + + # 读取日志文件内容 + try: + with open(log_file, 'r', encoding='utf-8', errors='ignore') as f: + content = f.read() + return jsonify({ + 'code': 0, + 'data': { + 'content': content, + 'status': status, + 'log_file': log_file + } + }) + except Exception as e: + return jsonify({ + 'code': 0, + 'data': { + 'content': '', + 'status': status, + 'message': f'读取日志失败: {str(e)}' + } + }) + + except Exception as e: + logger.error(f"获取训练日志失败: {e}") + return jsonify({'code': 1, 'message': str(e)}) + + +import re + + +@fine_tune_bp.route('/progress/', methods=['GET']) +def get_training_progress(task_id): + """获取训练任务进度(从日志中解析 llamafactory 的进度信息)""" + try: + from .model_manage import get_db_connection + + # 获取任务信息和进程ID + conn = get_db_connection() + cursor = conn.cursor() + cursor.execute( + "SELECT name, process_id, status FROM fine_tune WHERE id = %s", + (task_id,) + ) + result = cursor.fetchone() + conn.close() + + if not result: + return jsonify({'code': 1, 'message': '任务不存在'}) + + process_id = result.get('process_id') + task_name = result['name'] + status = result['status'] + + if not process_id: + return jsonify({ + 'code': 0, + 'data': { + 'progress': 0, + 'step': '', + 'eta': '', + 'speed': '', + 'status': status, + 'message': '任务尚未启动' + } + }) + + # 构建日志文件路径 - 新格式: logs/{日期}/{task_id}_{task_name}.log + from datetime import datetime + today = datetime.now().strftime('%Y-%m-%d') + training_logs_dir = os.path.join('/app/base/logs', today) + + # 查找日志文件 (新格式: {task_id}_{task_name}.log) + log_file = os.path.join(training_logs_dir, f'{task_id}_{task_name}.log') + + # TensorBoard 日志目录(使用默认值) + tensorboard_log_dir = '/app/base/saves' + + if not os.path.exists(log_file): + return jsonify({ + 'code': 0, + 'data': { + 'step': '', + 'elapsed': '', + 'eta': '', + 'speed': '', + 'status': status, + 'message': '日志文件尚未创建', + 'tensorboard_url': '' + } + }) + + # 读取日志文件最后部分,解析进度信息 + try: + with open(log_file, 'r', encoding='utf-8', errors='ignore') as f: + # 读取最后 10KB 内容 + f.seek(0, 2) # 跳到文件末尾 + file_size = f.tell() + read_size = min(10240, file_size) + f.seek(max(0, file_size - read_size)) + content = f.read() + + # 匹配 llamafactory 进度格式: 52%|█████▏ | 17/33 [02:16<02:08, 8.04s/it] + progress_pattern = r'\s*(\d+)%\|[█░▌▋▒█▏▎▏▐▀■□▪▫‣▶➜➡→]+\s*\|\s*(\d+)/(\d+)\s+\[(\d+):(\d+)<(\d+):(\d+),\s*([\d.]+)s/it\]' + match = re.search(progress_pattern, content) + + step_info = '' + elapsed = '' + eta = '' + speed = '' + message = '等待训练开始' + + if match: + current_step = int(match.group(2)) + total_steps = int(match.group(3)) + elapsed_min = int(match.group(4)) + elapsed_sec = int(match.group(5)) + eta_min = int(match.group(6)) + eta_sec = int(match.group(7)) + speed_val = float(match.group(8)) + + step_info = f'{current_step}/{total_steps}' + elapsed = f'{elapsed_min:02d}:{elapsed_sec:02d}' + eta = f'{eta_min:02d}:{eta_sec:02d}' + speed = f'{speed_val}s/it' + message = '训练进行中' + + return jsonify({ + 'code': 0, + 'data': { + 'step': step_info, + 'elapsed': elapsed, + 'eta': eta, + 'speed': speed, + 'status': status, + 'message': message, + 'tensorboard_log_dir': tensorboard_log_dir, + 'tensorboard_url': '' + } + }) + + except Exception as e: + return jsonify({ + 'code': 0, + 'data': { + 'step': '', + 'elapsed': '', + 'eta': '', + 'speed': '', + 'status': status, + 'message': f'读取进度失败: {str(e)}', + 'tensorboard_log_dir': tensorboard_log_dir, + 'tensorboard_url': '' + } + }) + + except Exception as e: + logger.error(f"获取训练进度失败: {e}") + return jsonify({'code': 1, 'message': str(e)}) + + def get_db_connection(): """获取数据库连接""" import pymysql @@ -441,3 +775,129 @@ def get_db_connection(): charset=db_config.get('charset', 'utf8mb4'), cursorclass=pymysql.cursors.DictCursor ) + + +@fine_tune_bp.route('/check-name', methods=['GET']) +def check_task_name(): + """检查任务名称是否重复""" + try: + name = request.args.get('name', '').strip() + if not name: + return jsonify({'code': 1, 'message': '任务名称不能为空'}) + + # 验证任务名称格式:只能包含英文、数字、下划线 + import re + if not re.match(r'^[a-zA-Z0-9_]+$', name): + return jsonify({'code': 1, 'message': '任务名称只能包含英文、数字和下划线'}) + + conn = get_db_connection() + cursor = conn.cursor() + cursor.execute("SELECT id FROM fine_tune WHERE name = %s", (name,)) + result = cursor.fetchone() + conn.close() + + if result: + return jsonify({ + 'code': 0, + 'data': { + 'exists': True, + 'message': '任务名称已存在' + } + }) + else: + return jsonify({ + 'code': 0, + 'data': { + 'exists': False, + 'message': '任务名称可用' + } + }) + + except Exception as e: + logger.error(f"检查任务名称失败: {e}") + return jsonify({'code': 1, 'message': str(e)}) + + +# TensorBoard 服务进程 +tensorboard_process = None + + +@fine_tune_bp.route('/tensorboard/start', methods=['POST']) +def start_tensorboard(): + """启动 TensorBoard 服务""" + global tensorboard_process + try: + import subprocess + import os + + # 检查是否已有进程在运行 + if tensorboard_process and tensorboard_process.poll() is None: + return jsonify({ + 'code': 0, + 'data': { + 'url': 'http://10.10.10.177:6006', + 'status': 'already_running', + 'message': 'TensorBoard 服务已运行' + } + }) + + # 获取日志目录 + log_dir = '/app/base/saves' + + # 检查目录是否存在 + if not os.path.exists(log_dir): + return jsonify({'code': 1, 'message': f'日志目录不存在: {log_dir}'}) + + # 启动 TensorBoard(后台运行) + cmd = ['tensorboard', '--logdir', log_dir, '--port', '6006', '--bind_all'] + tensorboard_process = subprocess.Popen( + cmd, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + preexec_fn=os.setsid + ) + + logger.info(f"TensorBoard 服务已启动: {cmd}") + + return jsonify({ + 'code': 0, + 'data': { + 'url': 'http://10.10.10.177:6006', + 'status': 'started', + 'message': 'TensorBoard 服务已启动' + } + }) + + except Exception as e: + logger.error(f"启动 TensorBoard 失败: {e}") + return jsonify({'code': 1, 'message': str(e)}) + + +@fine_tune_bp.route('/tensorboard/stop', methods=['POST']) +def stop_tensorboard(): + """停止 TensorBoard 服务""" + global tensorboard_process + try: + import subprocess + import signal + + if tensorboard_process and tensorboard_process.poll() is None: + # 使用 os.killpg 终止进程组 + try: + os.killpg(os.getpgid(tensorboard_process.pid), signal.SIGTERM) + except Exception: + pass + tensorboard_process = None + logger.info("TensorBoard 服务已停止") + + return jsonify({ + 'code': 0, + 'data': { + 'status': 'stopped', + 'message': 'TensorBoard 服务已停止' + } + }) + + except Exception as e: + logger.error(f"停止 TensorBoard 失败: {e}") + return jsonify({'code': 1, 'message': str(e)}) diff --git a/src/api/logs.py b/src/api/logs.py index 9315991..97e3a97 100644 --- a/src/api/logs.py +++ b/src/api/logs.py @@ -169,3 +169,231 @@ def get_log_content(): }) except Exception as e: return jsonify({'code': 1, 'message': f'读取日志文件失败: {str(e)}'}) + + +# ============ 训练日志相关 API ============ + +# 训练日志保存在 logs/{日期} 目录下 +TRAINING_LOGS_BASE_DIR = '/app/base/logs' +# 本地开发时的备用路径(Windows) +LOCAL_TRAINING_LOGS_BASE_DIR = os.path.join(PROJECT_ROOT, 'logs') + +# 添加调试日志 +logs_logger.info(f"[DEBUG] TRAINING_LOGS_BASE_DIR: {TRAINING_LOGS_BASE_DIR}") +logs_logger.info(f"[DEBUG] LOCAL_TRAINING_LOGS_BASE_DIR: {LOCAL_TRAINING_LOGS_BASE_DIR}") + + +@logs_bp.route('/training-log-files', methods=['GET']) +def get_training_log_files(): + """获取训练日志文件列表 - 从 logs/{日期} 目录下的 .log 文件""" + try: + # 确定基础目录 + logs_base_dir = TRAINING_LOGS_BASE_DIR + if not os.path.exists(logs_base_dir): + logs_base_dir = LOCAL_TRAINING_LOGS_BASE_DIR + + logs_logger.info(f"[DEBUG] logs_base_dir: {logs_base_dir}, exists: {os.path.exists(logs_base_dir)}") + + if not os.path.exists(logs_base_dir): + return jsonify({'code': 0, 'data': []}) + + # 遍历所有日期目录,收集训练日志文件 + log_files = [] + date_dirs = [] + + try: + # 获取所有日期目录(格式: YYYY-MM-DD) + for item in os.listdir(logs_base_dir): + item_path = os.path.join(logs_base_dir, item) + if os.path.isdir(item_path): + # 验证是否为日期目录 + try: + datetime.strptime(item, '%Y-%m-%d') + date_dirs.append(item) + except ValueError: + pass + except Exception as list_err: + logs_logger.error(f"[DEBUG] Failed to list base directory: {list_err}") + return jsonify({'code': 0, 'data': []}) + + # 按日期排序(最新的在前面) + date_dirs.sort(reverse=True) + + logs_logger.info(f"[DEBUG] Date directories: {date_dirs}") + + # 遍历每个日期目录,查找 .log 文件 + for date_dir in date_dirs: + date_full_path = os.path.join(logs_base_dir, date_dir) + try: + files = os.listdir(date_full_path) + except Exception as list_err: + logs_logger.warning(f"[DEBUG] Failed to list {date_full_path}: {list_err}") + continue + + for file_name in files: + if not file_name.endswith('.log'): + continue + + file_path = os.path.join(date_full_path, file_name) + try: + size = os.path.getsize(file_path) + except Exception as size_err: + logs_logger.warning(f"[DEBUG] Failed to get size of {file_path}: {size_err}") + continue + + # 文件名格式: {task_id}_{task_name}.log + # 例如: 889_testing.log + parts = file_name.replace('.log', '').split('_', 1) + if len(parts) >= 2: + task_id = parts[0] + task_name = parts[1] + try: + dt = datetime.strptime(date_dir, '%Y-%m-%d') + # 使用日期目录的时间作为排序键 + sort_key = dt.timestamp() + display_date = date_dir + except: + sort_key = 0 + display_date = date_dir + else: + task_id = 'unknown' + task_name = file_name.replace('.log', '') + sort_key = 0 + display_date = date_dir + + # 构建相对路径 (日期/文件名) + relative_path = f"{date_dir}/{file_name}" + + log_files.append({ + 'name': task_name, + 'file': relative_path, + 'task_id': task_id, + 'date': display_date, + 'size': format_file_size(size), + 'sort_key': sort_key + }) + + # 按时间戳排序(最新的在前面) + log_files.sort(key=lambda x: x['sort_key'] if x['sort_key'] else 0, reverse=True) + + logs_logger.info(f"[DEBUG] Found {len(log_files)} training log files") + + return jsonify({'code': 0, 'data': log_files}) + except Exception as e: + logs_logger.error(f"[DEBUG] 获取训练日志列表失败: {e}") + return jsonify({'code': 1, 'message': f'获取训练日志列表失败: {str(e)}'}) + + +@logs_bp.route('/training-log-content', methods=['GET']) +def get_training_log_content(): + """获取训练日志文件内容 - 从 logs/{日期}/ 目录""" + file_name = request.args.get('file') + if not file_name: + return jsonify({'code': 1, 'message': '缺少文件参数'}) + + logs_logger.info(f"[DEBUG] ============ get_training_log_content ============") + logs_logger.info(f"[DEBUG] file: {file_name}") + + # 防止目录遍历攻击 + file_name = file_name.replace('..', '').replace('//', '/') + + # file 格式: 日期/文件名,例如: 2026-01-28/889_testing.log + # 解析日期和文件名 + parts = file_name.split('/') + if len(parts) < 2: + return jsonify({'code': 1, 'message': '无效的文件路径格式'}) + + date_dir = parts[0] + log_file_name = '/'.join(parts[1:]) + + # 验证日期格式 + try: + datetime.strptime(date_dir, '%Y-%m-%d') + except ValueError: + return jsonify({'code': 1, 'message': '无效的日期格式'}) + + # 确定基础目录 + container_base_dir = TRAINING_LOGS_BASE_DIR # /app/base/logs + local_base_dir = LOCAL_TRAINING_LOGS_BASE_DIR # 项目目录下的 logs + + container_full_path = os.path.join(container_base_dir, date_dir, log_file_name) + local_full_path = os.path.join(local_base_dir, date_dir, log_file_name) + + logs_logger.info(f"[DEBUG] container_base_dir: {container_base_dir}, exists: {os.path.exists(container_base_dir)}") + logs_logger.info(f"[DEBUG] local_base_dir: {local_base_dir}, exists: {os.path.exists(local_base_dir)}") + logs_logger.info(f"[DEBUG] container_full_path: {container_full_path}, exists: {os.path.exists(container_full_path)}") + logs_logger.info(f"[DEBUG] local_full_path: {local_full_path}, exists: {os.path.exists(local_full_path)}") + + # 选择最终路径 + full_path = None + if os.path.exists(container_full_path): + full_path = container_full_path + logs_logger.info(f"[DEBUG] Using container path") + elif os.path.exists(local_full_path): + full_path = local_full_path + logs_logger.info(f"[DEBUG] Using local path") + else: + logs_logger.error(f"[DEBUG] File not found: {file_name}") + return jsonify({'code': 1, 'message': f'日志文件不存在: {file_name}'}) + + logs_logger.info(f"[DEBUG] Final full_path: {full_path}") + + # 尝试直接读取文件 + try: + max_size = 10 * 1024 * 1024 + content = '' + read_success = False + + try: + with open(full_path, 'rb') as f: + f.seek(0, 2) + size = f.tell() + f.seek(0) + + if size > max_size: + f.seek(size - max_size) + content = '... (日志文件较大,已显示最后 10MB 内容) ...\n\n' + f.read().decode('utf-8', errors='ignore') + else: + content = f.read().decode('utf-8', errors='ignore') + read_success = True + except (PermissionError, OSError) as pe: + logs_logger.warning(f"[DEBUG] 直接读取失败: {pe},尝试共享模式读取") + import mmap + try: + with open(full_path, 'rb') as f: + mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + try: + f.seek(0, 2) + size = f.tell() + if size > max_size: + content = '... (日志文件较大,已显示最后 10MB 内容) ...\n\n' + \ + mm[-max_size:].decode('utf-8', errors='ignore') + else: + content = mm[:].decode('utf-8', errors='ignore') + read_success = True + finally: + mm.close() + except Exception as e2: + logs_logger.error(f"[DEBUG] 共享模式读取失败: {e2}") + return jsonify({ + 'code': 2, + 'message': f'日志文件正在被训练进程占用,训练结束后可查看完整内容', + 'data': { + 'file': log_file_name, + 'size': format_file_size(0), + 'content': '' + } + }) + + if read_success: + return jsonify({ + 'code': 0, + 'data': { + 'file': log_file_name, + 'size': format_file_size(size), + 'content': content + } + }) + except Exception as e: + logs_logger.error(f"[DEBUG] 读取日志文件失败: {e}") + return jsonify({'code': 1, 'message': f'读取日志文件失败: {str(e)}'}) diff --git a/src/api/model_manage.py b/src/api/model_manage.py index 7dda27f..ac3fe62 100644 --- a/src/api/model_manage.py +++ b/src/api/model_manage.py @@ -193,3 +193,65 @@ def get_local_models(): except Exception as e: logger.error(f"获取本地模型列表失败: {e}") return jsonify({'code': 1, 'message': str(e)}) + + +# ============ 已训练模型列表接口 ============ + +@model_manage_bp.route('/trained-models', methods=['GET']) +def get_trained_models(): + """获取已训练模型列表(从/app/base/saves目录)""" + import logging + logger = logging.getLogger(__name__) + + try: + # 使用 /app/base/saves 目录(容器内路径) + saves_base_path = '/app/base/saves' + # 本地开发时的备用路径 + local_saves_path = os.path.join(PROJECT_ROOT, 'saves') + + # 选择存在的路径 + base_path = saves_base_path if os.path.exists(saves_base_path) else local_saves_path + + logger.info(f"[DEBUG] 已训练模型目录: {base_path}, exists: {os.path.exists(base_path)}") + + models = [] + if os.path.exists(base_path): + for item in os.listdir(base_path): + item_path = os.path.join(base_path, item) + if os.path.isdir(item_path): + # 检查是否是模板目录(包含训练方法的子目录) + sub_items = [] + if os.path.exists(item_path): + for sub_item in os.listdir(item_path): + sub_path = os.path.join(item_path, sub_item) + if os.path.isdir(sub_path): + # 检查是否包含模型文件(adapter_model.bin 或 pytorch_model.bin 等) + has_model = False + for f in os.listdir(sub_path): + if f.endswith('.bin') or f.endswith('.safetensors'): + has_model = True + break + if has_model: + sub_items.append({ + 'name': sub_item, + 'path': sub_path + }) + + models.append({ + 'name': item, + 'path': item_path, + 'train_methods': sub_items + }) + + logger.info(f"[DEBUG] 找到 {len(models)} 个已训练模型") + + return jsonify({ + 'code': 0, + 'data': { + 'models': models, + 'base_path': base_path + } + }) + except Exception as e: + logger.error(f"获取已训练模型列表失败: {e}") + return jsonify({'code': 1, 'message': str(e)}) diff --git a/src/main.py b/src/main.py index f451abf..b916c68 100644 --- a/src/main.py +++ b/src/main.py @@ -33,6 +33,9 @@ def load_config(): CONFIG = load_config() +# 训练日志路径 +TRAINING_LOGS_DIR = CONFIG.get('training_logs_path', '/app/base/training_logs') + # ============ 日志系统配置 ============ LOG_BASE_DIR = os.path.join(PROJECT_ROOT, 'logs') @@ -339,9 +342,10 @@ def init_database(): # 为 fine_tune 表添加训练相关列 columns_to_add = [ + ("description", "TEXT COMMENT '任务描述'"), ("train_dataset_id", "INT COMMENT '训练数据集ID'"), ("valid_dataset_id", "INT COMMENT '验证数据集ID'"), - ("eval_steps", "INT DEFAULT 100 COMMENT '评估步数'"), + ("save_steps", "INT DEFAULT 100 COMMENT '保存步数'"), ("lr_scheduler_type", "VARCHAR(50) DEFAULT 'cosine' COMMENT '学习率调度器'"), ("warmup_ratio", "FLOAT DEFAULT 0.05 COMMENT '预热比例'"), ("weight_decay", "FLOAT DEFAULT 0.01 COMMENT '权重衰减'"), @@ -379,8 +383,18 @@ def init_database(): app = Flask(__name__) app.config['SECRET_KEY'] = CONFIG['secret_key'] app.config['CORS_HEADERS'] = 'Content-Type' -# 允许所有来源 -CORS(app, resources={r"/api/*": {"origins": "*"}}, methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["Content-Type", "Authorization"]) + +# 允许所有来源 - 支持跨域请求 +CORS(app, resources={ + r"/api/*": { + "origins": "*", + "methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"], + "allow_headers": ["Content-Type", "Authorization", "X-Requested-With"], + "expose_headers": ["Content-Length", "Content-Range"], + "supports_credentials": False, + "max_age": 86400 # 缓存预检请求结果 24 小时 + } +}, vary_header=True) # 注册蓝图 register_blueprints(app) @@ -674,6 +688,168 @@ def get_fine_tune(): return jsonify({'code': 0, 'data': generic_get_all('fine_tune')}) +@app.route('/api/fine-tune/', methods=['GET']) +def get_fine_tune_by_id(id): + """获取单个训练任务详情""" + try: + conn = get_db_connection() + cursor = conn.cursor() + cursor.execute("SELECT * FROM fine_tune WHERE id = %s", (id,)) + task = cursor.fetchone() + + if not task: + cursor.close() + conn.close() + return jsonify({'code': 1, 'message': '任务不存在'}) + + # 获取列名并转换为字典(get_db_connection已使用DictCursor,task已是字典) + if isinstance(task, dict): + task_dict = task + else: + columns = [desc[0] for desc in cursor.description] + task_dict = dict(zip(columns, task)) + + cursor.close() + conn.close() + + # 处理 datetime 序列化 + for key, value in task_dict.items(): + if isinstance(value, datetime): + task_dict[key] = value.strftime('%Y-%m-%d %H:%M:%S') + + return jsonify({'code': 0, 'data': task_dict}) + except Exception as e: + return jsonify({'code': 1, 'message': str(e)}) + + +@app.route('/api/fine-tune/progress/', methods=['GET']) +def get_fine_tune_progress(id): + """获取训练任务的进度(通过解析日志文件)""" + try: + # 获取任务信息 + conn = get_db_connection() + cursor = conn.cursor(dictionary=True) + cursor.execute("SELECT id, process_id, name, status FROM fine_tune WHERE id = %s", (id,)) + task = cursor.fetchone() + conn.close() + + if not task: + return jsonify({'code': 1, 'message': '任务不存在'}) + + process_id = task.get('process_id') + task_name = task.get('name', '') + + if not process_id: + return jsonify({'code': 0, 'data': {'progress': 0, 'status': task.get('status', 'unknown'), 'step': '', 'speed': '', 'eta': ''}}) + + # 查找日志文件 - 优先使用容器路径,如果不存在则使用本地路径 + training_logs_dir = TRAINING_LOGS_DIR + if not os.path.exists(training_logs_dir): + training_logs_dir = os.path.join(PROJECT_ROOT, 'training_logs') + + if not os.path.exists(training_logs_dir): + return jsonify({'code': 0, 'data': {'progress': 0, 'status': task.get('status', 'unknown'), 'step': '', 'speed': '', 'eta': ''}}) + + log_file = None + + # 优先按 process_id 查找 + for file_name in os.listdir(training_logs_dir): + if file_name.endswith('.log') and file_name.startswith(f'{process_id}_'): + log_file = os.path.join(training_logs_dir, file_name) + break + + # 如果没找到,尝试按任务名称查找 + if not log_file and task_name: + for file_name in os.listdir(training_logs_dir): + if file_name.endswith('.log') and task_name in file_name: + log_file = os.path.join(training_logs_dir, file_name) + break + + if not log_file or not os.path.exists(log_file): + return jsonify({'code': 0, 'data': {'progress': 0, 'status': task.get('status', 'unknown'), 'step': '', 'speed': '', 'eta': ''}}) + + # 读取日志文件内容 + try: + with open(log_file, 'r', encoding='utf-8') as f: + content = f.read() + except Exception as e: + return jsonify({'code': 0, 'data': {'progress': 0, 'status': task.get('status', 'unknown'), 'step': '', 'speed': '', 'eta': ''}}) + + # 解析进度 + progress = 0 + step_info = '' + speed_info = '' + eta_info = '' + + import re + + # 处理 Windows 格式的日志(\r 覆盖行),将 \r 替换为换行 + content = content.replace('\r', '\n') + + # 日志格式: " 3%|▎ | 1/33 [00:09<05:10, 9.69s/it]" + # 或: " 30%|███ | 10/33 [01:22<03:00, 7.86s/it]" + # 匹配 "数字%|进度条| step/total [elapsed', methods=['DELETE']) def delete_fine_tune(id): + # 删除前获取任务信息(用于删除日志文件) + conn = get_db_connection() + cursor = conn.cursor() + cursor.execute("SELECT process_id, name FROM fine_tune WHERE id = %s", (id,)) + task_info = cursor.fetchone() + conn.close() + + # 删除相关的日志文件 + if task_info and task_info.get('process_id'): + from datetime import datetime + process_id = task_info['process_id'] + task_name = task_info.get('name', 'unknown') + + # 优先使用容器路径,如果不存在则使用本地路径 + training_logs_dir = TRAINING_LOGS_DIR + if not os.path.exists(training_logs_dir): + training_logs_dir = os.path.join(PROJECT_ROOT, 'training_logs') + + try: + if os.path.exists(training_logs_dir): + for file_name in os.listdir(training_logs_dir): + # 查找以 PID 开头的日志文件 + if file_name.endswith('.log') and file_name.startswith(f'{process_id}_'): + log_file = os.path.join(training_logs_dir, file_name) + try: + os.remove(log_file) + print(f"[INFO] 已删除日志文件: {log_file}") + except Exception as e: + print(f"[WARN] 删除日志文件失败: {log_file}, 错误: {e}") + except Exception as e: + print(f"[WARN] 查找或删除日志文件时出错: {e}") + + # 删除数据库记录 generic_delete('fine_tune', id) return jsonify({'code': 0, 'message': '删除成功'}) diff --git a/start_all.sh b/start_all.sh index 8e1d10d..58f54af 100644 --- a/start_all.sh +++ b/start_all.sh @@ -62,7 +62,9 @@ start_api() { return 1 fi - python src/main.py & + LOG_DIR="$SCRIPT_DIR/logs/$(date +%Y-%m-%d)" + mkdir -p "$LOG_DIR" + python src/main.py > "$LOG_DIR/api.log" 2>&1 & API_PID=$! echo "✅ 后端服务已启动 (PID: $API_PID, 端口: $API_PORT)" echo "$API_PID" > /tmp/ygft_api.pid diff --git a/web/pages/fine-tune-create.html b/web/pages/fine-tune-create.html index 79920fc..ab9eb46 100644 --- a/web/pages/fine-tune-create.html +++ b/web/pages/fine-tune-create.html @@ -219,10 +219,21 @@

基本信息

- +

0 / 50

+ +
+
+
+ +
+ +

0 / 200

@@ -466,16 +477,16 @@ - eval_steps + save_steps * - + [10, 10000] - 每训练多少步进行一次模型评估,建议设置为100的倍数 + 每训练多少步进行一次模型保存,建议设置为100的倍数 @@ -616,14 +627,7 @@

训练产出

- -
- -
- -

0 / 50

-
-
+

训练完成后,模型将保存为: 任务名称

@@ -678,16 +682,38 @@ }); }); - // 任务名称字数统计 + // 任务名称字数统计和实时预览(只能输入英文、数字、下划线) const nameInput = document.querySelector('input[name="name"]'); + const nameFormatError = document.getElementById('nameFormatError'); + const nameRegex = /^[a-zA-Z0-9_]*$/; + nameInput.addEventListener('input', () => { + const value = nameInput.value; + // 验证格式 + if (value.length > 0 && !nameRegex.test(value)) { + nameInput.classList.add('border-red-500'); + nameInput.classList.remove('border-gray-300'); + nameFormatError.classList.remove('hidden'); + } else { + nameInput.classList.remove('border-red-500'); + nameInput.classList.add('border-gray-300'); + nameFormatError.classList.add('hidden'); + } + // 过滤非法字符:只允许英文、数字、下划线 + const filteredValue = value.replace(/[^a-zA-Z0-9_]/g, ''); + if (value !== filteredValue) { + nameInput.value = filteredValue; + } document.getElementById('nameCount').textContent = nameInput.value.length; + // 更新模型名称预览 + document.getElementById('modelNamePreview').textContent = nameInput.value || '任务名称'; + updateCommandPreview(); }); - // 模型名称字数统计 - const modelNameInput = document.querySelector('input[name="output_model_name"]'); - modelNameInput.addEventListener('input', () => { - document.getElementById('modelNameCount').textContent = modelNameInput.value.length; + // 任务描述字数统计 + const descInput = document.querySelector('textarea[name="description"]'); + descInput.addEventListener('input', () => { + document.getElementById('descriptionCount').textContent = descInput.value.length; }); // 加载数据集列表 @@ -774,7 +800,7 @@ 'batch_size': 1, 'learning_rate': 0.0001, 'n_epochs': 1, - 'eval_steps': 100, + 'save_steps': 100, 'lr_scheduler_type': 'cosine', 'max_length': 512, 'warmup_ratio': 0.05, @@ -1014,7 +1040,7 @@ batch_size: parseInt(formData.get('batch_size')) || 1, learning_rate: parseFloat(formData.get('learning_rate')) || 0.0001, n_epochs: parseFloat(formData.get('n_epochs')) || 1.0, - eval_steps: parseInt(formData.get('eval_steps')) || 100, + save_steps: parseInt(formData.get('save_steps')) || 100, lr_scheduler_type: formData.get('lr_scheduler_type') || 'cosine', max_length: parseInt(formData.get('max_length')) || 512, warmup_ratio: parseFloat(formData.get('warmup_ratio')) || 0.05, @@ -1024,15 +1050,18 @@ lora_rank: formData.get('lora_rank') || '8' }; + const taskName = formData.get('name'); + const data = { - name: formData.get('name'), + name: taskName, + description: formData.get('description'), base_model: formData.get('base_model'), template: formData.get('template'), train_type: formData.get('train_type'), train_method: formData.get('train_method'), gpus: selectedGPUs, train_dataset_id: formData.get('train_dataset_id'), - output_model_name: formData.get('output_model_name'), + output_model_name: taskName, // 使用任务名称作为模型名称 ...trainParams, status: 'pending', progress: 0 @@ -1042,6 +1071,26 @@ showMessage('提示', '请输入任务名称', 'warning'); return; } + + // 验证任务名称格式 + const nameRegex = /^[a-zA-Z0-9_]+$/; + if (!nameRegex.test(data.name)) { + showMessage('提示', '任务名称只能包含英文、数字和下划线', 'warning'); + return; + } + + // 检查任务名称是否重复 + try { + const checkResponse = await fetch(`${API_BASE}/fine-tune/check-name?name=${encodeURIComponent(data.name)}`); + const checkResult = await checkResponse.json(); + if (checkResult.code === 0 && checkResult.data.exists) { + showMessage('提示', '任务名称已存在,请使用其他名称', 'warning'); + return; + } + } catch (error) { + console.error('检查任务名称失败:', error); + } + if (selectedGPUs.length === 0) { showMessage('提示', '请选择至少一个GPU硬件', 'warning'); return; @@ -1060,6 +1109,12 @@ } try { + // 显示加载中状态 + const submitBtn = document.querySelector('button[onclick="submitForm()"]'); + const originalText = submitBtn.innerHTML; + submitBtn.disabled = true; + submitBtn.innerHTML = '训练任务创建中...'; + // 第一步:创建训练任务记录 const createResponse = await fetch(`${API_BASE}/fine-tune`, { method: 'POST', @@ -1068,6 +1123,8 @@ }); const createResult = await createResponse.json(); if (createResult.code !== 0) { + submitBtn.disabled = false; + submitBtn.innerHTML = originalText; showMessage('错误', createResult.message || '创建任务失败', 'error'); return; } @@ -1077,12 +1134,13 @@ // 第二步:启动训练 const startData = { task_id: taskId, + name: data.name, // 任务名称,用于日志文件名和模型名称 base_model: data.base_model, template: data.template, train_type: data.train_type, train_method: data.train_method, train_dataset_id: data.train_dataset_id, - output_model_name: data.output_model_name, + output_model_name: data.name, // 使用任务名称作为模型名称 ...trainParams }; @@ -1093,9 +1151,12 @@ }); const startResult = await startResponse.json(); + // 恢复按钮状态 + submitBtn.disabled = false; + submitBtn.innerHTML = originalText; + if (startResult.code === 0) { - const cmd = startResult.data?.command || ''; - showMessage('成功', `训练任务已启动!

${cmd}`, 'success', () => { + showMessage('成功', '训练任务已启动!', 'success', () => { window.location.href = 'main.html'; }); } else { @@ -1108,6 +1169,12 @@ showMessage('错误', startResult.message || '启动训练失败', 'error'); } } catch (error) { + // 恢复按钮状态 + const submitBtn = document.querySelector('button[onclick="submitForm()"]'); + if (submitBtn) { + submitBtn.disabled = false; + submitBtn.innerHTML = '开始训练'; + } showMessage('错误', '操作失败: ' + error.message, 'error'); } } @@ -1146,9 +1213,10 @@ const trainMethod = formData.get('train_method') || 'lora'; const methodMap = { 'lora': 'lora', 'full': 'full' }; - // 获取输出模型名称 - const outputModelName = formData.get('output_model_name') || `${template}/${trainMethod}`; - const outputDir = outputModelName.startsWith('./') ? outputModelName : `./saves/${outputModelName}`; + // 获取输出模型名称(使用任务名称) + const taskName = formData.get('name') || 'task_name'; + const outputModelName = taskName; + const outputDir = outputModelName.startsWith('/') ? outputModelName : `/app/base/saves/${outputModelName}`; // 获取数据集名称 const trainDatasetSelect = form.querySelector('select[name="train_dataset_id"]'); @@ -1167,7 +1235,7 @@ const nEpochs = parseFloat(formData.get('n_epochs')) || 1.0; const maxLength = parseInt(formData.get('max_length')) || 512; const warmupSteps = parseInt(formData.get('warmup_steps')) || 20; - const evalSteps = parseInt(formData.get('eval_steps')) || 100; + const saveSteps = parseInt(formData.get('save_steps')) || 100; const gradientAccumulationSteps = parseInt(formData.get('gradient_accumulation_steps')) || 8; const lrSchedulerType = formData.get('lr_scheduler_type') || 'cosine'; @@ -1204,10 +1272,10 @@ cmd += ` --lr_scheduler_type ${lrSchedulerType} \\\n`; cmd += ` --logging_steps 50 \\\n`; cmd += ` --warmup_steps ${warmupSteps} \\\n`; - cmd += ` --save_steps 100 \\\n`; - cmd += ` --eval_steps ${evalSteps} \\\n`; + cmd += ` --save_steps ${saveSteps} \\\n`; cmd += ` --learning_rate ${learningRate} \\\n`; - cmd += ` --num_train_epochs ${nEpochs}`; + cmd += ` --num_train_epochs ${nEpochs} \\\n`; + cmd += ` --plot_loss`; return cmd; } diff --git a/web/pages/main.html b/web/pages/main.html index 84f0a75..0907d9c 100644 --- a/web/pages/main.html +++ b/web/pages/main.html @@ -260,6 +260,11 @@
+ + @@ -304,28 +309,6 @@
+ + + + + + + +
+
+
+ 训练日志 +
+
+ +
+
+
+ + +
+
+

加载中...

+ 加载中 +
+
+
+
基础模型
+
-
+
+
+
数据集
+
-
+
+
+
创建时间
+
-
+
+
+
进程ID
+
-
+
+
+
最后更新
+
-
+
+
+
+ + +
+

训练曲线

+
+
+ +
+
+ +
+
+ +
+
+
+ + +
+
+

实时日志

+
+ +
+
+
+
+
+ 加载日志中... +
+
+
+ + + +