模型微调已经调通
增加了参数预览
This commit is contained in:
@@ -6,6 +6,7 @@ from .model_manage import model_manage_bp
|
||||
from .model_chat import model_chat_bp
|
||||
from .dimension import dimension_bp
|
||||
from .logs import logs_bp
|
||||
from .fine_tune import fine_tune_bp
|
||||
|
||||
# 注册所有蓝图
|
||||
def register_blueprints(app):
|
||||
@@ -15,3 +16,4 @@ def register_blueprints(app):
|
||||
app.register_blueprint(model_chat_bp)
|
||||
app.register_blueprint(dimension_bp)
|
||||
app.register_blueprint(logs_bp)
|
||||
app.register_blueprint(fine_tune_bp)
|
||||
|
||||
@@ -5,6 +5,7 @@ import io
|
||||
import os
|
||||
import time
|
||||
import zipfile
|
||||
import json
|
||||
from flask import Blueprint, request, jsonify, send_from_directory, Response
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
@@ -52,6 +53,45 @@ def allowed_file(filename):
|
||||
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
||||
|
||||
|
||||
def update_dataset_info_json(dataset_name=None, actual_filename=None, remove_filename=None):
|
||||
"""更新 datasets/dataset_info.json 文件
|
||||
|
||||
Args:
|
||||
dataset_name: 数据集名称(用于作为 key)
|
||||
actual_filename: 实际保存的文件名(含时间戳前缀),用于 file_name
|
||||
remove_filename: 要移除的文件名,为None表示不移除
|
||||
"""
|
||||
info_path = os.path.join(DATASET_FOLDER, 'dataset_info.json')
|
||||
|
||||
# 读取现有配置
|
||||
dataset_info = {}
|
||||
if os.path.exists(info_path):
|
||||
try:
|
||||
with open(info_path, 'r', encoding='utf-8') as f:
|
||||
dataset_info = json.load(f)
|
||||
except Exception as e:
|
||||
print(f"读取 dataset_info.json 失败: {e}")
|
||||
|
||||
# 移除旧条目(根据移除的文件名)
|
||||
if remove_filename:
|
||||
key = os.path.splitext(remove_filename)[0]
|
||||
if key in dataset_info:
|
||||
del dataset_info[key]
|
||||
print(f"从 dataset_info.json 移除: {key}")
|
||||
|
||||
# 添加新条目
|
||||
if dataset_name and actual_filename:
|
||||
dataset_info[dataset_name] = {"file_name": actual_filename}
|
||||
print(f"更新 dataset_info.json: {dataset_name} -> {actual_filename}")
|
||||
|
||||
# 写入文件
|
||||
try:
|
||||
with open(info_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(dataset_info, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
print(f"写入 dataset_info.json 失败: {e}")
|
||||
|
||||
|
||||
def generic_get_by_id(table_name, id_val):
|
||||
"""通用按ID查询"""
|
||||
conn = get_db_connection()
|
||||
@@ -149,17 +189,40 @@ def delete_dataset(id):
|
||||
"""删除数据集"""
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
# 获取文件路径列表
|
||||
cursor.execute("SELECT file_path FROM dataset_files WHERE dataset_id = %s", (id,))
|
||||
|
||||
# 获取数据集名称(用于从 dataset_info.json 移除条目)
|
||||
cursor.execute("SELECT name FROM dataset_manage WHERE id = %s", (id,))
|
||||
dataset_result = cursor.fetchone()
|
||||
dataset_name = dataset_result['name'] if dataset_result else None
|
||||
|
||||
# 获取文件信息列表(包含原始文件名)
|
||||
cursor.execute("SELECT file_name, file_path FROM dataset_files WHERE dataset_id = %s", (id,))
|
||||
files = cursor.fetchall()
|
||||
# 删除文件
|
||||
|
||||
for f in files:
|
||||
file_path = f.get('file_path')
|
||||
if file_path and os.path.exists(file_path):
|
||||
# 尝试多个可能的路径
|
||||
paths_to_try = []
|
||||
if file_path:
|
||||
paths_to_try.append(file_path)
|
||||
# 尝试 PROJECT_ROOT 相对路径
|
||||
rel_path = file_path.replace('/app/base', PROJECT_ROOT, 1) if file_path.startswith('/app/base') else None
|
||||
if rel_path:
|
||||
paths_to_try.append(rel_path)
|
||||
|
||||
for path in paths_to_try:
|
||||
if path and os.path.exists(path):
|
||||
try:
|
||||
os.remove(file_path)
|
||||
os.remove(path)
|
||||
print(f"已删除文件: {path}")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"删除文件失败: {file_path}, {e}")
|
||||
print(f"删除文件失败: {path}, {e}")
|
||||
|
||||
# 使用数据集名称从 dataset_info.json 移除条目
|
||||
if dataset_name:
|
||||
update_dataset_info_json(remove_filename=dataset_name)
|
||||
|
||||
# 删除数据库记录
|
||||
cursor.execute("DELETE FROM dataset_files WHERE dataset_id = %s", (id,))
|
||||
cursor.execute("DELETE FROM dataset_manage WHERE id = %s", (id,))
|
||||
@@ -218,6 +281,12 @@ def upload_dataset_file(dataset_id):
|
||||
file.save(file_path)
|
||||
file_size = os.path.getsize(file_path)
|
||||
|
||||
# 获取数据集名称用于作为 dataset_info.json 的 key
|
||||
dataset_name = dataset.get('name') if dataset else None
|
||||
|
||||
# 更新 dataset_info.json,使用数据集名称作为 key,实际保存的文件名作为 file_name
|
||||
update_dataset_info_json(dataset_name=dataset_name, actual_filename=new_filename)
|
||||
|
||||
# 获取文件扩展名(安全处理无扩展名的情况)
|
||||
parts = filename.rsplit('.', 1)
|
||||
ext = parts[1].lower() if len(parts) > 1 else 'unknown'
|
||||
@@ -304,8 +373,8 @@ def delete_dataset_file(file_id):
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 获取文件信息
|
||||
cursor.execute("SELECT dataset_id, file_path FROM dataset_files WHERE id = %s", (file_id,))
|
||||
# 获取文件信息(包含 dataset_id)
|
||||
cursor.execute("SELECT dataset_id, file_name, file_path FROM dataset_files WHERE id = %s", (file_id,))
|
||||
file_info = cursor.fetchone()
|
||||
|
||||
if not file_info:
|
||||
@@ -313,6 +382,13 @@ def delete_dataset_file(file_id):
|
||||
conn.close()
|
||||
return jsonify({'code': 1, 'message': '文件不存在'})
|
||||
|
||||
dataset_id = file_info['dataset_id']
|
||||
|
||||
# 获取数据集名称(用于从 dataset_info.json 移除条目)
|
||||
cursor.execute("SELECT name FROM dataset_manage WHERE id = %s", (dataset_id,))
|
||||
dataset_result = cursor.fetchone()
|
||||
dataset_name = dataset_result['name'] if dataset_result else None
|
||||
|
||||
# 删除物理文件
|
||||
file_path = file_info['file_path']
|
||||
if file_path and os.path.exists(file_path):
|
||||
@@ -324,8 +400,11 @@ def delete_dataset_file(file_id):
|
||||
# 删除数据库记录
|
||||
cursor.execute("DELETE FROM dataset_files WHERE id = %s", (file_id,))
|
||||
|
||||
# 使用数据集名称从 dataset_info.json 移除条目
|
||||
if dataset_name:
|
||||
update_dataset_info_json(remove_filename=dataset_name)
|
||||
|
||||
# 更新数据集的文件数量和大小
|
||||
dataset_id = file_info['dataset_id']
|
||||
cursor.execute("SELECT COUNT(*) as count, SUM(file_size) as total_size FROM dataset_files WHERE dataset_id = %s", (dataset_id,))
|
||||
result = cursor.fetchone()
|
||||
file_count = result['count'] or 0
|
||||
|
||||
443
src/api/fine_tune.py
Normal file
443
src/api/fine_tune.py
Normal file
@@ -0,0 +1,443 @@
|
||||
"""
|
||||
精调训练 API 路由
|
||||
调用 llamafactory-cli 执行训练任务
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from flask import Blueprint, request, jsonify
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
train_logger = logging.getLogger('train') # 专门的训练日志 logger,输出到 train.log
|
||||
|
||||
# 创建蓝图
|
||||
fine_tune_bp = Blueprint('fine_tune', __name__, url_prefix='/api/fine-tune')
|
||||
|
||||
|
||||
# 训练类型映射
|
||||
TRAIN_TYPE_MAP = {
|
||||
'SFT': 'sft',
|
||||
'DPO': 'dpo',
|
||||
'CPT': 'cpt'
|
||||
}
|
||||
|
||||
# 训练方法映射
|
||||
FINETUNING_TYPE_MAP = {
|
||||
'lora': 'lora',
|
||||
'full': 'full'
|
||||
}
|
||||
|
||||
|
||||
@fine_tune_bp.route('/start', methods=['POST'])
|
||||
def start_training():
|
||||
"""启动 llamafactory 训练任务"""
|
||||
try:
|
||||
data = request.json
|
||||
train_logger.info(f"[TRAIN] ========== 开始训练任务 ==========")
|
||||
train_logger.info(f"[TRAIN] 收到启动训练请求: base_model={data.get('base_model')}, train_dataset_id={data.get('train_dataset_id')}")
|
||||
|
||||
# 必填参数验证
|
||||
required_fields = ['base_model', 'template', 'train_dataset_id']
|
||||
for field in required_fields:
|
||||
if not data.get(field):
|
||||
return jsonify({'code': 1, 'message': f'缺少必要参数: {field}'})
|
||||
|
||||
# 获取模型信息
|
||||
model_path = data.get('base_model')
|
||||
# 尝试转换为整数
|
||||
try:
|
||||
model_id = int(model_path) if str(model_path).isdigit() else None
|
||||
except (ValueError, TypeError):
|
||||
model_id = None
|
||||
|
||||
if model_id:
|
||||
# 如果是 model_id,需要获取模型路径
|
||||
from .model_manage import get_db_connection
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT id, name, path FROM model_manage WHERE id = %s", (model_id,))
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
logger.info(f"模型查询结果: {result}")
|
||||
if result and result.get('path'):
|
||||
model_path = result['path']
|
||||
logger.info(f"从数据库获取的模型路径: {model_path}")
|
||||
else:
|
||||
return jsonify({'code': 1, 'message': '模型不存在或路径为空'})
|
||||
elif not model_path:
|
||||
return jsonify({'code': 1, 'message': f'模型路径为空'})
|
||||
|
||||
train_logger.info(f"[TRAIN] 模型路径: {model_path}")
|
||||
|
||||
# 设置工作目录为 llamafactory 目录
|
||||
llamafactory_dir = '/app/src/llamafactory'
|
||||
|
||||
# 处理数据集文件:将数据集复制到 llamafactory 的 datasets 目录
|
||||
dataset_id = data.get('train_dataset_id')
|
||||
try:
|
||||
dataset_id_int = int(dataset_id) if str(dataset_id).isdigit() else None
|
||||
except (ValueError, TypeError):
|
||||
dataset_id_int = None
|
||||
|
||||
llamafactory_datasets_dir = os.path.join(llamafactory_dir, 'datasets')
|
||||
os.makedirs(llamafactory_datasets_dir, exist_ok=True)
|
||||
|
||||
# 获取数据集名称(用于 --dataset 参数)
|
||||
dataset_key = None
|
||||
if dataset_id_int:
|
||||
from .datasets import get_db_connection as get_dataset_conn
|
||||
conn = get_dataset_conn()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT dm.name FROM dataset_manage dm WHERE dm.id = %s", (dataset_id_int,))
|
||||
dataset_result = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
dataset_key = dataset_result['name'] if dataset_result else None
|
||||
|
||||
if dataset_key:
|
||||
# 从 dataset_info.json 读取实际文件名
|
||||
src_info_json = os.path.join('/app/base', 'datasets', 'dataset_info.json')
|
||||
actual_file_name = None
|
||||
if os.path.exists(src_info_json):
|
||||
import json as json_lib
|
||||
with open(src_info_json, 'r', encoding='utf-8') as f:
|
||||
dataset_info = json_lib.load(f)
|
||||
if dataset_key in dataset_info:
|
||||
actual_file_name = dataset_info[dataset_key].get('file_name')
|
||||
train_logger.info(f"[TRAIN] 从 dataset_info.json 获取文件名: {dataset_key} -> {actual_file_name}")
|
||||
|
||||
# 复制数据集文件到 llamafactory 目录
|
||||
if actual_file_name:
|
||||
src_file = os.path.join('/app/base', 'datasets', actual_file_name)
|
||||
dst_file = os.path.join(llamafactory_datasets_dir, actual_file_name)
|
||||
if os.path.exists(src_file):
|
||||
import shutil
|
||||
shutil.copy2(src_file, dst_file)
|
||||
train_logger.info(f"[TRAIN] 复制数据集文件: {src_file} -> {dst_file}")
|
||||
else:
|
||||
train_logger.warning(f"[TRAIN] 数据集文件不存在: {src_file}")
|
||||
|
||||
# 复制 dataset_info.json 到 llamafactory datasets 目录
|
||||
src_info_json = os.path.join('/app/base', 'datasets', 'dataset_info.json')
|
||||
dst_info_json = os.path.join(llamafactory_datasets_dir, 'dataset_info.json')
|
||||
try:
|
||||
if os.path.exists(src_info_json):
|
||||
shutil.copy2(src_info_json, dst_info_json)
|
||||
train_logger.info(f"[TRAIN] 已复制 dataset_info.json 到 llamafactory 目录")
|
||||
else:
|
||||
train_logger.warning(f"[TRAIN] dataset_info.json 不存在: {src_info_json}")
|
||||
except Exception as e:
|
||||
train_logger.warning(f"[TRAIN] 复制 dataset_info.json 失败: {e}")
|
||||
|
||||
# 获取选中的 GPU 索引
|
||||
gpus = data.get('gpus', [])
|
||||
if gpus:
|
||||
gpu_ids = [gpu.get('id', '').replace('gpu', '') for gpu in gpus]
|
||||
gpu_ids = [g for g in gpu_ids if g.isdigit()]
|
||||
cuda_devices = ','.join(gpu_ids)
|
||||
else:
|
||||
cuda_devices = '0'
|
||||
|
||||
# 设置环境变量
|
||||
env = os.environ.copy()
|
||||
env['CUDA_VISIBLE_DEVICES'] = cuda_devices
|
||||
env['TF_CPP_MIN_LOG_LEVEL'] = '2' # 减少 TensorFlow 日志
|
||||
|
||||
# 构建 llamafactory-cli 命令(传入数据集名称用于 --dataset 参数)
|
||||
cmd = build_train_command(data, model_path, dataset_key)
|
||||
cmd_str = ' '.join(cmd)
|
||||
train_logger.info(f"[TRAIN] 执行训练命令: {cmd_str}")
|
||||
|
||||
# 在返回的命令中显示 GPU 配置
|
||||
cmd_str_with_gpu = f"CUDA_VISIBLE_DEVICES={cuda_devices} {cmd_str}"
|
||||
|
||||
# 生成训练日志文件路径(按日期分目录)
|
||||
from datetime import datetime
|
||||
today = datetime.now().strftime('%Y-%m-%d')
|
||||
task_id_str = str(data.get('task_id', 'unknown'))
|
||||
log_dir = os.path.join(llamafactory_dir, 'logs', today)
|
||||
train_output_log = os.path.join(log_dir, f'train_{task_id_str}.log')
|
||||
|
||||
# 确保日志目录存在
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
train_logger.info(f"[TRAIN] 启动训练进程...")
|
||||
|
||||
# 使用线程在后台运行训练进程
|
||||
def run_training():
|
||||
with open(train_output_log, 'w', encoding='utf-8') as log_file:
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=llamafactory_dir,
|
||||
stdout=log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
env=env
|
||||
)
|
||||
train_logger.info(f"[TRAIN] 训练进程 PID: {process.pid}")
|
||||
# 等待进程完成
|
||||
process.wait()
|
||||
train_logger.info(f"[TRAIN] 训练进程已结束,退出码: {process.returncode}")
|
||||
|
||||
# 更新任务状态
|
||||
final_status = 'completed' if process.returncode == 0 else 'failed'
|
||||
update_fine_tune_status(data.get('task_id'), final_status, process.pid)
|
||||
|
||||
# 启动后台线程
|
||||
training_thread = threading.Thread(target=run_training, daemon=True)
|
||||
training_thread.start()
|
||||
|
||||
# 立即返回,不等待进程完成
|
||||
pid = None # 此时还不知道实际 PID,稍后可从日志获取
|
||||
train_logger.info(f"[TRAIN] 训练任务已在后台启动")
|
||||
train_logger.info(f"[TRAIN] 训练日志输出到: {train_output_log}")
|
||||
|
||||
# 更新任务状态为运行中
|
||||
update_fine_tune_status(data.get('task_id'), 'running', 0)
|
||||
|
||||
return jsonify({
|
||||
'code': 0,
|
||||
'message': f'训练任务已启动 (GPU: {cuda_devices})',
|
||||
'data': {
|
||||
'task_id': data.get('task_id'),
|
||||
'gpu_ids': cuda_devices,
|
||||
'command': cmd_str_with_gpu,
|
||||
'log_file': train_output_log
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
train_logger.error(f"[TRAIN] 启动训练任务失败: {e}")
|
||||
train_logger.error(f"[TRAIN] 详细错误: {traceback.format_exc()}")
|
||||
return jsonify({'code': 1, 'message': str(e)})
|
||||
|
||||
|
||||
def build_train_command(data, model_path, dataset_name=None):
|
||||
"""构建 llamafactory-cli train 命令"""
|
||||
# llamafactory-cli 路径(已在系统 PATH 中)
|
||||
cmd = ['llamafactory-cli', 'train']
|
||||
|
||||
# 训练阶段
|
||||
train_type = data.get('train_type', 'SFT')
|
||||
cmd.extend(['--stage', TRAIN_TYPE_MAP.get(train_type, 'sft')])
|
||||
cmd.append('--do_train')
|
||||
|
||||
# 模型路径
|
||||
cmd.extend(['--model_name_or_path', model_path])
|
||||
|
||||
# 数据集 - 使用数据集名称(dataset_manage.name),不是实际文件名
|
||||
if dataset_name:
|
||||
cmd.extend(['--dataset', dataset_name])
|
||||
train_logger.info(f"[TRAIN] 使用数据集名称: {dataset_name}")
|
||||
else:
|
||||
# 回退到原有逻辑
|
||||
dataset_id = data.get('train_dataset_id')
|
||||
try:
|
||||
dataset_id_int = int(dataset_id) if str(dataset_id).isdigit() else None
|
||||
except (ValueError, TypeError):
|
||||
dataset_id_int = None
|
||||
|
||||
if dataset_id_int:
|
||||
dataset_name = get_dataset_name(dataset_id_int)
|
||||
train_logger.info(f"[TRAIN] 从数据库获取的数据集名称: {dataset_name}")
|
||||
else:
|
||||
dataset_name = dataset_id
|
||||
cmd.extend(['--dataset', dataset_name])
|
||||
|
||||
# 数据集目录
|
||||
cmd.extend(['--dataset_dir', './datasets']) # llamafactory 工作目录下的 datasets 目录
|
||||
|
||||
# 模板
|
||||
template = data.get('template')
|
||||
cmd.extend(['--template', template])
|
||||
|
||||
# 训练方法
|
||||
train_method = data.get('train_method', 'lora')
|
||||
cmd.extend(['--finetuning_type', FINETUNING_TYPE_MAP.get(train_method, 'lora')])
|
||||
|
||||
# 输出目录
|
||||
output_dir = data.get('output_model_name', f"./saves/{template}/{train_method}")
|
||||
if not output_dir.startswith('./'):
|
||||
output_dir = f"./saves/{output_dir}"
|
||||
cmd.extend(['--output_dir', output_dir])
|
||||
|
||||
# 常用参数
|
||||
cmd.extend([
|
||||
'--overwrite_cache',
|
||||
'--overwrite_output_dir',
|
||||
'--cutoff_len', str(data.get('max_length', 512)),
|
||||
'--preprocessing_num_workers', '16',
|
||||
'--per_device_train_batch_size', str(data.get('batch_size', 1)),
|
||||
'--per_device_eval_batch_size', '1',
|
||||
'--gradient_accumulation_steps', str(data.get('gradient_accumulation_steps', 8)),
|
||||
'--lr_scheduler_type', data.get('lr_scheduler_type', 'cosine'),
|
||||
'--logging_steps', '50',
|
||||
'--warmup_steps', str(data.get('warmup_steps', 20)),
|
||||
'--save_steps', '100',
|
||||
'--eval_steps', str(data.get('eval_steps', 100)),
|
||||
])
|
||||
|
||||
# 学习率
|
||||
cmd.extend(['--learning_rate', str(data.get('learning_rate', 0.0001))])
|
||||
|
||||
# 训练轮数
|
||||
cmd.extend(['--num_train_epochs', str(data.get('n_epochs', 1.0))])
|
||||
|
||||
# 验证集比例
|
||||
val_ratio = data.get('valid_ratio', 0)
|
||||
if val_ratio > 0:
|
||||
cmd.extend(['--val_size', str(val_ratio / 100)])
|
||||
|
||||
# 最大样本数
|
||||
if data.get('max_samples'):
|
||||
cmd.extend(['--max_samples', str(data.get('max_samples'))])
|
||||
|
||||
# 其他选项
|
||||
if data.get('plot_loss'):
|
||||
cmd.append('--plot_loss')
|
||||
|
||||
if data.get('fp16'):
|
||||
cmd.append('--fp16')
|
||||
|
||||
if data.get('load_best_model_at_end'):
|
||||
cmd.append('--load_best_model_at_end')
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
def get_dataset_name(dataset_id):
|
||||
"""根据数据集 ID 获取数据集名称"""
|
||||
try:
|
||||
from .datasets import get_db_connection
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT id, name FROM dataset_manage WHERE id = %s", (dataset_id,))
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
logger.info(f"数据集查询结果: {result}")
|
||||
if result and result.get('name'):
|
||||
return result['name']
|
||||
logger.warning(f"未找到数据集 ID={dataset_id},使用默认值")
|
||||
return 'default'
|
||||
except Exception as e:
|
||||
logger.error(f"查询数据集失败: {e}")
|
||||
return 'default'
|
||||
|
||||
|
||||
def update_fine_tune_status(task_id, status, pid=None):
|
||||
"""更新训练任务状态"""
|
||||
try:
|
||||
from .model_manage import get_db_connection
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
if status == 'running' and pid:
|
||||
cursor.execute(
|
||||
"UPDATE fine_tune SET status = %s, process_id = %s WHERE id = %s",
|
||||
(status, pid, task_id)
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"UPDATE fine_tune SET status = %s WHERE id = %s",
|
||||
(status, task_id)
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logger.error(f"更新任务状态失败: {e}")
|
||||
|
||||
|
||||
@fine_tune_bp.route('/stop/<int:task_id>', methods=['POST'])
|
||||
def stop_training(task_id):
|
||||
"""停止训练任务"""
|
||||
try:
|
||||
from .model_manage import get_db_connection
|
||||
|
||||
# 获取进程 ID
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT process_id FROM fine_tune WHERE id = %s", (task_id,))
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
if result and result.get('process_id'):
|
||||
pid = result['process_id']
|
||||
try:
|
||||
# 尝试终止进程
|
||||
import signal
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
logger.info(f"已终止训练进程 PID: {pid}")
|
||||
except ProcessLookupError:
|
||||
logger.warning(f"进程 {pid} 不存在")
|
||||
except PermissionError:
|
||||
logger.error(f"没有权限终止进程 {pid}")
|
||||
|
||||
# 更新状态
|
||||
update_fine_tune_status(task_id, 'stopped')
|
||||
|
||||
return jsonify({'code': 0, 'message': '训练任务已停止'})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"停止训练任务失败: {e}")
|
||||
return jsonify({'code': 1, 'message': str(e)})
|
||||
|
||||
|
||||
@fine_tune_bp.route('/status/<int:task_id>', methods=['GET'])
|
||||
def get_training_status(task_id):
|
||||
"""获取训练任务状态"""
|
||||
try:
|
||||
from .model_manage import get_db_connection
|
||||
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT id, name, status, progress, process_id FROM fine_tune WHERE id = %s",
|
||||
(task_id,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
if result:
|
||||
return jsonify({
|
||||
'code': 0,
|
||||
'data': {
|
||||
'task_id': result['id'],
|
||||
'name': result['name'],
|
||||
'status': result['status'],
|
||||
'progress': result['progress'],
|
||||
'pid': result.get('process_id')
|
||||
}
|
||||
})
|
||||
else:
|
||||
return jsonify({'code': 1, 'message': '任务不存在'})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取任务状态失败: {e}")
|
||||
return jsonify({'code': 1, 'message': str(e)})
|
||||
|
||||
|
||||
def get_db_connection():
|
||||
"""获取数据库连接"""
|
||||
import pymysql
|
||||
import yaml
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
|
||||
|
||||
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||
CONFIG = yaml.safe_load(f)
|
||||
|
||||
db_config = CONFIG['database']
|
||||
return pymysql.connect(
|
||||
host=db_config['host'],
|
||||
port=db_config['port'],
|
||||
user=db_config['username'],
|
||||
password=db_config['password'],
|
||||
database=db_config['name'],
|
||||
charset=db_config.get('charset', 'utf8mb4'),
|
||||
cursorclass=pymysql.cursors.DictCursor
|
||||
)
|
||||
@@ -6,8 +6,12 @@ import pymysql
|
||||
import yaml
|
||||
from flask import Blueprint, request, jsonify
|
||||
|
||||
# 获取项目根目录
|
||||
# 获取项目根目录 - 优先使用环境变量,否则从文件路径计算
|
||||
MOUNT_BASE = os.environ.get('MOUNT_BASE', '/app/base')
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
# 如果 PROJECT_ROOT 是 /app 或 /app/src/llamafactory,则使用挂载路径
|
||||
if PROJECT_ROOT in ('/app', '/app/src/llamafactory'):
|
||||
PROJECT_ROOT = MOUNT_BASE
|
||||
|
||||
# 创建蓝图
|
||||
model_manage_bp = Blueprint('model_manage', __name__, url_prefix='/api/model-manage')
|
||||
|
||||
60
src/main.py
60
src/main.py
@@ -86,6 +86,15 @@ def setup_logger(name='app'):
|
||||
datefmt='%H:%M:%S'
|
||||
))
|
||||
|
||||
# 5. 训练日志处理器 - 专门记录训练输出
|
||||
train_log_path = os.path.join(log_dir, 'train.log')
|
||||
train_handler = RotatingFileHandler(train_log_path, maxBytes=100*1024*1024, backupCount=5, encoding='utf-8')
|
||||
train_handler.setLevel(logging.INFO)
|
||||
train_handler.setFormatter(logging.Formatter(
|
||||
'[%(asctime)s] %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
))
|
||||
|
||||
# 添加处理器到 logger
|
||||
logger.addHandler(all_handler)
|
||||
logger.addHandler(error_handler)
|
||||
@@ -98,6 +107,13 @@ def setup_logger(name='app'):
|
||||
request_logger.addHandler(request_handler)
|
||||
request_logger.addHandler(console_handler)
|
||||
|
||||
# 为训练日志创建单独的 logger
|
||||
train_logger = logging.getLogger('train')
|
||||
train_logger.setLevel(logging.INFO)
|
||||
train_logger.handlers.clear()
|
||||
train_logger.addHandler(train_handler)
|
||||
train_logger.addHandler(console_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
@@ -137,6 +153,7 @@ def init_database():
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
base_model VARCHAR(255),
|
||||
template VARCHAR(100) COMMENT '训练模板,如 qwen, llama, chatglm 等',
|
||||
train_type VARCHAR(50),
|
||||
train_method VARCHAR(50),
|
||||
gpus JSON COMMENT 'GPU硬件选择,支持多卡训练',
|
||||
@@ -144,6 +161,7 @@ def init_database():
|
||||
valid_split VARCHAR(50),
|
||||
valid_ratio INT DEFAULT 10,
|
||||
output_model_name VARCHAR(255),
|
||||
process_id INT COMMENT '训练进程ID',
|
||||
status VARCHAR(50) DEFAULT 'pending',
|
||||
progress INT DEFAULT 0,
|
||||
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
@@ -305,6 +323,44 @@ def init_database():
|
||||
except Exception:
|
||||
pass # 列已存在时不输出任何信息
|
||||
|
||||
# 为 fine_tune 表添加 template 列
|
||||
try:
|
||||
cursor.execute("ALTER TABLE fine_tune ADD COLUMN template VARCHAR(100) COMMENT '训练模板,如 qwen, llama, chatglm 等'")
|
||||
logger.debug("fine_tune 表添加 template 列成功")
|
||||
except Exception:
|
||||
pass # 列已存在时不输出任何信息
|
||||
|
||||
# 为 fine_tune 表添加 process_id 列
|
||||
try:
|
||||
cursor.execute("ALTER TABLE fine_tune ADD COLUMN process_id INT COMMENT '训练进程ID'")
|
||||
logger.debug("fine_tune 表添加 process_id 列成功")
|
||||
except Exception:
|
||||
pass # 列已存在时不输出任何信息
|
||||
|
||||
# 为 fine_tune 表添加训练相关列
|
||||
columns_to_add = [
|
||||
("train_dataset_id", "INT COMMENT '训练数据集ID'"),
|
||||
("valid_dataset_id", "INT COMMENT '验证数据集ID'"),
|
||||
("eval_steps", "INT DEFAULT 100 COMMENT '评估步数'"),
|
||||
("lr_scheduler_type", "VARCHAR(50) DEFAULT 'cosine' COMMENT '学习率调度器'"),
|
||||
("warmup_ratio", "FLOAT DEFAULT 0.05 COMMENT '预热比例'"),
|
||||
("weight_decay", "FLOAT DEFAULT 0.01 COMMENT '权重衰减'"),
|
||||
("batch_size", "INT DEFAULT 1 COMMENT '批次大小'"),
|
||||
("learning_rate", "FLOAT DEFAULT 0.0001 COMMENT '学习率'"),
|
||||
("n_epochs", "FLOAT DEFAULT 1.0 COMMENT '训练轮数'"),
|
||||
("max_length", "INT DEFAULT 512 COMMENT '最大长度'"),
|
||||
("lora_alpha", "VARCHAR(10) DEFAULT '32' COMMENT 'LoRA alpha'"),
|
||||
("lora_rank", "VARCHAR(10) DEFAULT '8' COMMENT 'LoRA rank'"),
|
||||
("lora_dropout", "FLOAT DEFAULT 0.1 COMMENT 'LoRA dropout'"),
|
||||
("valid_ratio", "INT DEFAULT 10 COMMENT '验证集比例'"),
|
||||
]
|
||||
for col_name, col_def in columns_to_add:
|
||||
try:
|
||||
cursor.execute(f"ALTER TABLE fine_tune ADD COLUMN {col_name} {col_def}")
|
||||
logger.debug(f"fine_tune 表添加 {col_name} 列成功")
|
||||
except Exception:
|
||||
pass # 列已存在时不输出任何信息
|
||||
|
||||
# 插入默认管理员用户
|
||||
cursor.execute("SELECT * FROM users WHERE username = 'admin'")
|
||||
if not cursor.fetchone():
|
||||
@@ -323,8 +379,8 @@ def init_database():
|
||||
app = Flask(__name__)
|
||||
app.config['SECRET_KEY'] = CONFIG['secret_key']
|
||||
app.config['CORS_HEADERS'] = 'Content-Type'
|
||||
# 使用字符串形式的 origins
|
||||
CORS(app, origins="*", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["Content-Type", "Authorization"], supports_credentials=False)
|
||||
# 允许所有来源
|
||||
CORS(app, resources={r"/api/*": {"origins": "*"}}, methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["Content-Type", "Authorization"])
|
||||
|
||||
# 注册蓝图
|
||||
register_blueprints(app)
|
||||
|
||||
@@ -293,6 +293,92 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 选择训练模板 -->
|
||||
<div class="mb-6">
|
||||
<label class="block text-sm text-gray-600 mb-3">
|
||||
<span class="text-red-500 mr-1">*</span>训练模板
|
||||
</label>
|
||||
<select name="template" id="templateSelect" class="form-select flex-1 max-w-md">
|
||||
<optgroup label="Qwen 系列">
|
||||
<option value="qwen">qwen (Qwen/Qwen2)</option>
|
||||
<option value="qwen3">qwen3 (Qwen3)</option>
|
||||
<option value="qwen3_nothink">qwen3_nothink (Qwen3-Thinking)</option>
|
||||
<option value="qwen2_vl">qwen2_vl (Qwen2-VL)</option>
|
||||
<option value="qwen3_vl">qwen3_vl (Qwen3-VL)</option>
|
||||
<option value="qwen2_audio">qwen2_audio (Qwen2-Audio)</option>
|
||||
<option value="qwen2_omni">qwen2_omni (Qwen2.5-Omni)</option>
|
||||
<option value="qwen3_omni">qwen3_omni (Qwen3-Omni)</option>
|
||||
</optgroup>
|
||||
<optgroup label="LLaMA 系列">
|
||||
<option value="llama">llama (LLaMA)</option>
|
||||
<option value="llama2">llama2 (LLaMA 2)</option>
|
||||
<option value="llama3">llama3 (LLaMA 3/3.3)</option>
|
||||
<option value="llama4">llama4 (LLaMA 4)</option>
|
||||
<option value="mllama">mllama (LLaMA 3.2 Vision)</option>
|
||||
<option value="llava">llava (LLaVA-1.5)</option>
|
||||
<option value="llava_next">llava_next (LLaVA-NeXT)</option>
|
||||
<option value="llava_next_video">llava_next_video (LLaVA-NeXT-Video)</option>
|
||||
</optgroup>
|
||||
<optgroup label="DeepSeek 系列">
|
||||
<option value="deepseek">deepseek (DeepSeek LLM/Code/MoE)</option>
|
||||
<option value="deepseek3">deepseek3 (DeepSeek 3-3.2)</option>
|
||||
<option value="deepseekr1">deepseekr1 (DeepSeek R1 Distill)</option>
|
||||
</optgroup>
|
||||
<optgroup label="GLM 系列">
|
||||
<option value="glm4">glm4 (GLM-4/GLM-4-0414/GLM-Z1)</option>
|
||||
<option value="glm4_moe">glm4_moe (GLM-4.5)</option>
|
||||
<option value="glm4_5v">glm4_5v (GLM-4.5V)</option>
|
||||
</optgroup>
|
||||
<optgroup label="Gemma 系列">
|
||||
<option value="gemma">gemma (Gemma/Gemma 2/CodeGemma)</option>
|
||||
<option value="gemma2">gemma2 (Gemma 2)</option>
|
||||
<option value="gemma3">gemma3 (Gemma 3)</option>
|
||||
<option value="gemma3n">gemma3n (Gemma 3n)</option>
|
||||
</optgroup>
|
||||
<optgroup label="Phi 系列">
|
||||
<option value="phi">phi (Phi-3/Phi-3.5)</option>
|
||||
<option value="phi_small">phi_small (Phi-3-small)</option>
|
||||
<option value="phi4_mini">phi4_mini (Phi-4-mini)</option>
|
||||
<option value="phi4">phi4 (Phi-4)</option>
|
||||
</optgroup>
|
||||
<optgroup label="InternLM 系列">
|
||||
<option value="intern2">intern2 (InternLM 2-3)</option>
|
||||
<option value="intern_vl">intern_vl (InternVL 2.5-3.5)</option>
|
||||
<option value="intern_s1">intern_s1 (Intern-S1-mini)</option>
|
||||
</optgroup>
|
||||
<optgroup label="Mistral 系列">
|
||||
<option value="mistral">mistral (Mistral/Mixtral)</option>
|
||||
<option value="ministral3">ministral3 (Ministral 3)</option>
|
||||
</optgroup>
|
||||
<optgroup label="其他系列">
|
||||
<option value="yi">yi (Yi)</option>
|
||||
<option value="baichuan">baichuan (Baichuan)</option>
|
||||
<option value="falcon">falcon (Falcon)</option>
|
||||
<option value="falcon_h1">falcon_h1 (Falcon H1)</option>
|
||||
<option value="pixtral">pixtral (Pixtral)</option>
|
||||
<option value="paligemma">paligemma (PaliGemma)</option>
|
||||
<option value="minicpm_o">minicpm_o (MiniCPM-o-2.6)</option>
|
||||
<option value="minicpm_v">minicpm_v (MiniCPM-V-2.6)</option>
|
||||
<option value="seed_oss">seed_oss (Seed OSS)</option>
|
||||
<option value="seed_coder">seed_coder (Seed Coder)</option>
|
||||
<option value="kimi_vl">kimi_vl (Kimi-VL)</option>
|
||||
<option value="hunyuan">hunyuan (Hunyuan)</option>
|
||||
<option value="hunyuan_small">hunyuan_small (Hunyuan1.5)</option>
|
||||
<option value="granite3">granite3 (Granite 3)</option>
|
||||
<option value="granite4">granite4 (Granite 3-4)</option>
|
||||
<option value="mimo">mimo (MiMo)</option>
|
||||
<option value="mimo_v2">mimo_v2 (MiMo V2)</option>
|
||||
<option value="lfm2">lfm2 (LFM 2.5)</option>
|
||||
<option value="lfm2_vl">lfm2_vl (LFM 2.5 VL)</option>
|
||||
<option value="bailing_v2">bailing_v2 (Ling 2.0)</option>
|
||||
<option value="yuan">yuan (Yuan 2)</option>
|
||||
<option value="ernie_nothink">ernie_nothink (ERNIE-4.5)</option>
|
||||
<option value="gpt_oss">gpt_oss (GPT-OSS)</option>
|
||||
</optgroup>
|
||||
</select>
|
||||
<p class="text-xs text-gray-400 mt-1">选择与您的模型匹配的对话模板,确保训练数据格式正确</p>
|
||||
</div>
|
||||
|
||||
<!-- 训练方法 -->
|
||||
<div class="mb-6">
|
||||
<label class="block text-sm text-gray-600 mb-3">训练方法</label>
|
||||
@@ -512,7 +598,7 @@
|
||||
<span class="text-red-500 mr-1">*</span>训练集
|
||||
</label>
|
||||
<div class="flex items-center">
|
||||
<select name="dataset_id" id="trainDatasetSelect" class="form-select flex-1 max-w-md">
|
||||
<select name="train_dataset_id" id="trainDatasetSelect" class="form-select flex-1 max-w-md">
|
||||
<option value="">请选择训练数据集</option>
|
||||
</select>
|
||||
<button type="button" class="ml-2 text-primary text-sm flex items-center hover:text-primary/80" onclick="loadDatasets()">
|
||||
@@ -524,38 +610,6 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 验证集 -->
|
||||
<div class="mb-6">
|
||||
<label class="block text-sm text-gray-600 mb-3">验证集 <span class="text-red-500">*</span></label>
|
||||
<div class="flex items-center space-x-6 mb-3">
|
||||
<label class="flex items-center">
|
||||
<input type="radio" name="valid_split" value="auto" checked class="mr-2" onchange="toggleValidSplit()">
|
||||
<span class="text-sm">自动切分</span>
|
||||
</label>
|
||||
<label class="flex items-center">
|
||||
<input type="radio" name="valid_split" value="custom" class="mr-2" onchange="toggleValidSplit()">
|
||||
<span class="text-sm">选择数据集</span>
|
||||
</label>
|
||||
</div>
|
||||
<div id="autoSplitSection" class="flex items-center">
|
||||
<span class="text-sm text-gray-600 mr-2">从当前训练集随机分割</span>
|
||||
<input type="number" name="valid_ratio" value="10" class="w-16 px-2 py-1 border border-gray-300 rounded text-sm text-center focus:border-primary focus:outline-none">
|
||||
<span class="text-sm text-gray-600 ml-2">% 作为验证集</span>
|
||||
</div>
|
||||
<div id="customSplitSection" class="hidden">
|
||||
<div class="flex items-center">
|
||||
<select name="valid_dataset_id" id="validDatasetSelect" class="form-select flex-1 max-w-md">
|
||||
<option value="">请选择验证数据集</option>
|
||||
</select>
|
||||
<button type="button" class="ml-2 text-primary text-sm flex items-center hover:text-primary/80" onclick="loadDatasets()">
|
||||
<i class="fa fa-refresh"></i>
|
||||
</button>
|
||||
<button type="button" class="ml-3 bg-white border border-primary text-primary rounded px-3 py-1.5 text-sm hover:bg-primary/5" onclick="window.location.href='dataset-create.html?from=fine-tune'">
|
||||
+ 新增数据集
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 训练产出 -->
|
||||
@@ -571,16 +625,21 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 模型加密 -->
|
||||
<div class="mb-4">
|
||||
<div class="flex items-center">
|
||||
<span class="text-sm text-gray-600 mr-2">模型加密</span>
|
||||
<span class="px-2 py-0.5 bg-green-100 text-green-700 text-xs rounded">安全升级</span>
|
||||
<!-- 训练命令预览 -->
|
||||
<div class="mt-4">
|
||||
<div class="flex items-center mb-2">
|
||||
<span class="text-sm font-medium text-gray-600">训练命令预览</span>
|
||||
<button type="button" onclick="updateCommandPreview()" class="ml-2 px-2 py-0.5 bg-blue-50 text-blue-600 text-xs rounded hover:bg-blue-100">
|
||||
<i class="fa fa-refresh mr-1"></i>刷新
|
||||
</button>
|
||||
</div>
|
||||
<p class="text-xs text-gray-400 mt-1">为保障您的数据安全,平台会为导出的模型文件开启 OSS 服务端加密</p>
|
||||
<div class="bg-gray-900 rounded-lg p-3 overflow-x-auto">
|
||||
<pre id="commandPreview" class="text-xs text-green-400 font-mono whitespace-pre-wrap break-all">请选择完整配置后查看预览命令</pre>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
<!-- 底部按钮 -->
|
||||
<div class="flex items-center justify-between pt-6 border-t border-gray-100">
|
||||
<div class="flex items-center space-x-3">
|
||||
@@ -591,11 +650,6 @@
|
||||
取消
|
||||
</a>
|
||||
</div>
|
||||
<div class="flex items-center text-sm">
|
||||
<a href="#" class="text-primary hover:underline">训练费用 (预估)</a>
|
||||
<span class="mx-2 text-gray-300">|</span>
|
||||
<a href="#" class="text-primary hover:underline">计算详情</a>
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
@@ -648,6 +702,9 @@
|
||||
// 加载GPU列表
|
||||
loadGPUList();
|
||||
|
||||
// 初始化训练命令预览
|
||||
initCommandPreview();
|
||||
|
||||
// 设置侧边栏当前页高亮
|
||||
const currentPage = 'fine-tune';
|
||||
document.querySelectorAll('.nav-link').forEach(link => {
|
||||
@@ -683,20 +740,6 @@
|
||||
}
|
||||
}
|
||||
|
||||
// 切换验证集切分方式
|
||||
function toggleValidSplit() {
|
||||
const validSplit = document.querySelector('input[name="valid_split"]:checked').value;
|
||||
const autoSection = document.getElementById('autoSplitSection');
|
||||
const customSection = document.getElementById('customSplitSection');
|
||||
if (validSplit === 'auto') {
|
||||
autoSection.classList.remove('hidden');
|
||||
customSection.classList.add('hidden');
|
||||
} else {
|
||||
autoSection.classList.add('hidden');
|
||||
customSection.classList.remove('hidden');
|
||||
}
|
||||
}
|
||||
|
||||
// 切换训练方法 - 显示/隐藏LoRA参数
|
||||
function toggleTrainMethod() {
|
||||
const trainMethod = document.querySelector('input[name="train_method"]:checked').value;
|
||||
@@ -782,12 +825,6 @@
|
||||
trainSelect.innerHTML = '<option value="">请选择训练数据集</option>' +
|
||||
result.data.map(d => `<option value="${d.id}">${d.name}</option>`).join('');
|
||||
}
|
||||
// 更新验证集下拉框
|
||||
const validSelect = document.getElementById('validDatasetSelect');
|
||||
if (validSelect) {
|
||||
validSelect.innerHTML = '<option value="">请选择验证数据集</option>' +
|
||||
result.data.map(d => `<option value="${d.id}">${d.name}</option>`).join('');
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('加载数据集失败:', e);
|
||||
@@ -968,22 +1005,35 @@
|
||||
async function submitForm() {
|
||||
const form = document.getElementById('createForm');
|
||||
const formData = new FormData(form);
|
||||
const validSplit = formData.get('valid_split');
|
||||
|
||||
// 获取选中的GPU
|
||||
const selectedGPUs = getSelectedGPUs();
|
||||
|
||||
// 收集训练参数
|
||||
const trainParams = {
|
||||
batch_size: parseInt(formData.get('batch_size')) || 1,
|
||||
learning_rate: parseFloat(formData.get('learning_rate')) || 0.0001,
|
||||
n_epochs: parseFloat(formData.get('n_epochs')) || 1.0,
|
||||
eval_steps: parseInt(formData.get('eval_steps')) || 100,
|
||||
lr_scheduler_type: formData.get('lr_scheduler_type') || 'cosine',
|
||||
max_length: parseInt(formData.get('max_length')) || 512,
|
||||
warmup_ratio: parseFloat(formData.get('warmup_ratio')) || 0.05,
|
||||
weight_decay: parseFloat(formData.get('weight_decay')) || 0.01,
|
||||
lora_alpha: formData.get('lora_alpha') || '32',
|
||||
lora_dropout: parseFloat(formData.get('lora_dropout')) || 0.1,
|
||||
lora_rank: formData.get('lora_rank') || '8'
|
||||
};
|
||||
|
||||
const data = {
|
||||
name: formData.get('name'),
|
||||
base_model: formData.get('base_model'),
|
||||
template: formData.get('template'),
|
||||
train_type: formData.get('train_type'),
|
||||
train_method: formData.get('train_method'),
|
||||
gpus: selectedGPUs, // 添加GPU选择
|
||||
gpus: selectedGPUs,
|
||||
train_dataset_id: formData.get('train_dataset_id'),
|
||||
valid_split: validSplit,
|
||||
valid_ratio: parseInt(formData.get('valid_ratio')) || 10,
|
||||
valid_dataset_id: validSplit === 'custom' ? formData.get('valid_dataset_id') : null,
|
||||
output_model_name: formData.get('output_model_name'),
|
||||
...trainParams,
|
||||
status: 'pending',
|
||||
progress: 0
|
||||
};
|
||||
@@ -1000,33 +1050,201 @@
|
||||
showMessage('提示', '请选择基础模型', 'warning');
|
||||
return;
|
||||
}
|
||||
if (!data.template) {
|
||||
showMessage('提示', '请选择训练模板', 'warning');
|
||||
return;
|
||||
}
|
||||
if (!data.train_dataset_id) {
|
||||
showMessage('提示', '请选择训练集', 'warning');
|
||||
return;
|
||||
}
|
||||
if (validSplit === 'custom' && !data.valid_dataset_id) {
|
||||
showMessage('提示', '请选择验证集', 'warning');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/fine-tune`, {
|
||||
// 第一步:创建训练任务记录
|
||||
const createResponse = await fetch(`${API_BASE}/fine-tune`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(data)
|
||||
});
|
||||
const result = await response.json();
|
||||
if (result.code === 0) {
|
||||
showMessage('成功', '创建成功!', 'success', () => {
|
||||
const createResult = await createResponse.json();
|
||||
if (createResult.code !== 0) {
|
||||
showMessage('错误', createResult.message || '创建任务失败', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
const taskId = createResult.id;
|
||||
|
||||
// 第二步:启动训练
|
||||
const startData = {
|
||||
task_id: taskId,
|
||||
base_model: data.base_model,
|
||||
template: data.template,
|
||||
train_type: data.train_type,
|
||||
train_method: data.train_method,
|
||||
train_dataset_id: data.train_dataset_id,
|
||||
output_model_name: data.output_model_name,
|
||||
...trainParams
|
||||
};
|
||||
|
||||
const startResponse = await fetch(`${API_BASE}/fine-tune/start`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(startData)
|
||||
});
|
||||
const startResult = await startResponse.json();
|
||||
|
||||
if (startResult.code === 0) {
|
||||
const cmd = startResult.data?.command || '';
|
||||
showMessage('成功', `训练任务已启动!<br><br><code class="text-xs bg-gray-100 p-1 rounded">${cmd}</code>`, 'success', () => {
|
||||
window.location.href = 'main.html';
|
||||
});
|
||||
} else {
|
||||
showMessage('错误', result.message || '创建失败', 'error');
|
||||
// 更新任务状态为失败
|
||||
await fetch(`${API_BASE}/fine-tune/${taskId}`, {
|
||||
method: 'PUT',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ status: 'failed' })
|
||||
});
|
||||
showMessage('错误', startResult.message || '启动训练失败', 'error');
|
||||
}
|
||||
} catch (error) {
|
||||
showMessage('错误', '创建失败: ' + error.message, 'error');
|
||||
showMessage('错误', '操作失败: ' + error.message, 'error');
|
||||
}
|
||||
}
|
||||
|
||||
// 生成训练命令预览
|
||||
function buildCommandPreview() {
|
||||
const form = document.getElementById('createForm');
|
||||
const formData = new FormData(form);
|
||||
|
||||
// 获取选中的GPU
|
||||
const selectedGPUs = getSelectedGPUs();
|
||||
let gpuIds = '0';
|
||||
if (selectedGPUs.length > 0) {
|
||||
gpuIds = selectedGPUs.map(g => g.id.replace('gpu', '')).filter(g => /^\d+$/.test(g)).join(',');
|
||||
}
|
||||
|
||||
// 获取模型路径
|
||||
const baseModelSelect = form.querySelector('select[name="base_model"]');
|
||||
let modelPath = formData.get('base_model') || '';
|
||||
if (baseModelSelect && baseModelSelect.selectedOptions.length > 0) {
|
||||
const selectedOption = baseModelSelect.selectedOptions[0];
|
||||
const pathValue = selectedOption.getAttribute('data-path');
|
||||
if (pathValue) {
|
||||
modelPath = pathValue;
|
||||
}
|
||||
}
|
||||
|
||||
// 获取模板
|
||||
const template = formData.get('template') || 'qwen3';
|
||||
|
||||
// 获取训练类型
|
||||
const trainType = formData.get('train_type') || 'SFT';
|
||||
const stageMap = { 'SFT': 'sft', 'DPO': 'dpo', 'CPT': 'cpt' };
|
||||
|
||||
// 获取训练方法
|
||||
const trainMethod = formData.get('train_method') || 'lora';
|
||||
const methodMap = { 'lora': 'lora', 'full': 'full' };
|
||||
|
||||
// 获取输出模型名称
|
||||
const outputModelName = formData.get('output_model_name') || `${template}/${trainMethod}`;
|
||||
const outputDir = outputModelName.startsWith('./') ? outputModelName : `./saves/${outputModelName}`;
|
||||
|
||||
// 获取数据集名称
|
||||
const trainDatasetSelect = form.querySelector('select[name="train_dataset_id"]');
|
||||
let datasetName = formData.get('train_dataset_id') || 'dataset_name';
|
||||
if (trainDatasetSelect && trainDatasetSelect.selectedOptions.length > 0) {
|
||||
const selectedOption = trainDatasetSelect.selectedOptions[0];
|
||||
const datasetValue = selectedOption.getAttribute('data-name');
|
||||
if (datasetValue) {
|
||||
datasetName = datasetValue;
|
||||
}
|
||||
}
|
||||
|
||||
// 获取训练参数
|
||||
const batchSize = parseInt(formData.get('batch_size')) || 1;
|
||||
const learningRate = parseFloat(formData.get('learning_rate')) || 0.0001;
|
||||
const nEpochs = parseFloat(formData.get('n_epochs')) || 1.0;
|
||||
const maxLength = parseInt(formData.get('max_length')) || 512;
|
||||
const warmupSteps = parseInt(formData.get('warmup_steps')) || 20;
|
||||
const evalSteps = parseInt(formData.get('eval_steps')) || 100;
|
||||
const gradientAccumulationSteps = parseInt(formData.get('gradient_accumulation_steps')) || 8;
|
||||
const lrSchedulerType = formData.get('lr_scheduler_type') || 'cosine';
|
||||
|
||||
// LoRA参数
|
||||
const loraAlpha = formData.get('lora_alpha') || '32';
|
||||
const loraDropout = parseFloat(formData.get('lora_dropout')) || 0.1;
|
||||
const loraRank = formData.get('lora_rank') || '8';
|
||||
|
||||
// 构建命令
|
||||
let cmd = `CUDA_VISIBLE_DEVICES=${gpuIds} llamafactory-cli train \\\n`;
|
||||
cmd += ` --stage ${stageMap[trainType] || 'sft'} \\\n`;
|
||||
cmd += ` --do_train \\\n`;
|
||||
cmd += ` --model_name_or_path ${modelPath} \\\n`;
|
||||
cmd += ` --dataset ${datasetName} \\\n`;
|
||||
cmd += ` --dataset_dir ./datasets \\\n`;
|
||||
cmd += ` --template ${template} \\\n`;
|
||||
cmd += ` --finetuning_type ${methodMap[trainMethod] || 'lora'} \\\n`;
|
||||
|
||||
// LoRA参数(仅lora方法时显示)
|
||||
if (trainMethod === 'lora') {
|
||||
cmd += ` --lora_alpha ${loraAlpha} \\\n`;
|
||||
cmd += ` --lora_dropout ${loraDropout} \\\n`;
|
||||
cmd += ` --lora_rank ${loraRank} \\\n`;
|
||||
}
|
||||
|
||||
cmd += ` --output_dir ${outputDir} \\\n`;
|
||||
cmd += ` --overwrite_cache \\\n`;
|
||||
cmd += ` --overwrite_output_dir \\\n`;
|
||||
cmd += ` --cutoff_len ${maxLength} \\\n`;
|
||||
cmd += ` --preprocessing_num_workers 16 \\\n`;
|
||||
cmd += ` --per_device_train_batch_size ${batchSize} \\\n`;
|
||||
cmd += ` --per_device_eval_batch_size 1 \\\n`;
|
||||
cmd += ` --gradient_accumulation_steps ${gradientAccumulationSteps} \\\n`;
|
||||
cmd += ` --lr_scheduler_type ${lrSchedulerType} \\\n`;
|
||||
cmd += ` --logging_steps 50 \\\n`;
|
||||
cmd += ` --warmup_steps ${warmupSteps} \\\n`;
|
||||
cmd += ` --save_steps 100 \\\n`;
|
||||
cmd += ` --eval_steps ${evalSteps} \\\n`;
|
||||
cmd += ` --learning_rate ${learningRate} \\\n`;
|
||||
cmd += ` --num_train_epochs ${nEpochs}`;
|
||||
|
||||
return cmd;
|
||||
}
|
||||
|
||||
// 更新命令预览
|
||||
function updateCommandPreview() {
|
||||
const preview = document.getElementById('commandPreview');
|
||||
const cmd = buildCommandPreview();
|
||||
preview.textContent = cmd;
|
||||
}
|
||||
|
||||
// 监听表单变化自动更新预览
|
||||
function initCommandPreview() {
|
||||
const form = document.getElementById('createForm');
|
||||
|
||||
// 监听所有 input 和 select 的变化
|
||||
const inputs = form.querySelectorAll('input, select');
|
||||
inputs.forEach(input => {
|
||||
input.addEventListener('change', () => setTimeout(updateCommandPreview, 100));
|
||||
if (input.type === 'text' || input.type === 'number') {
|
||||
input.addEventListener('input', () => setTimeout(updateCommandPreview, 100));
|
||||
}
|
||||
});
|
||||
|
||||
// 监听卡片式单选框的点击事件 (训练类型、训练方法)
|
||||
document.querySelectorAll('.card-radio').forEach(card => {
|
||||
card.addEventListener('click', () => setTimeout(updateCommandPreview, 100));
|
||||
});
|
||||
|
||||
// 监听 GPU 卡片的点击事件
|
||||
document.querySelectorAll('.gpu-card').forEach(card => {
|
||||
card.addEventListener('click', () => setTimeout(updateCommandPreview, 100));
|
||||
});
|
||||
|
||||
// 初始化时更新一次
|
||||
setTimeout(updateCommandPreview, 500);
|
||||
}
|
||||
</script>
|
||||
|
||||
<!-- 自定义消息弹窗 -->
|
||||
|
||||
@@ -417,9 +417,9 @@
|
||||
}
|
||||
}
|
||||
|
||||
// 页面加载时获取监控数据,并每5秒刷新
|
||||
// 页面加载时获取监控数据,并每30秒刷新
|
||||
fetchSystemMetrics();
|
||||
setInterval(fetchSystemMetrics, 5000);
|
||||
setInterval(fetchSystemMetrics, 30000);
|
||||
|
||||
// 各功能模块的表格配置
|
||||
const tableConfigs = {
|
||||
|
||||
@@ -517,7 +517,7 @@
|
||||
if (select.options.length > 1) return;
|
||||
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/local-models`);
|
||||
const response = await fetch(`${API_BASE}/model-manage/local-models`);
|
||||
const result = await response.json();
|
||||
|
||||
if (result.code === 0 && result.data && result.data.models) {
|
||||
|
||||
Reference in New Issue
Block a user