From a560d24e2f400ff5f49b249433a53d2cedcfb496 Mon Sep 17 00:00:00 2001 From: "WIN-JHFT4D3SIVT\\caoxiaozhu" Date: Wed, 28 Jan 2026 10:31:09 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=BE=AE=E8=B0=83=E5=B7=B2?= =?UTF-8?q?=E7=BB=8F=E8=B0=83=E9=80=9A=20=E5=A2=9E=E5=8A=A0=E4=BA=86?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E9=A2=84=E8=A7=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api/__init__.py | 2 + src/api/datasets.py | 101 ++++++- src/api/fine_tune.py | 443 +++++++++++++++++++++++++++++ src/api/model_manage.py | 6 +- src/main.py | 60 +++- web/pages/fine-tune-create.html | 376 +++++++++++++++++++----- web/pages/main.html | 4 +- web/pages/model-manage-create.html | 2 +- 8 files changed, 898 insertions(+), 96 deletions(-) create mode 100644 src/api/fine_tune.py diff --git a/src/api/__init__.py b/src/api/__init__.py index 8e90cb1..851ad6b 100644 --- a/src/api/__init__.py +++ b/src/api/__init__.py @@ -6,6 +6,7 @@ from .model_manage import model_manage_bp from .model_chat import model_chat_bp from .dimension import dimension_bp from .logs import logs_bp +from .fine_tune import fine_tune_bp # 注册所有蓝图 def register_blueprints(app): @@ -15,3 +16,4 @@ def register_blueprints(app): app.register_blueprint(model_chat_bp) app.register_blueprint(dimension_bp) app.register_blueprint(logs_bp) + app.register_blueprint(fine_tune_bp) diff --git a/src/api/datasets.py b/src/api/datasets.py index 6a8e714..c954fd0 100644 --- a/src/api/datasets.py +++ b/src/api/datasets.py @@ -5,6 +5,7 @@ import io import os import time import zipfile +import json from flask import Blueprint, request, jsonify, send_from_directory, Response from werkzeug.utils import secure_filename @@ -52,6 +53,45 @@ def allowed_file(filename): return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS +def update_dataset_info_json(dataset_name=None, actual_filename=None, remove_filename=None): + """更新 datasets/dataset_info.json 文件 + + Args: + dataset_name: 数据集名称(用于作为 key) + actual_filename: 实际保存的文件名(含时间戳前缀),用于 file_name + remove_filename: 要移除的文件名,为None表示不移除 + """ + info_path = os.path.join(DATASET_FOLDER, 'dataset_info.json') + + # 读取现有配置 + dataset_info = {} + if os.path.exists(info_path): + try: + with open(info_path, 'r', encoding='utf-8') as f: + dataset_info = json.load(f) + except Exception as e: + print(f"读取 dataset_info.json 失败: {e}") + + # 移除旧条目(根据移除的文件名) + if remove_filename: + key = os.path.splitext(remove_filename)[0] + if key in dataset_info: + del dataset_info[key] + print(f"从 dataset_info.json 移除: {key}") + + # 添加新条目 + if dataset_name and actual_filename: + dataset_info[dataset_name] = {"file_name": actual_filename} + print(f"更新 dataset_info.json: {dataset_name} -> {actual_filename}") + + # 写入文件 + try: + with open(info_path, 'w', encoding='utf-8') as f: + json.dump(dataset_info, f, ensure_ascii=False, indent=2) + except Exception as e: + print(f"写入 dataset_info.json 失败: {e}") + + def generic_get_by_id(table_name, id_val): """通用按ID查询""" conn = get_db_connection() @@ -149,17 +189,40 @@ def delete_dataset(id): """删除数据集""" conn = get_db_connection() cursor = conn.cursor() - # 获取文件路径列表 - cursor.execute("SELECT file_path FROM dataset_files WHERE dataset_id = %s", (id,)) + + # 获取数据集名称(用于从 dataset_info.json 移除条目) + cursor.execute("SELECT name FROM dataset_manage WHERE id = %s", (id,)) + dataset_result = cursor.fetchone() + dataset_name = dataset_result['name'] if dataset_result else None + + # 获取文件信息列表(包含原始文件名) + cursor.execute("SELECT file_name, file_path FROM dataset_files WHERE dataset_id = %s", (id,)) files = cursor.fetchall() - # 删除文件 + for f in files: file_path = f.get('file_path') - if file_path and os.path.exists(file_path): - try: - os.remove(file_path) - except Exception as e: - print(f"删除文件失败: {file_path}, {e}") + # 尝试多个可能的路径 + paths_to_try = [] + if file_path: + paths_to_try.append(file_path) + # 尝试 PROJECT_ROOT 相对路径 + rel_path = file_path.replace('/app/base', PROJECT_ROOT, 1) if file_path.startswith('/app/base') else None + if rel_path: + paths_to_try.append(rel_path) + + for path in paths_to_try: + if path and os.path.exists(path): + try: + os.remove(path) + print(f"已删除文件: {path}") + break + except Exception as e: + print(f"删除文件失败: {path}, {e}") + + # 使用数据集名称从 dataset_info.json 移除条目 + if dataset_name: + update_dataset_info_json(remove_filename=dataset_name) + # 删除数据库记录 cursor.execute("DELETE FROM dataset_files WHERE dataset_id = %s", (id,)) cursor.execute("DELETE FROM dataset_manage WHERE id = %s", (id,)) @@ -218,6 +281,12 @@ def upload_dataset_file(dataset_id): file.save(file_path) file_size = os.path.getsize(file_path) + # 获取数据集名称用于作为 dataset_info.json 的 key + dataset_name = dataset.get('name') if dataset else None + + # 更新 dataset_info.json,使用数据集名称作为 key,实际保存的文件名作为 file_name + update_dataset_info_json(dataset_name=dataset_name, actual_filename=new_filename) + # 获取文件扩展名(安全处理无扩展名的情况) parts = filename.rsplit('.', 1) ext = parts[1].lower() if len(parts) > 1 else 'unknown' @@ -304,8 +373,8 @@ def delete_dataset_file(file_id): conn = get_db_connection() cursor = conn.cursor() - # 获取文件信息 - cursor.execute("SELECT dataset_id, file_path FROM dataset_files WHERE id = %s", (file_id,)) + # 获取文件信息(包含 dataset_id) + cursor.execute("SELECT dataset_id, file_name, file_path FROM dataset_files WHERE id = %s", (file_id,)) file_info = cursor.fetchone() if not file_info: @@ -313,6 +382,13 @@ def delete_dataset_file(file_id): conn.close() return jsonify({'code': 1, 'message': '文件不存在'}) + dataset_id = file_info['dataset_id'] + + # 获取数据集名称(用于从 dataset_info.json 移除条目) + cursor.execute("SELECT name FROM dataset_manage WHERE id = %s", (dataset_id,)) + dataset_result = cursor.fetchone() + dataset_name = dataset_result['name'] if dataset_result else None + # 删除物理文件 file_path = file_info['file_path'] if file_path and os.path.exists(file_path): @@ -324,8 +400,11 @@ def delete_dataset_file(file_id): # 删除数据库记录 cursor.execute("DELETE FROM dataset_files WHERE id = %s", (file_id,)) + # 使用数据集名称从 dataset_info.json 移除条目 + if dataset_name: + update_dataset_info_json(remove_filename=dataset_name) + # 更新数据集的文件数量和大小 - dataset_id = file_info['dataset_id'] cursor.execute("SELECT COUNT(*) as count, SUM(file_size) as total_size FROM dataset_files WHERE dataset_id = %s", (dataset_id,)) result = cursor.fetchone() file_count = result['count'] or 0 diff --git a/src/api/fine_tune.py b/src/api/fine_tune.py new file mode 100644 index 0000000..1195d33 --- /dev/null +++ b/src/api/fine_tune.py @@ -0,0 +1,443 @@ +""" +精调训练 API 路由 +调用 llamafactory-cli 执行训练任务 +""" +import os +import subprocess +import json +import threading +import time +from flask import Blueprint, request, jsonify +import logging + +logger = logging.getLogger(__name__) +train_logger = logging.getLogger('train') # 专门的训练日志 logger,输出到 train.log + +# 创建蓝图 +fine_tune_bp = Blueprint('fine_tune', __name__, url_prefix='/api/fine-tune') + + +# 训练类型映射 +TRAIN_TYPE_MAP = { + 'SFT': 'sft', + 'DPO': 'dpo', + 'CPT': 'cpt' +} + +# 训练方法映射 +FINETUNING_TYPE_MAP = { + 'lora': 'lora', + 'full': 'full' +} + + +@fine_tune_bp.route('/start', methods=['POST']) +def start_training(): + """启动 llamafactory 训练任务""" + try: + data = request.json + train_logger.info(f"[TRAIN] ========== 开始训练任务 ==========") + train_logger.info(f"[TRAIN] 收到启动训练请求: base_model={data.get('base_model')}, train_dataset_id={data.get('train_dataset_id')}") + + # 必填参数验证 + required_fields = ['base_model', 'template', 'train_dataset_id'] + for field in required_fields: + if not data.get(field): + return jsonify({'code': 1, 'message': f'缺少必要参数: {field}'}) + + # 获取模型信息 + model_path = data.get('base_model') + # 尝试转换为整数 + try: + model_id = int(model_path) if str(model_path).isdigit() else None + except (ValueError, TypeError): + model_id = None + + if model_id: + # 如果是 model_id,需要获取模型路径 + from .model_manage import get_db_connection + conn = get_db_connection() + cursor = conn.cursor() + cursor.execute("SELECT id, name, path FROM model_manage WHERE id = %s", (model_id,)) + result = cursor.fetchone() + conn.close() + logger.info(f"模型查询结果: {result}") + if result and result.get('path'): + model_path = result['path'] + logger.info(f"从数据库获取的模型路径: {model_path}") + else: + return jsonify({'code': 1, 'message': '模型不存在或路径为空'}) + elif not model_path: + return jsonify({'code': 1, 'message': f'模型路径为空'}) + + train_logger.info(f"[TRAIN] 模型路径: {model_path}") + + # 设置工作目录为 llamafactory 目录 + llamafactory_dir = '/app/src/llamafactory' + + # 处理数据集文件:将数据集复制到 llamafactory 的 datasets 目录 + dataset_id = data.get('train_dataset_id') + try: + dataset_id_int = int(dataset_id) if str(dataset_id).isdigit() else None + except (ValueError, TypeError): + dataset_id_int = None + + llamafactory_datasets_dir = os.path.join(llamafactory_dir, 'datasets') + os.makedirs(llamafactory_datasets_dir, exist_ok=True) + + # 获取数据集名称(用于 --dataset 参数) + dataset_key = None + if dataset_id_int: + from .datasets import get_db_connection as get_dataset_conn + conn = get_dataset_conn() + cursor = conn.cursor() + cursor.execute("SELECT dm.name FROM dataset_manage dm WHERE dm.id = %s", (dataset_id_int,)) + dataset_result = cursor.fetchone() + conn.close() + + dataset_key = dataset_result['name'] if dataset_result else None + + if dataset_key: + # 从 dataset_info.json 读取实际文件名 + src_info_json = os.path.join('/app/base', 'datasets', 'dataset_info.json') + actual_file_name = None + if os.path.exists(src_info_json): + import json as json_lib + with open(src_info_json, 'r', encoding='utf-8') as f: + dataset_info = json_lib.load(f) + if dataset_key in dataset_info: + actual_file_name = dataset_info[dataset_key].get('file_name') + train_logger.info(f"[TRAIN] 从 dataset_info.json 获取文件名: {dataset_key} -> {actual_file_name}") + + # 复制数据集文件到 llamafactory 目录 + if actual_file_name: + src_file = os.path.join('/app/base', 'datasets', actual_file_name) + dst_file = os.path.join(llamafactory_datasets_dir, actual_file_name) + if os.path.exists(src_file): + import shutil + shutil.copy2(src_file, dst_file) + train_logger.info(f"[TRAIN] 复制数据集文件: {src_file} -> {dst_file}") + else: + train_logger.warning(f"[TRAIN] 数据集文件不存在: {src_file}") + + # 复制 dataset_info.json 到 llamafactory datasets 目录 + src_info_json = os.path.join('/app/base', 'datasets', 'dataset_info.json') + dst_info_json = os.path.join(llamafactory_datasets_dir, 'dataset_info.json') + try: + if os.path.exists(src_info_json): + shutil.copy2(src_info_json, dst_info_json) + train_logger.info(f"[TRAIN] 已复制 dataset_info.json 到 llamafactory 目录") + else: + train_logger.warning(f"[TRAIN] dataset_info.json 不存在: {src_info_json}") + except Exception as e: + train_logger.warning(f"[TRAIN] 复制 dataset_info.json 失败: {e}") + + # 获取选中的 GPU 索引 + gpus = data.get('gpus', []) + if gpus: + gpu_ids = [gpu.get('id', '').replace('gpu', '') for gpu in gpus] + gpu_ids = [g for g in gpu_ids if g.isdigit()] + cuda_devices = ','.join(gpu_ids) + else: + cuda_devices = '0' + + # 设置环境变量 + env = os.environ.copy() + env['CUDA_VISIBLE_DEVICES'] = cuda_devices + env['TF_CPP_MIN_LOG_LEVEL'] = '2' # 减少 TensorFlow 日志 + + # 构建 llamafactory-cli 命令(传入数据集名称用于 --dataset 参数) + cmd = build_train_command(data, model_path, dataset_key) + cmd_str = ' '.join(cmd) + train_logger.info(f"[TRAIN] 执行训练命令: {cmd_str}") + + # 在返回的命令中显示 GPU 配置 + cmd_str_with_gpu = f"CUDA_VISIBLE_DEVICES={cuda_devices} {cmd_str}" + + # 生成训练日志文件路径(按日期分目录) + from datetime import datetime + today = datetime.now().strftime('%Y-%m-%d') + task_id_str = str(data.get('task_id', 'unknown')) + log_dir = os.path.join(llamafactory_dir, 'logs', today) + train_output_log = os.path.join(log_dir, f'train_{task_id_str}.log') + + # 确保日志目录存在 + os.makedirs(log_dir, exist_ok=True) + + train_logger.info(f"[TRAIN] 启动训练进程...") + + # 使用线程在后台运行训练进程 + def run_training(): + with open(train_output_log, 'w', encoding='utf-8') as log_file: + process = subprocess.Popen( + cmd, + cwd=llamafactory_dir, + stdout=log_file, + stderr=subprocess.STDOUT, + env=env + ) + train_logger.info(f"[TRAIN] 训练进程 PID: {process.pid}") + # 等待进程完成 + process.wait() + train_logger.info(f"[TRAIN] 训练进程已结束,退出码: {process.returncode}") + + # 更新任务状态 + final_status = 'completed' if process.returncode == 0 else 'failed' + update_fine_tune_status(data.get('task_id'), final_status, process.pid) + + # 启动后台线程 + training_thread = threading.Thread(target=run_training, daemon=True) + training_thread.start() + + # 立即返回,不等待进程完成 + pid = None # 此时还不知道实际 PID,稍后可从日志获取 + train_logger.info(f"[TRAIN] 训练任务已在后台启动") + train_logger.info(f"[TRAIN] 训练日志输出到: {train_output_log}") + + # 更新任务状态为运行中 + update_fine_tune_status(data.get('task_id'), 'running', 0) + + return jsonify({ + 'code': 0, + 'message': f'训练任务已启动 (GPU: {cuda_devices})', + 'data': { + 'task_id': data.get('task_id'), + 'gpu_ids': cuda_devices, + 'command': cmd_str_with_gpu, + 'log_file': train_output_log + } + }) + + except Exception as e: + import traceback + train_logger.error(f"[TRAIN] 启动训练任务失败: {e}") + train_logger.error(f"[TRAIN] 详细错误: {traceback.format_exc()}") + return jsonify({'code': 1, 'message': str(e)}) + + +def build_train_command(data, model_path, dataset_name=None): + """构建 llamafactory-cli train 命令""" + # llamafactory-cli 路径(已在系统 PATH 中) + cmd = ['llamafactory-cli', 'train'] + + # 训练阶段 + train_type = data.get('train_type', 'SFT') + cmd.extend(['--stage', TRAIN_TYPE_MAP.get(train_type, 'sft')]) + cmd.append('--do_train') + + # 模型路径 + cmd.extend(['--model_name_or_path', model_path]) + + # 数据集 - 使用数据集名称(dataset_manage.name),不是实际文件名 + if dataset_name: + cmd.extend(['--dataset', dataset_name]) + train_logger.info(f"[TRAIN] 使用数据集名称: {dataset_name}") + else: + # 回退到原有逻辑 + dataset_id = data.get('train_dataset_id') + try: + dataset_id_int = int(dataset_id) if str(dataset_id).isdigit() else None + except (ValueError, TypeError): + dataset_id_int = None + + if dataset_id_int: + dataset_name = get_dataset_name(dataset_id_int) + train_logger.info(f"[TRAIN] 从数据库获取的数据集名称: {dataset_name}") + else: + dataset_name = dataset_id + cmd.extend(['--dataset', dataset_name]) + + # 数据集目录 + cmd.extend(['--dataset_dir', './datasets']) # llamafactory 工作目录下的 datasets 目录 + + # 模板 + template = data.get('template') + cmd.extend(['--template', template]) + + # 训练方法 + train_method = data.get('train_method', 'lora') + cmd.extend(['--finetuning_type', FINETUNING_TYPE_MAP.get(train_method, 'lora')]) + + # 输出目录 + output_dir = data.get('output_model_name', f"./saves/{template}/{train_method}") + if not output_dir.startswith('./'): + output_dir = f"./saves/{output_dir}" + cmd.extend(['--output_dir', output_dir]) + + # 常用参数 + cmd.extend([ + '--overwrite_cache', + '--overwrite_output_dir', + '--cutoff_len', str(data.get('max_length', 512)), + '--preprocessing_num_workers', '16', + '--per_device_train_batch_size', str(data.get('batch_size', 1)), + '--per_device_eval_batch_size', '1', + '--gradient_accumulation_steps', str(data.get('gradient_accumulation_steps', 8)), + '--lr_scheduler_type', data.get('lr_scheduler_type', 'cosine'), + '--logging_steps', '50', + '--warmup_steps', str(data.get('warmup_steps', 20)), + '--save_steps', '100', + '--eval_steps', str(data.get('eval_steps', 100)), + ]) + + # 学习率 + cmd.extend(['--learning_rate', str(data.get('learning_rate', 0.0001))]) + + # 训练轮数 + cmd.extend(['--num_train_epochs', str(data.get('n_epochs', 1.0))]) + + # 验证集比例 + val_ratio = data.get('valid_ratio', 0) + if val_ratio > 0: + cmd.extend(['--val_size', str(val_ratio / 100)]) + + # 最大样本数 + if data.get('max_samples'): + cmd.extend(['--max_samples', str(data.get('max_samples'))]) + + # 其他选项 + if data.get('plot_loss'): + cmd.append('--plot_loss') + + if data.get('fp16'): + cmd.append('--fp16') + + if data.get('load_best_model_at_end'): + cmd.append('--load_best_model_at_end') + + return cmd + + +def get_dataset_name(dataset_id): + """根据数据集 ID 获取数据集名称""" + try: + from .datasets import get_db_connection + conn = get_db_connection() + cursor = conn.cursor() + cursor.execute("SELECT id, name FROM dataset_manage WHERE id = %s", (dataset_id,)) + result = cursor.fetchone() + conn.close() + logger.info(f"数据集查询结果: {result}") + if result and result.get('name'): + return result['name'] + logger.warning(f"未找到数据集 ID={dataset_id},使用默认值") + return 'default' + except Exception as e: + logger.error(f"查询数据集失败: {e}") + return 'default' + + +def update_fine_tune_status(task_id, status, pid=None): + """更新训练任务状态""" + try: + from .model_manage import get_db_connection + conn = get_db_connection() + cursor = conn.cursor() + + if status == 'running' and pid: + cursor.execute( + "UPDATE fine_tune SET status = %s, process_id = %s WHERE id = %s", + (status, pid, task_id) + ) + else: + cursor.execute( + "UPDATE fine_tune SET status = %s WHERE id = %s", + (status, task_id) + ) + + conn.commit() + conn.close() + except Exception as e: + logger.error(f"更新任务状态失败: {e}") + + +@fine_tune_bp.route('/stop/', methods=['POST']) +def stop_training(task_id): + """停止训练任务""" + try: + from .model_manage import get_db_connection + + # 获取进程 ID + conn = get_db_connection() + cursor = conn.cursor() + cursor.execute("SELECT process_id FROM fine_tune WHERE id = %s", (task_id,)) + result = cursor.fetchone() + conn.close() + + if result and result.get('process_id'): + pid = result['process_id'] + try: + # 尝试终止进程 + import signal + os.kill(pid, signal.SIGTERM) + logger.info(f"已终止训练进程 PID: {pid}") + except ProcessLookupError: + logger.warning(f"进程 {pid} 不存在") + except PermissionError: + logger.error(f"没有权限终止进程 {pid}") + + # 更新状态 + update_fine_tune_status(task_id, 'stopped') + + return jsonify({'code': 0, 'message': '训练任务已停止'}) + + except Exception as e: + logger.error(f"停止训练任务失败: {e}") + return jsonify({'code': 1, 'message': str(e)}) + + +@fine_tune_bp.route('/status/', methods=['GET']) +def get_training_status(task_id): + """获取训练任务状态""" + try: + from .model_manage import get_db_connection + + conn = get_db_connection() + cursor = conn.cursor() + cursor.execute( + "SELECT id, name, status, progress, process_id FROM fine_tune WHERE id = %s", + (task_id,) + ) + result = cursor.fetchone() + conn.close() + + if result: + return jsonify({ + 'code': 0, + 'data': { + 'task_id': result['id'], + 'name': result['name'], + 'status': result['status'], + 'progress': result['progress'], + 'pid': result.get('process_id') + } + }) + else: + return jsonify({'code': 1, 'message': '任务不存在'}) + + except Exception as e: + logger.error(f"获取任务状态失败: {e}") + return jsonify({'code': 1, 'message': str(e)}) + + +def get_db_connection(): + """获取数据库连接""" + import pymysql + import yaml + + PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml') + + with open(CONFIG_PATH, 'r', encoding='utf-8') as f: + CONFIG = yaml.safe_load(f) + + db_config = CONFIG['database'] + return pymysql.connect( + host=db_config['host'], + port=db_config['port'], + user=db_config['username'], + password=db_config['password'], + database=db_config['name'], + charset=db_config.get('charset', 'utf8mb4'), + cursorclass=pymysql.cursors.DictCursor + ) diff --git a/src/api/model_manage.py b/src/api/model_manage.py index f1ad21a..7dda27f 100644 --- a/src/api/model_manage.py +++ b/src/api/model_manage.py @@ -6,8 +6,12 @@ import pymysql import yaml from flask import Blueprint, request, jsonify -# 获取项目根目录 +# 获取项目根目录 - 优先使用环境变量,否则从文件路径计算 +MOUNT_BASE = os.environ.get('MOUNT_BASE', '/app/base') PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +# 如果 PROJECT_ROOT 是 /app 或 /app/src/llamafactory,则使用挂载路径 +if PROJECT_ROOT in ('/app', '/app/src/llamafactory'): + PROJECT_ROOT = MOUNT_BASE # 创建蓝图 model_manage_bp = Blueprint('model_manage', __name__, url_prefix='/api/model-manage') diff --git a/src/main.py b/src/main.py index cbd0394..f451abf 100644 --- a/src/main.py +++ b/src/main.py @@ -86,6 +86,15 @@ def setup_logger(name='app'): datefmt='%H:%M:%S' )) + # 5. 训练日志处理器 - 专门记录训练输出 + train_log_path = os.path.join(log_dir, 'train.log') + train_handler = RotatingFileHandler(train_log_path, maxBytes=100*1024*1024, backupCount=5, encoding='utf-8') + train_handler.setLevel(logging.INFO) + train_handler.setFormatter(logging.Formatter( + '[%(asctime)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + )) + # 添加处理器到 logger logger.addHandler(all_handler) logger.addHandler(error_handler) @@ -98,6 +107,13 @@ def setup_logger(name='app'): request_logger.addHandler(request_handler) request_logger.addHandler(console_handler) + # 为训练日志创建单独的 logger + train_logger = logging.getLogger('train') + train_logger.setLevel(logging.INFO) + train_logger.handlers.clear() + train_logger.addHandler(train_handler) + train_logger.addHandler(console_handler) + return logger @@ -137,6 +153,7 @@ def init_database(): id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255) NOT NULL, base_model VARCHAR(255), + template VARCHAR(100) COMMENT '训练模板,如 qwen, llama, chatglm 等', train_type VARCHAR(50), train_method VARCHAR(50), gpus JSON COMMENT 'GPU硬件选择,支持多卡训练', @@ -144,6 +161,7 @@ def init_database(): valid_split VARCHAR(50), valid_ratio INT DEFAULT 10, output_model_name VARCHAR(255), + process_id INT COMMENT '训练进程ID', status VARCHAR(50) DEFAULT 'pending', progress INT DEFAULT 0, create_time DATETIME DEFAULT CURRENT_TIMESTAMP, @@ -305,6 +323,44 @@ def init_database(): except Exception: pass # 列已存在时不输出任何信息 + # 为 fine_tune 表添加 template 列 + try: + cursor.execute("ALTER TABLE fine_tune ADD COLUMN template VARCHAR(100) COMMENT '训练模板,如 qwen, llama, chatglm 等'") + logger.debug("fine_tune 表添加 template 列成功") + except Exception: + pass # 列已存在时不输出任何信息 + + # 为 fine_tune 表添加 process_id 列 + try: + cursor.execute("ALTER TABLE fine_tune ADD COLUMN process_id INT COMMENT '训练进程ID'") + logger.debug("fine_tune 表添加 process_id 列成功") + except Exception: + pass # 列已存在时不输出任何信息 + + # 为 fine_tune 表添加训练相关列 + columns_to_add = [ + ("train_dataset_id", "INT COMMENT '训练数据集ID'"), + ("valid_dataset_id", "INT COMMENT '验证数据集ID'"), + ("eval_steps", "INT DEFAULT 100 COMMENT '评估步数'"), + ("lr_scheduler_type", "VARCHAR(50) DEFAULT 'cosine' COMMENT '学习率调度器'"), + ("warmup_ratio", "FLOAT DEFAULT 0.05 COMMENT '预热比例'"), + ("weight_decay", "FLOAT DEFAULT 0.01 COMMENT '权重衰减'"), + ("batch_size", "INT DEFAULT 1 COMMENT '批次大小'"), + ("learning_rate", "FLOAT DEFAULT 0.0001 COMMENT '学习率'"), + ("n_epochs", "FLOAT DEFAULT 1.0 COMMENT '训练轮数'"), + ("max_length", "INT DEFAULT 512 COMMENT '最大长度'"), + ("lora_alpha", "VARCHAR(10) DEFAULT '32' COMMENT 'LoRA alpha'"), + ("lora_rank", "VARCHAR(10) DEFAULT '8' COMMENT 'LoRA rank'"), + ("lora_dropout", "FLOAT DEFAULT 0.1 COMMENT 'LoRA dropout'"), + ("valid_ratio", "INT DEFAULT 10 COMMENT '验证集比例'"), + ] + for col_name, col_def in columns_to_add: + try: + cursor.execute(f"ALTER TABLE fine_tune ADD COLUMN {col_name} {col_def}") + logger.debug(f"fine_tune 表添加 {col_name} 列成功") + except Exception: + pass # 列已存在时不输出任何信息 + # 插入默认管理员用户 cursor.execute("SELECT * FROM users WHERE username = 'admin'") if not cursor.fetchone(): @@ -323,8 +379,8 @@ def init_database(): app = Flask(__name__) app.config['SECRET_KEY'] = CONFIG['secret_key'] app.config['CORS_HEADERS'] = 'Content-Type' -# 使用字符串形式的 origins -CORS(app, origins="*", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["Content-Type", "Authorization"], supports_credentials=False) +# 允许所有来源 +CORS(app, resources={r"/api/*": {"origins": "*"}}, methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["Content-Type", "Authorization"]) # 注册蓝图 register_blueprints(app) diff --git a/web/pages/fine-tune-create.html b/web/pages/fine-tune-create.html index d6f6361..79920fc 100644 --- a/web/pages/fine-tune-create.html +++ b/web/pages/fine-tune-create.html @@ -293,6 +293,92 @@ + +
+ + +

选择与您的模型匹配的对话模板,确保训练数据格式正确

+
+
@@ -512,7 +598,7 @@ *训练集
-
- -
- -
- - -
-
- 从当前训练集随机分割 - - % 作为验证集 -
- -
@@ -571,14 +625,19 @@ - -
-
- 模型加密 - 安全升级 + +
+
+ 训练命令预览 + +
+
+
请选择完整配置后查看预览命令
-

为保障您的数据安全,平台会为导出的模型文件开启 OSS 服务端加密

+
@@ -591,11 +650,6 @@ 取消
- @@ -648,6 +702,9 @@ // 加载GPU列表 loadGPUList(); + // 初始化训练命令预览 + initCommandPreview(); + // 设置侧边栏当前页高亮 const currentPage = 'fine-tune'; document.querySelectorAll('.nav-link').forEach(link => { @@ -683,20 +740,6 @@ } } - // 切换验证集切分方式 - function toggleValidSplit() { - const validSplit = document.querySelector('input[name="valid_split"]:checked').value; - const autoSection = document.getElementById('autoSplitSection'); - const customSection = document.getElementById('customSplitSection'); - if (validSplit === 'auto') { - autoSection.classList.remove('hidden'); - customSection.classList.add('hidden'); - } else { - autoSection.classList.add('hidden'); - customSection.classList.remove('hidden'); - } - } - // 切换训练方法 - 显示/隐藏LoRA参数 function toggleTrainMethod() { const trainMethod = document.querySelector('input[name="train_method"]:checked').value; @@ -782,12 +825,6 @@ trainSelect.innerHTML = '' + result.data.map(d => ``).join(''); } - // 更新验证集下拉框 - const validSelect = document.getElementById('validDatasetSelect'); - if (validSelect) { - validSelect.innerHTML = '' + - result.data.map(d => ``).join(''); - } } } catch (e) { console.error('加载数据集失败:', e); @@ -968,22 +1005,35 @@ async function submitForm() { const form = document.getElementById('createForm'); const formData = new FormData(form); - const validSplit = formData.get('valid_split'); // 获取选中的GPU const selectedGPUs = getSelectedGPUs(); + // 收集训练参数 + const trainParams = { + batch_size: parseInt(formData.get('batch_size')) || 1, + learning_rate: parseFloat(formData.get('learning_rate')) || 0.0001, + n_epochs: parseFloat(formData.get('n_epochs')) || 1.0, + eval_steps: parseInt(formData.get('eval_steps')) || 100, + lr_scheduler_type: formData.get('lr_scheduler_type') || 'cosine', + max_length: parseInt(formData.get('max_length')) || 512, + warmup_ratio: parseFloat(formData.get('warmup_ratio')) || 0.05, + weight_decay: parseFloat(formData.get('weight_decay')) || 0.01, + lora_alpha: formData.get('lora_alpha') || '32', + lora_dropout: parseFloat(formData.get('lora_dropout')) || 0.1, + lora_rank: formData.get('lora_rank') || '8' + }; + const data = { name: formData.get('name'), base_model: formData.get('base_model'), + template: formData.get('template'), train_type: formData.get('train_type'), train_method: formData.get('train_method'), - gpus: selectedGPUs, // 添加GPU选择 + gpus: selectedGPUs, train_dataset_id: formData.get('train_dataset_id'), - valid_split: validSplit, - valid_ratio: parseInt(formData.get('valid_ratio')) || 10, - valid_dataset_id: validSplit === 'custom' ? formData.get('valid_dataset_id') : null, output_model_name: formData.get('output_model_name'), + ...trainParams, status: 'pending', progress: 0 }; @@ -1000,33 +1050,201 @@ showMessage('提示', '请选择基础模型', 'warning'); return; } + if (!data.template) { + showMessage('提示', '请选择训练模板', 'warning'); + return; + } if (!data.train_dataset_id) { showMessage('提示', '请选择训练集', 'warning'); return; } - if (validSplit === 'custom' && !data.valid_dataset_id) { - showMessage('提示', '请选择验证集', 'warning'); - return; - } try { - const response = await fetch(`${API_BASE}/fine-tune`, { + // 第一步:创建训练任务记录 + const createResponse = await fetch(`${API_BASE}/fine-tune`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify(data) }); - const result = await response.json(); - if (result.code === 0) { - showMessage('成功', '创建成功!', 'success', () => { + const createResult = await createResponse.json(); + if (createResult.code !== 0) { + showMessage('错误', createResult.message || '创建任务失败', 'error'); + return; + } + + const taskId = createResult.id; + + // 第二步:启动训练 + const startData = { + task_id: taskId, + base_model: data.base_model, + template: data.template, + train_type: data.train_type, + train_method: data.train_method, + train_dataset_id: data.train_dataset_id, + output_model_name: data.output_model_name, + ...trainParams + }; + + const startResponse = await fetch(`${API_BASE}/fine-tune/start`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(startData) + }); + const startResult = await startResponse.json(); + + if (startResult.code === 0) { + const cmd = startResult.data?.command || ''; + showMessage('成功', `训练任务已启动!

${cmd}`, 'success', () => { window.location.href = 'main.html'; }); } else { - showMessage('错误', result.message || '创建失败', 'error'); + // 更新任务状态为失败 + await fetch(`${API_BASE}/fine-tune/${taskId}`, { + method: 'PUT', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ status: 'failed' }) + }); + showMessage('错误', startResult.message || '启动训练失败', 'error'); } } catch (error) { - showMessage('错误', '创建失败: ' + error.message, 'error'); + showMessage('错误', '操作失败: ' + error.message, 'error'); } } + + // 生成训练命令预览 + function buildCommandPreview() { + const form = document.getElementById('createForm'); + const formData = new FormData(form); + + // 获取选中的GPU + const selectedGPUs = getSelectedGPUs(); + let gpuIds = '0'; + if (selectedGPUs.length > 0) { + gpuIds = selectedGPUs.map(g => g.id.replace('gpu', '')).filter(g => /^\d+$/.test(g)).join(','); + } + + // 获取模型路径 + const baseModelSelect = form.querySelector('select[name="base_model"]'); + let modelPath = formData.get('base_model') || ''; + if (baseModelSelect && baseModelSelect.selectedOptions.length > 0) { + const selectedOption = baseModelSelect.selectedOptions[0]; + const pathValue = selectedOption.getAttribute('data-path'); + if (pathValue) { + modelPath = pathValue; + } + } + + // 获取模板 + const template = formData.get('template') || 'qwen3'; + + // 获取训练类型 + const trainType = formData.get('train_type') || 'SFT'; + const stageMap = { 'SFT': 'sft', 'DPO': 'dpo', 'CPT': 'cpt' }; + + // 获取训练方法 + const trainMethod = formData.get('train_method') || 'lora'; + const methodMap = { 'lora': 'lora', 'full': 'full' }; + + // 获取输出模型名称 + const outputModelName = formData.get('output_model_name') || `${template}/${trainMethod}`; + const outputDir = outputModelName.startsWith('./') ? outputModelName : `./saves/${outputModelName}`; + + // 获取数据集名称 + const trainDatasetSelect = form.querySelector('select[name="train_dataset_id"]'); + let datasetName = formData.get('train_dataset_id') || 'dataset_name'; + if (trainDatasetSelect && trainDatasetSelect.selectedOptions.length > 0) { + const selectedOption = trainDatasetSelect.selectedOptions[0]; + const datasetValue = selectedOption.getAttribute('data-name'); + if (datasetValue) { + datasetName = datasetValue; + } + } + + // 获取训练参数 + const batchSize = parseInt(formData.get('batch_size')) || 1; + const learningRate = parseFloat(formData.get('learning_rate')) || 0.0001; + const nEpochs = parseFloat(formData.get('n_epochs')) || 1.0; + const maxLength = parseInt(formData.get('max_length')) || 512; + const warmupSteps = parseInt(formData.get('warmup_steps')) || 20; + const evalSteps = parseInt(formData.get('eval_steps')) || 100; + const gradientAccumulationSteps = parseInt(formData.get('gradient_accumulation_steps')) || 8; + const lrSchedulerType = formData.get('lr_scheduler_type') || 'cosine'; + + // LoRA参数 + const loraAlpha = formData.get('lora_alpha') || '32'; + const loraDropout = parseFloat(formData.get('lora_dropout')) || 0.1; + const loraRank = formData.get('lora_rank') || '8'; + + // 构建命令 + let cmd = `CUDA_VISIBLE_DEVICES=${gpuIds} llamafactory-cli train \\\n`; + cmd += ` --stage ${stageMap[trainType] || 'sft'} \\\n`; + cmd += ` --do_train \\\n`; + cmd += ` --model_name_or_path ${modelPath} \\\n`; + cmd += ` --dataset ${datasetName} \\\n`; + cmd += ` --dataset_dir ./datasets \\\n`; + cmd += ` --template ${template} \\\n`; + cmd += ` --finetuning_type ${methodMap[trainMethod] || 'lora'} \\\n`; + + // LoRA参数(仅lora方法时显示) + if (trainMethod === 'lora') { + cmd += ` --lora_alpha ${loraAlpha} \\\n`; + cmd += ` --lora_dropout ${loraDropout} \\\n`; + cmd += ` --lora_rank ${loraRank} \\\n`; + } + + cmd += ` --output_dir ${outputDir} \\\n`; + cmd += ` --overwrite_cache \\\n`; + cmd += ` --overwrite_output_dir \\\n`; + cmd += ` --cutoff_len ${maxLength} \\\n`; + cmd += ` --preprocessing_num_workers 16 \\\n`; + cmd += ` --per_device_train_batch_size ${batchSize} \\\n`; + cmd += ` --per_device_eval_batch_size 1 \\\n`; + cmd += ` --gradient_accumulation_steps ${gradientAccumulationSteps} \\\n`; + cmd += ` --lr_scheduler_type ${lrSchedulerType} \\\n`; + cmd += ` --logging_steps 50 \\\n`; + cmd += ` --warmup_steps ${warmupSteps} \\\n`; + cmd += ` --save_steps 100 \\\n`; + cmd += ` --eval_steps ${evalSteps} \\\n`; + cmd += ` --learning_rate ${learningRate} \\\n`; + cmd += ` --num_train_epochs ${nEpochs}`; + + return cmd; + } + + // 更新命令预览 + function updateCommandPreview() { + const preview = document.getElementById('commandPreview'); + const cmd = buildCommandPreview(); + preview.textContent = cmd; + } + + // 监听表单变化自动更新预览 + function initCommandPreview() { + const form = document.getElementById('createForm'); + + // 监听所有 input 和 select 的变化 + const inputs = form.querySelectorAll('input, select'); + inputs.forEach(input => { + input.addEventListener('change', () => setTimeout(updateCommandPreview, 100)); + if (input.type === 'text' || input.type === 'number') { + input.addEventListener('input', () => setTimeout(updateCommandPreview, 100)); + } + }); + + // 监听卡片式单选框的点击事件 (训练类型、训练方法) + document.querySelectorAll('.card-radio').forEach(card => { + card.addEventListener('click', () => setTimeout(updateCommandPreview, 100)); + }); + + // 监听 GPU 卡片的点击事件 + document.querySelectorAll('.gpu-card').forEach(card => { + card.addEventListener('click', () => setTimeout(updateCommandPreview, 100)); + }); + + // 初始化时更新一次 + setTimeout(updateCommandPreview, 500); + } diff --git a/web/pages/main.html b/web/pages/main.html index a889fb1..84f0a75 100644 --- a/web/pages/main.html +++ b/web/pages/main.html @@ -417,9 +417,9 @@ } } - // 页面加载时获取监控数据,并每5秒刷新 + // 页面加载时获取监控数据,并每30秒刷新 fetchSystemMetrics(); - setInterval(fetchSystemMetrics, 5000); + setInterval(fetchSystemMetrics, 30000); // 各功能模块的表格配置 const tableConfigs = { diff --git a/web/pages/model-manage-create.html b/web/pages/model-manage-create.html index 6171b3a..22de632 100644 --- a/web/pages/model-manage-create.html +++ b/web/pages/model-manage-create.html @@ -517,7 +517,7 @@ if (select.options.length > 1) return; try { - const response = await fetch(`${API_BASE}/local-models`); + const response = await fetch(`${API_BASE}/model-manage/local-models`); const result = await response.json(); if (result.code === 0 && result.data && result.data.models) {