模型微调已经调通

增加了参数预览
This commit is contained in:
2026-01-28 10:31:09 +08:00
parent 8a638b6372
commit a560d24e2f
8 changed files with 898 additions and 96 deletions

View File

@@ -6,6 +6,7 @@ from .model_manage import model_manage_bp
from .model_chat import model_chat_bp
from .dimension import dimension_bp
from .logs import logs_bp
from .fine_tune import fine_tune_bp
# 注册所有蓝图
def register_blueprints(app):
@@ -15,3 +16,4 @@ def register_blueprints(app):
app.register_blueprint(model_chat_bp)
app.register_blueprint(dimension_bp)
app.register_blueprint(logs_bp)
app.register_blueprint(fine_tune_bp)

View File

@@ -5,6 +5,7 @@ import io
import os
import time
import zipfile
import json
from flask import Blueprint, request, jsonify, send_from_directory, Response
from werkzeug.utils import secure_filename
@@ -52,6 +53,45 @@ def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def update_dataset_info_json(dataset_name=None, actual_filename=None, remove_filename=None):
"""更新 datasets/dataset_info.json 文件
Args:
dataset_name: 数据集名称(用于作为 key
actual_filename: 实际保存的文件名(含时间戳前缀),用于 file_name
remove_filename: 要移除的文件名为None表示不移除
"""
info_path = os.path.join(DATASET_FOLDER, 'dataset_info.json')
# 读取现有配置
dataset_info = {}
if os.path.exists(info_path):
try:
with open(info_path, 'r', encoding='utf-8') as f:
dataset_info = json.load(f)
except Exception as e:
print(f"读取 dataset_info.json 失败: {e}")
# 移除旧条目(根据移除的文件名)
if remove_filename:
key = os.path.splitext(remove_filename)[0]
if key in dataset_info:
del dataset_info[key]
print(f"从 dataset_info.json 移除: {key}")
# 添加新条目
if dataset_name and actual_filename:
dataset_info[dataset_name] = {"file_name": actual_filename}
print(f"更新 dataset_info.json: {dataset_name} -> {actual_filename}")
# 写入文件
try:
with open(info_path, 'w', encoding='utf-8') as f:
json.dump(dataset_info, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"写入 dataset_info.json 失败: {e}")
def generic_get_by_id(table_name, id_val):
"""通用按ID查询"""
conn = get_db_connection()
@@ -149,17 +189,40 @@ def delete_dataset(id):
"""删除数据集"""
conn = get_db_connection()
cursor = conn.cursor()
# 获取文件路径列表
cursor.execute("SELECT file_path FROM dataset_files WHERE dataset_id = %s", (id,))
# 获取数据集名称(用于从 dataset_info.json 移除条目)
cursor.execute("SELECT name FROM dataset_manage WHERE id = %s", (id,))
dataset_result = cursor.fetchone()
dataset_name = dataset_result['name'] if dataset_result else None
# 获取文件信息列表(包含原始文件名)
cursor.execute("SELECT file_name, file_path FROM dataset_files WHERE dataset_id = %s", (id,))
files = cursor.fetchall()
# 删除文件
for f in files:
file_path = f.get('file_path')
if file_path and os.path.exists(file_path):
try:
os.remove(file_path)
except Exception as e:
print(f"删除文件失败: {file_path}, {e}")
# 尝试多个可能的路径
paths_to_try = []
if file_path:
paths_to_try.append(file_path)
# 尝试 PROJECT_ROOT 相对路径
rel_path = file_path.replace('/app/base', PROJECT_ROOT, 1) if file_path.startswith('/app/base') else None
if rel_path:
paths_to_try.append(rel_path)
for path in paths_to_try:
if path and os.path.exists(path):
try:
os.remove(path)
print(f"已删除文件: {path}")
break
except Exception as e:
print(f"删除文件失败: {path}, {e}")
# 使用数据集名称从 dataset_info.json 移除条目
if dataset_name:
update_dataset_info_json(remove_filename=dataset_name)
# 删除数据库记录
cursor.execute("DELETE FROM dataset_files WHERE dataset_id = %s", (id,))
cursor.execute("DELETE FROM dataset_manage WHERE id = %s", (id,))
@@ -218,6 +281,12 @@ def upload_dataset_file(dataset_id):
file.save(file_path)
file_size = os.path.getsize(file_path)
# 获取数据集名称用于作为 dataset_info.json 的 key
dataset_name = dataset.get('name') if dataset else None
# 更新 dataset_info.json使用数据集名称作为 key实际保存的文件名作为 file_name
update_dataset_info_json(dataset_name=dataset_name, actual_filename=new_filename)
# 获取文件扩展名(安全处理无扩展名的情况)
parts = filename.rsplit('.', 1)
ext = parts[1].lower() if len(parts) > 1 else 'unknown'
@@ -304,8 +373,8 @@ def delete_dataset_file(file_id):
conn = get_db_connection()
cursor = conn.cursor()
# 获取文件信息
cursor.execute("SELECT dataset_id, file_path FROM dataset_files WHERE id = %s", (file_id,))
# 获取文件信息(包含 dataset_id
cursor.execute("SELECT dataset_id, file_name, file_path FROM dataset_files WHERE id = %s", (file_id,))
file_info = cursor.fetchone()
if not file_info:
@@ -313,6 +382,13 @@ def delete_dataset_file(file_id):
conn.close()
return jsonify({'code': 1, 'message': '文件不存在'})
dataset_id = file_info['dataset_id']
# 获取数据集名称(用于从 dataset_info.json 移除条目)
cursor.execute("SELECT name FROM dataset_manage WHERE id = %s", (dataset_id,))
dataset_result = cursor.fetchone()
dataset_name = dataset_result['name'] if dataset_result else None
# 删除物理文件
file_path = file_info['file_path']
if file_path and os.path.exists(file_path):
@@ -324,8 +400,11 @@ def delete_dataset_file(file_id):
# 删除数据库记录
cursor.execute("DELETE FROM dataset_files WHERE id = %s", (file_id,))
# 使用数据集名称从 dataset_info.json 移除条目
if dataset_name:
update_dataset_info_json(remove_filename=dataset_name)
# 更新数据集的文件数量和大小
dataset_id = file_info['dataset_id']
cursor.execute("SELECT COUNT(*) as count, SUM(file_size) as total_size FROM dataset_files WHERE dataset_id = %s", (dataset_id,))
result = cursor.fetchone()
file_count = result['count'] or 0

443
src/api/fine_tune.py Normal file
View File

@@ -0,0 +1,443 @@
"""
精调训练 API 路由
调用 llamafactory-cli 执行训练任务
"""
import os
import subprocess
import json
import threading
import time
from flask import Blueprint, request, jsonify
import logging
logger = logging.getLogger(__name__)
train_logger = logging.getLogger('train') # 专门的训练日志 logger输出到 train.log
# 创建蓝图
fine_tune_bp = Blueprint('fine_tune', __name__, url_prefix='/api/fine-tune')
# 训练类型映射
TRAIN_TYPE_MAP = {
'SFT': 'sft',
'DPO': 'dpo',
'CPT': 'cpt'
}
# 训练方法映射
FINETUNING_TYPE_MAP = {
'lora': 'lora',
'full': 'full'
}
@fine_tune_bp.route('/start', methods=['POST'])
def start_training():
"""启动 llamafactory 训练任务"""
try:
data = request.json
train_logger.info(f"[TRAIN] ========== 开始训练任务 ==========")
train_logger.info(f"[TRAIN] 收到启动训练请求: base_model={data.get('base_model')}, train_dataset_id={data.get('train_dataset_id')}")
# 必填参数验证
required_fields = ['base_model', 'template', 'train_dataset_id']
for field in required_fields:
if not data.get(field):
return jsonify({'code': 1, 'message': f'缺少必要参数: {field}'})
# 获取模型信息
model_path = data.get('base_model')
# 尝试转换为整数
try:
model_id = int(model_path) if str(model_path).isdigit() else None
except (ValueError, TypeError):
model_id = None
if model_id:
# 如果是 model_id需要获取模型路径
from .model_manage import get_db_connection
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("SELECT id, name, path FROM model_manage WHERE id = %s", (model_id,))
result = cursor.fetchone()
conn.close()
logger.info(f"模型查询结果: {result}")
if result and result.get('path'):
model_path = result['path']
logger.info(f"从数据库获取的模型路径: {model_path}")
else:
return jsonify({'code': 1, 'message': '模型不存在或路径为空'})
elif not model_path:
return jsonify({'code': 1, 'message': f'模型路径为空'})
train_logger.info(f"[TRAIN] 模型路径: {model_path}")
# 设置工作目录为 llamafactory 目录
llamafactory_dir = '/app/src/llamafactory'
# 处理数据集文件:将数据集复制到 llamafactory 的 datasets 目录
dataset_id = data.get('train_dataset_id')
try:
dataset_id_int = int(dataset_id) if str(dataset_id).isdigit() else None
except (ValueError, TypeError):
dataset_id_int = None
llamafactory_datasets_dir = os.path.join(llamafactory_dir, 'datasets')
os.makedirs(llamafactory_datasets_dir, exist_ok=True)
# 获取数据集名称(用于 --dataset 参数)
dataset_key = None
if dataset_id_int:
from .datasets import get_db_connection as get_dataset_conn
conn = get_dataset_conn()
cursor = conn.cursor()
cursor.execute("SELECT dm.name FROM dataset_manage dm WHERE dm.id = %s", (dataset_id_int,))
dataset_result = cursor.fetchone()
conn.close()
dataset_key = dataset_result['name'] if dataset_result else None
if dataset_key:
# 从 dataset_info.json 读取实际文件名
src_info_json = os.path.join('/app/base', 'datasets', 'dataset_info.json')
actual_file_name = None
if os.path.exists(src_info_json):
import json as json_lib
with open(src_info_json, 'r', encoding='utf-8') as f:
dataset_info = json_lib.load(f)
if dataset_key in dataset_info:
actual_file_name = dataset_info[dataset_key].get('file_name')
train_logger.info(f"[TRAIN] 从 dataset_info.json 获取文件名: {dataset_key} -> {actual_file_name}")
# 复制数据集文件到 llamafactory 目录
if actual_file_name:
src_file = os.path.join('/app/base', 'datasets', actual_file_name)
dst_file = os.path.join(llamafactory_datasets_dir, actual_file_name)
if os.path.exists(src_file):
import shutil
shutil.copy2(src_file, dst_file)
train_logger.info(f"[TRAIN] 复制数据集文件: {src_file} -> {dst_file}")
else:
train_logger.warning(f"[TRAIN] 数据集文件不存在: {src_file}")
# 复制 dataset_info.json 到 llamafactory datasets 目录
src_info_json = os.path.join('/app/base', 'datasets', 'dataset_info.json')
dst_info_json = os.path.join(llamafactory_datasets_dir, 'dataset_info.json')
try:
if os.path.exists(src_info_json):
shutil.copy2(src_info_json, dst_info_json)
train_logger.info(f"[TRAIN] 已复制 dataset_info.json 到 llamafactory 目录")
else:
train_logger.warning(f"[TRAIN] dataset_info.json 不存在: {src_info_json}")
except Exception as e:
train_logger.warning(f"[TRAIN] 复制 dataset_info.json 失败: {e}")
# 获取选中的 GPU 索引
gpus = data.get('gpus', [])
if gpus:
gpu_ids = [gpu.get('id', '').replace('gpu', '') for gpu in gpus]
gpu_ids = [g for g in gpu_ids if g.isdigit()]
cuda_devices = ','.join(gpu_ids)
else:
cuda_devices = '0'
# 设置环境变量
env = os.environ.copy()
env['CUDA_VISIBLE_DEVICES'] = cuda_devices
env['TF_CPP_MIN_LOG_LEVEL'] = '2' # 减少 TensorFlow 日志
# 构建 llamafactory-cli 命令(传入数据集名称用于 --dataset 参数)
cmd = build_train_command(data, model_path, dataset_key)
cmd_str = ' '.join(cmd)
train_logger.info(f"[TRAIN] 执行训练命令: {cmd_str}")
# 在返回的命令中显示 GPU 配置
cmd_str_with_gpu = f"CUDA_VISIBLE_DEVICES={cuda_devices} {cmd_str}"
# 生成训练日志文件路径(按日期分目录)
from datetime import datetime
today = datetime.now().strftime('%Y-%m-%d')
task_id_str = str(data.get('task_id', 'unknown'))
log_dir = os.path.join(llamafactory_dir, 'logs', today)
train_output_log = os.path.join(log_dir, f'train_{task_id_str}.log')
# 确保日志目录存在
os.makedirs(log_dir, exist_ok=True)
train_logger.info(f"[TRAIN] 启动训练进程...")
# 使用线程在后台运行训练进程
def run_training():
with open(train_output_log, 'w', encoding='utf-8') as log_file:
process = subprocess.Popen(
cmd,
cwd=llamafactory_dir,
stdout=log_file,
stderr=subprocess.STDOUT,
env=env
)
train_logger.info(f"[TRAIN] 训练进程 PID: {process.pid}")
# 等待进程完成
process.wait()
train_logger.info(f"[TRAIN] 训练进程已结束,退出码: {process.returncode}")
# 更新任务状态
final_status = 'completed' if process.returncode == 0 else 'failed'
update_fine_tune_status(data.get('task_id'), final_status, process.pid)
# 启动后台线程
training_thread = threading.Thread(target=run_training, daemon=True)
training_thread.start()
# 立即返回,不等待进程完成
pid = None # 此时还不知道实际 PID稍后可从日志获取
train_logger.info(f"[TRAIN] 训练任务已在后台启动")
train_logger.info(f"[TRAIN] 训练日志输出到: {train_output_log}")
# 更新任务状态为运行中
update_fine_tune_status(data.get('task_id'), 'running', 0)
return jsonify({
'code': 0,
'message': f'训练任务已启动 (GPU: {cuda_devices})',
'data': {
'task_id': data.get('task_id'),
'gpu_ids': cuda_devices,
'command': cmd_str_with_gpu,
'log_file': train_output_log
}
})
except Exception as e:
import traceback
train_logger.error(f"[TRAIN] 启动训练任务失败: {e}")
train_logger.error(f"[TRAIN] 详细错误: {traceback.format_exc()}")
return jsonify({'code': 1, 'message': str(e)})
def build_train_command(data, model_path, dataset_name=None):
"""构建 llamafactory-cli train 命令"""
# llamafactory-cli 路径(已在系统 PATH 中)
cmd = ['llamafactory-cli', 'train']
# 训练阶段
train_type = data.get('train_type', 'SFT')
cmd.extend(['--stage', TRAIN_TYPE_MAP.get(train_type, 'sft')])
cmd.append('--do_train')
# 模型路径
cmd.extend(['--model_name_or_path', model_path])
# 数据集 - 使用数据集名称dataset_manage.name不是实际文件名
if dataset_name:
cmd.extend(['--dataset', dataset_name])
train_logger.info(f"[TRAIN] 使用数据集名称: {dataset_name}")
else:
# 回退到原有逻辑
dataset_id = data.get('train_dataset_id')
try:
dataset_id_int = int(dataset_id) if str(dataset_id).isdigit() else None
except (ValueError, TypeError):
dataset_id_int = None
if dataset_id_int:
dataset_name = get_dataset_name(dataset_id_int)
train_logger.info(f"[TRAIN] 从数据库获取的数据集名称: {dataset_name}")
else:
dataset_name = dataset_id
cmd.extend(['--dataset', dataset_name])
# 数据集目录
cmd.extend(['--dataset_dir', './datasets']) # llamafactory 工作目录下的 datasets 目录
# 模板
template = data.get('template')
cmd.extend(['--template', template])
# 训练方法
train_method = data.get('train_method', 'lora')
cmd.extend(['--finetuning_type', FINETUNING_TYPE_MAP.get(train_method, 'lora')])
# 输出目录
output_dir = data.get('output_model_name', f"./saves/{template}/{train_method}")
if not output_dir.startswith('./'):
output_dir = f"./saves/{output_dir}"
cmd.extend(['--output_dir', output_dir])
# 常用参数
cmd.extend([
'--overwrite_cache',
'--overwrite_output_dir',
'--cutoff_len', str(data.get('max_length', 512)),
'--preprocessing_num_workers', '16',
'--per_device_train_batch_size', str(data.get('batch_size', 1)),
'--per_device_eval_batch_size', '1',
'--gradient_accumulation_steps', str(data.get('gradient_accumulation_steps', 8)),
'--lr_scheduler_type', data.get('lr_scheduler_type', 'cosine'),
'--logging_steps', '50',
'--warmup_steps', str(data.get('warmup_steps', 20)),
'--save_steps', '100',
'--eval_steps', str(data.get('eval_steps', 100)),
])
# 学习率
cmd.extend(['--learning_rate', str(data.get('learning_rate', 0.0001))])
# 训练轮数
cmd.extend(['--num_train_epochs', str(data.get('n_epochs', 1.0))])
# 验证集比例
val_ratio = data.get('valid_ratio', 0)
if val_ratio > 0:
cmd.extend(['--val_size', str(val_ratio / 100)])
# 最大样本数
if data.get('max_samples'):
cmd.extend(['--max_samples', str(data.get('max_samples'))])
# 其他选项
if data.get('plot_loss'):
cmd.append('--plot_loss')
if data.get('fp16'):
cmd.append('--fp16')
if data.get('load_best_model_at_end'):
cmd.append('--load_best_model_at_end')
return cmd
def get_dataset_name(dataset_id):
"""根据数据集 ID 获取数据集名称"""
try:
from .datasets import get_db_connection
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("SELECT id, name FROM dataset_manage WHERE id = %s", (dataset_id,))
result = cursor.fetchone()
conn.close()
logger.info(f"数据集查询结果: {result}")
if result and result.get('name'):
return result['name']
logger.warning(f"未找到数据集 ID={dataset_id},使用默认值")
return 'default'
except Exception as e:
logger.error(f"查询数据集失败: {e}")
return 'default'
def update_fine_tune_status(task_id, status, pid=None):
"""更新训练任务状态"""
try:
from .model_manage import get_db_connection
conn = get_db_connection()
cursor = conn.cursor()
if status == 'running' and pid:
cursor.execute(
"UPDATE fine_tune SET status = %s, process_id = %s WHERE id = %s",
(status, pid, task_id)
)
else:
cursor.execute(
"UPDATE fine_tune SET status = %s WHERE id = %s",
(status, task_id)
)
conn.commit()
conn.close()
except Exception as e:
logger.error(f"更新任务状态失败: {e}")
@fine_tune_bp.route('/stop/<int:task_id>', methods=['POST'])
def stop_training(task_id):
"""停止训练任务"""
try:
from .model_manage import get_db_connection
# 获取进程 ID
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("SELECT process_id FROM fine_tune WHERE id = %s", (task_id,))
result = cursor.fetchone()
conn.close()
if result and result.get('process_id'):
pid = result['process_id']
try:
# 尝试终止进程
import signal
os.kill(pid, signal.SIGTERM)
logger.info(f"已终止训练进程 PID: {pid}")
except ProcessLookupError:
logger.warning(f"进程 {pid} 不存在")
except PermissionError:
logger.error(f"没有权限终止进程 {pid}")
# 更新状态
update_fine_tune_status(task_id, 'stopped')
return jsonify({'code': 0, 'message': '训练任务已停止'})
except Exception as e:
logger.error(f"停止训练任务失败: {e}")
return jsonify({'code': 1, 'message': str(e)})
@fine_tune_bp.route('/status/<int:task_id>', methods=['GET'])
def get_training_status(task_id):
"""获取训练任务状态"""
try:
from .model_manage import get_db_connection
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute(
"SELECT id, name, status, progress, process_id FROM fine_tune WHERE id = %s",
(task_id,)
)
result = cursor.fetchone()
conn.close()
if result:
return jsonify({
'code': 0,
'data': {
'task_id': result['id'],
'name': result['name'],
'status': result['status'],
'progress': result['progress'],
'pid': result.get('process_id')
}
})
else:
return jsonify({'code': 1, 'message': '任务不存在'})
except Exception as e:
logger.error(f"获取任务状态失败: {e}")
return jsonify({'code': 1, 'message': str(e)})
def get_db_connection():
"""获取数据库连接"""
import pymysql
import yaml
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
CONFIG = yaml.safe_load(f)
db_config = CONFIG['database']
return pymysql.connect(
host=db_config['host'],
port=db_config['port'],
user=db_config['username'],
password=db_config['password'],
database=db_config['name'],
charset=db_config.get('charset', 'utf8mb4'),
cursorclass=pymysql.cursors.DictCursor
)

View File

@@ -6,8 +6,12 @@ import pymysql
import yaml
from flask import Blueprint, request, jsonify
# 获取项目根目录
# 获取项目根目录 - 优先使用环境变量,否则从文件路径计算
MOUNT_BASE = os.environ.get('MOUNT_BASE', '/app/base')
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 如果 PROJECT_ROOT 是 /app 或 /app/src/llamafactory则使用挂载路径
if PROJECT_ROOT in ('/app', '/app/src/llamafactory'):
PROJECT_ROOT = MOUNT_BASE
# 创建蓝图
model_manage_bp = Blueprint('model_manage', __name__, url_prefix='/api/model-manage')

View File

@@ -86,6 +86,15 @@ def setup_logger(name='app'):
datefmt='%H:%M:%S'
))
# 5. 训练日志处理器 - 专门记录训练输出
train_log_path = os.path.join(log_dir, 'train.log')
train_handler = RotatingFileHandler(train_log_path, maxBytes=100*1024*1024, backupCount=5, encoding='utf-8')
train_handler.setLevel(logging.INFO)
train_handler.setFormatter(logging.Formatter(
'[%(asctime)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
))
# 添加处理器到 logger
logger.addHandler(all_handler)
logger.addHandler(error_handler)
@@ -98,6 +107,13 @@ def setup_logger(name='app'):
request_logger.addHandler(request_handler)
request_logger.addHandler(console_handler)
# 为训练日志创建单独的 logger
train_logger = logging.getLogger('train')
train_logger.setLevel(logging.INFO)
train_logger.handlers.clear()
train_logger.addHandler(train_handler)
train_logger.addHandler(console_handler)
return logger
@@ -137,6 +153,7 @@ def init_database():
id INT AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(255) NOT NULL,
base_model VARCHAR(255),
template VARCHAR(100) COMMENT '训练模板,如 qwen, llama, chatglm 等',
train_type VARCHAR(50),
train_method VARCHAR(50),
gpus JSON COMMENT 'GPU硬件选择支持多卡训练',
@@ -144,6 +161,7 @@ def init_database():
valid_split VARCHAR(50),
valid_ratio INT DEFAULT 10,
output_model_name VARCHAR(255),
process_id INT COMMENT '训练进程ID',
status VARCHAR(50) DEFAULT 'pending',
progress INT DEFAULT 0,
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
@@ -305,6 +323,44 @@ def init_database():
except Exception:
pass # 列已存在时不输出任何信息
# 为 fine_tune 表添加 template 列
try:
cursor.execute("ALTER TABLE fine_tune ADD COLUMN template VARCHAR(100) COMMENT '训练模板,如 qwen, llama, chatglm 等'")
logger.debug("fine_tune 表添加 template 列成功")
except Exception:
pass # 列已存在时不输出任何信息
# 为 fine_tune 表添加 process_id 列
try:
cursor.execute("ALTER TABLE fine_tune ADD COLUMN process_id INT COMMENT '训练进程ID'")
logger.debug("fine_tune 表添加 process_id 列成功")
except Exception:
pass # 列已存在时不输出任何信息
# 为 fine_tune 表添加训练相关列
columns_to_add = [
("train_dataset_id", "INT COMMENT '训练数据集ID'"),
("valid_dataset_id", "INT COMMENT '验证数据集ID'"),
("eval_steps", "INT DEFAULT 100 COMMENT '评估步数'"),
("lr_scheduler_type", "VARCHAR(50) DEFAULT 'cosine' COMMENT '学习率调度器'"),
("warmup_ratio", "FLOAT DEFAULT 0.05 COMMENT '预热比例'"),
("weight_decay", "FLOAT DEFAULT 0.01 COMMENT '权重衰减'"),
("batch_size", "INT DEFAULT 1 COMMENT '批次大小'"),
("learning_rate", "FLOAT DEFAULT 0.0001 COMMENT '学习率'"),
("n_epochs", "FLOAT DEFAULT 1.0 COMMENT '训练轮数'"),
("max_length", "INT DEFAULT 512 COMMENT '最大长度'"),
("lora_alpha", "VARCHAR(10) DEFAULT '32' COMMENT 'LoRA alpha'"),
("lora_rank", "VARCHAR(10) DEFAULT '8' COMMENT 'LoRA rank'"),
("lora_dropout", "FLOAT DEFAULT 0.1 COMMENT 'LoRA dropout'"),
("valid_ratio", "INT DEFAULT 10 COMMENT '验证集比例'"),
]
for col_name, col_def in columns_to_add:
try:
cursor.execute(f"ALTER TABLE fine_tune ADD COLUMN {col_name} {col_def}")
logger.debug(f"fine_tune 表添加 {col_name} 列成功")
except Exception:
pass # 列已存在时不输出任何信息
# 插入默认管理员用户
cursor.execute("SELECT * FROM users WHERE username = 'admin'")
if not cursor.fetchone():
@@ -323,8 +379,8 @@ def init_database():
app = Flask(__name__)
app.config['SECRET_KEY'] = CONFIG['secret_key']
app.config['CORS_HEADERS'] = 'Content-Type'
# 使用字符串形式的 origins
CORS(app, origins="*", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["Content-Type", "Authorization"], supports_credentials=False)
# 允许所有来源
CORS(app, resources={r"/api/*": {"origins": "*"}}, methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["Content-Type", "Authorization"])
# 注册蓝图
register_blueprints(app)

View File

@@ -293,6 +293,92 @@
</div>
</div>
<!-- 选择训练模板 -->
<div class="mb-6">
<label class="block text-sm text-gray-600 mb-3">
<span class="text-red-500 mr-1">*</span>训练模板
</label>
<select name="template" id="templateSelect" class="form-select flex-1 max-w-md">
<optgroup label="Qwen 系列">
<option value="qwen">qwen (Qwen/Qwen2)</option>
<option value="qwen3">qwen3 (Qwen3)</option>
<option value="qwen3_nothink">qwen3_nothink (Qwen3-Thinking)</option>
<option value="qwen2_vl">qwen2_vl (Qwen2-VL)</option>
<option value="qwen3_vl">qwen3_vl (Qwen3-VL)</option>
<option value="qwen2_audio">qwen2_audio (Qwen2-Audio)</option>
<option value="qwen2_omni">qwen2_omni (Qwen2.5-Omni)</option>
<option value="qwen3_omni">qwen3_omni (Qwen3-Omni)</option>
</optgroup>
<optgroup label="LLaMA 系列">
<option value="llama">llama (LLaMA)</option>
<option value="llama2">llama2 (LLaMA 2)</option>
<option value="llama3">llama3 (LLaMA 3/3.3)</option>
<option value="llama4">llama4 (LLaMA 4)</option>
<option value="mllama">mllama (LLaMA 3.2 Vision)</option>
<option value="llava">llava (LLaVA-1.5)</option>
<option value="llava_next">llava_next (LLaVA-NeXT)</option>
<option value="llava_next_video">llava_next_video (LLaVA-NeXT-Video)</option>
</optgroup>
<optgroup label="DeepSeek 系列">
<option value="deepseek">deepseek (DeepSeek LLM/Code/MoE)</option>
<option value="deepseek3">deepseek3 (DeepSeek 3-3.2)</option>
<option value="deepseekr1">deepseekr1 (DeepSeek R1 Distill)</option>
</optgroup>
<optgroup label="GLM 系列">
<option value="glm4">glm4 (GLM-4/GLM-4-0414/GLM-Z1)</option>
<option value="glm4_moe">glm4_moe (GLM-4.5)</option>
<option value="glm4_5v">glm4_5v (GLM-4.5V)</option>
</optgroup>
<optgroup label="Gemma 系列">
<option value="gemma">gemma (Gemma/Gemma 2/CodeGemma)</option>
<option value="gemma2">gemma2 (Gemma 2)</option>
<option value="gemma3">gemma3 (Gemma 3)</option>
<option value="gemma3n">gemma3n (Gemma 3n)</option>
</optgroup>
<optgroup label="Phi 系列">
<option value="phi">phi (Phi-3/Phi-3.5)</option>
<option value="phi_small">phi_small (Phi-3-small)</option>
<option value="phi4_mini">phi4_mini (Phi-4-mini)</option>
<option value="phi4">phi4 (Phi-4)</option>
</optgroup>
<optgroup label="InternLM 系列">
<option value="intern2">intern2 (InternLM 2-3)</option>
<option value="intern_vl">intern_vl (InternVL 2.5-3.5)</option>
<option value="intern_s1">intern_s1 (Intern-S1-mini)</option>
</optgroup>
<optgroup label="Mistral 系列">
<option value="mistral">mistral (Mistral/Mixtral)</option>
<option value="ministral3">ministral3 (Ministral 3)</option>
</optgroup>
<optgroup label="其他系列">
<option value="yi">yi (Yi)</option>
<option value="baichuan">baichuan (Baichuan)</option>
<option value="falcon">falcon (Falcon)</option>
<option value="falcon_h1">falcon_h1 (Falcon H1)</option>
<option value="pixtral">pixtral (Pixtral)</option>
<option value="paligemma">paligemma (PaliGemma)</option>
<option value="minicpm_o">minicpm_o (MiniCPM-o-2.6)</option>
<option value="minicpm_v">minicpm_v (MiniCPM-V-2.6)</option>
<option value="seed_oss">seed_oss (Seed OSS)</option>
<option value="seed_coder">seed_coder (Seed Coder)</option>
<option value="kimi_vl">kimi_vl (Kimi-VL)</option>
<option value="hunyuan">hunyuan (Hunyuan)</option>
<option value="hunyuan_small">hunyuan_small (Hunyuan1.5)</option>
<option value="granite3">granite3 (Granite 3)</option>
<option value="granite4">granite4 (Granite 3-4)</option>
<option value="mimo">mimo (MiMo)</option>
<option value="mimo_v2">mimo_v2 (MiMo V2)</option>
<option value="lfm2">lfm2 (LFM 2.5)</option>
<option value="lfm2_vl">lfm2_vl (LFM 2.5 VL)</option>
<option value="bailing_v2">bailing_v2 (Ling 2.0)</option>
<option value="yuan">yuan (Yuan 2)</option>
<option value="ernie_nothink">ernie_nothink (ERNIE-4.5)</option>
<option value="gpt_oss">gpt_oss (GPT-OSS)</option>
</optgroup>
</select>
<p class="text-xs text-gray-400 mt-1">选择与您的模型匹配的对话模板,确保训练数据格式正确</p>
</div>
<!-- 训练方法 -->
<div class="mb-6">
<label class="block text-sm text-gray-600 mb-3">训练方法</label>
@@ -512,7 +598,7 @@
<span class="text-red-500 mr-1">*</span>训练集
</label>
<div class="flex items-center">
<select name="dataset_id" id="trainDatasetSelect" class="form-select flex-1 max-w-md">
<select name="train_dataset_id" id="trainDatasetSelect" class="form-select flex-1 max-w-md">
<option value="">请选择训练数据集</option>
</select>
<button type="button" class="ml-2 text-primary text-sm flex items-center hover:text-primary/80" onclick="loadDatasets()">
@@ -524,38 +610,6 @@
</div>
</div>
<!-- 验证集 -->
<div class="mb-6">
<label class="block text-sm text-gray-600 mb-3">验证集 <span class="text-red-500">*</span></label>
<div class="flex items-center space-x-6 mb-3">
<label class="flex items-center">
<input type="radio" name="valid_split" value="auto" checked class="mr-2" onchange="toggleValidSplit()">
<span class="text-sm">自动切分</span>
</label>
<label class="flex items-center">
<input type="radio" name="valid_split" value="custom" class="mr-2" onchange="toggleValidSplit()">
<span class="text-sm">选择数据集</span>
</label>
</div>
<div id="autoSplitSection" class="flex items-center">
<span class="text-sm text-gray-600 mr-2">从当前训练集随机分割</span>
<input type="number" name="valid_ratio" value="10" class="w-16 px-2 py-1 border border-gray-300 rounded text-sm text-center focus:border-primary focus:outline-none">
<span class="text-sm text-gray-600 ml-2">% 作为验证集</span>
</div>
<div id="customSplitSection" class="hidden">
<div class="flex items-center">
<select name="valid_dataset_id" id="validDatasetSelect" class="form-select flex-1 max-w-md">
<option value="">请选择验证数据集</option>
</select>
<button type="button" class="ml-2 text-primary text-sm flex items-center hover:text-primary/80" onclick="loadDatasets()">
<i class="fa fa-refresh"></i>
</button>
<button type="button" class="ml-3 bg-white border border-primary text-primary rounded px-3 py-1.5 text-sm hover:bg-primary/5" onclick="window.location.href='dataset-create.html?from=fine-tune'">
+ 新增数据集
</button>
</div>
</div>
</div>
</div>
<!-- 训练产出 -->
@@ -571,14 +625,19 @@
</div>
</div>
<!-- 模型加密 -->
<div class="mb-4">
<div class="flex items-center">
<span class="text-sm text-gray-600 mr-2">模型加密</span>
<span class="px-2 py-0.5 bg-green-100 text-green-700 text-xs rounded">安全升级</span>
<!-- 训练命令预览 -->
<div class="mt-4">
<div class="flex items-center mb-2">
<span class="text-sm font-medium text-gray-600">训练命令预览</span>
<button type="button" onclick="updateCommandPreview()" class="ml-2 px-2 py-0.5 bg-blue-50 text-blue-600 text-xs rounded hover:bg-blue-100">
<i class="fa fa-refresh mr-1"></i>刷新
</button>
</div>
<div class="bg-gray-900 rounded-lg p-3 overflow-x-auto">
<pre id="commandPreview" class="text-xs text-green-400 font-mono whitespace-pre-wrap break-all">请选择完整配置后查看预览命令</pre>
</div>
<p class="text-xs text-gray-400 mt-1">为保障您的数据安全,平台会为导出的模型文件开启 OSS 服务端加密</p>
</div>
</div>
<!-- 底部按钮 -->
@@ -591,11 +650,6 @@
取消
</a>
</div>
<div class="flex items-center text-sm">
<a href="#" class="text-primary hover:underline">训练费用 (预估)</a>
<span class="mx-2 text-gray-300">|</span>
<a href="#" class="text-primary hover:underline">计算详情</a>
</div>
</div>
</form>
</div>
@@ -648,6 +702,9 @@
// 加载GPU列表
loadGPUList();
// 初始化训练命令预览
initCommandPreview();
// 设置侧边栏当前页高亮
const currentPage = 'fine-tune';
document.querySelectorAll('.nav-link').forEach(link => {
@@ -683,20 +740,6 @@
}
}
// 切换验证集切分方式
function toggleValidSplit() {
const validSplit = document.querySelector('input[name="valid_split"]:checked').value;
const autoSection = document.getElementById('autoSplitSection');
const customSection = document.getElementById('customSplitSection');
if (validSplit === 'auto') {
autoSection.classList.remove('hidden');
customSection.classList.add('hidden');
} else {
autoSection.classList.add('hidden');
customSection.classList.remove('hidden');
}
}
// 切换训练方法 - 显示/隐藏LoRA参数
function toggleTrainMethod() {
const trainMethod = document.querySelector('input[name="train_method"]:checked').value;
@@ -782,12 +825,6 @@
trainSelect.innerHTML = '<option value="">请选择训练数据集</option>' +
result.data.map(d => `<option value="${d.id}">${d.name}</option>`).join('');
}
// 更新验证集下拉框
const validSelect = document.getElementById('validDatasetSelect');
if (validSelect) {
validSelect.innerHTML = '<option value="">请选择验证数据集</option>' +
result.data.map(d => `<option value="${d.id}">${d.name}</option>`).join('');
}
}
} catch (e) {
console.error('加载数据集失败:', e);
@@ -968,22 +1005,35 @@
async function submitForm() {
const form = document.getElementById('createForm');
const formData = new FormData(form);
const validSplit = formData.get('valid_split');
// 获取选中的GPU
const selectedGPUs = getSelectedGPUs();
// 收集训练参数
const trainParams = {
batch_size: parseInt(formData.get('batch_size')) || 1,
learning_rate: parseFloat(formData.get('learning_rate')) || 0.0001,
n_epochs: parseFloat(formData.get('n_epochs')) || 1.0,
eval_steps: parseInt(formData.get('eval_steps')) || 100,
lr_scheduler_type: formData.get('lr_scheduler_type') || 'cosine',
max_length: parseInt(formData.get('max_length')) || 512,
warmup_ratio: parseFloat(formData.get('warmup_ratio')) || 0.05,
weight_decay: parseFloat(formData.get('weight_decay')) || 0.01,
lora_alpha: formData.get('lora_alpha') || '32',
lora_dropout: parseFloat(formData.get('lora_dropout')) || 0.1,
lora_rank: formData.get('lora_rank') || '8'
};
const data = {
name: formData.get('name'),
base_model: formData.get('base_model'),
template: formData.get('template'),
train_type: formData.get('train_type'),
train_method: formData.get('train_method'),
gpus: selectedGPUs, // 添加GPU选择
gpus: selectedGPUs,
train_dataset_id: formData.get('train_dataset_id'),
valid_split: validSplit,
valid_ratio: parseInt(formData.get('valid_ratio')) || 10,
valid_dataset_id: validSplit === 'custom' ? formData.get('valid_dataset_id') : null,
output_model_name: formData.get('output_model_name'),
...trainParams,
status: 'pending',
progress: 0
};
@@ -1000,33 +1050,201 @@
showMessage('提示', '请选择基础模型', 'warning');
return;
}
if (!data.template) {
showMessage('提示', '请选择训练模板', 'warning');
return;
}
if (!data.train_dataset_id) {
showMessage('提示', '请选择训练集', 'warning');
return;
}
if (validSplit === 'custom' && !data.valid_dataset_id) {
showMessage('提示', '请选择验证集', 'warning');
return;
}
try {
const response = await fetch(`${API_BASE}/fine-tune`, {
// 第一步:创建训练任务记录
const createResponse = await fetch(`${API_BASE}/fine-tune`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(data)
});
const result = await response.json();
if (result.code === 0) {
showMessage('成功', '创建成功!', 'success', () => {
const createResult = await createResponse.json();
if (createResult.code !== 0) {
showMessage('错误', createResult.message || '创建任务失败', 'error');
return;
}
const taskId = createResult.id;
// 第二步:启动训练
const startData = {
task_id: taskId,
base_model: data.base_model,
template: data.template,
train_type: data.train_type,
train_method: data.train_method,
train_dataset_id: data.train_dataset_id,
output_model_name: data.output_model_name,
...trainParams
};
const startResponse = await fetch(`${API_BASE}/fine-tune/start`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(startData)
});
const startResult = await startResponse.json();
if (startResult.code === 0) {
const cmd = startResult.data?.command || '';
showMessage('成功', `训练任务已启动!<br><br><code class="text-xs bg-gray-100 p-1 rounded">${cmd}</code>`, 'success', () => {
window.location.href = 'main.html';
});
} else {
showMessage('错误', result.message || '创建失败', 'error');
// 更新任务状态为失败
await fetch(`${API_BASE}/fine-tune/${taskId}`, {
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ status: 'failed' })
});
showMessage('错误', startResult.message || '启动训练失败', 'error');
}
} catch (error) {
showMessage('错误', '创建失败: ' + error.message, 'error');
showMessage('错误', '操作失败: ' + error.message, 'error');
}
}
// 生成训练命令预览
function buildCommandPreview() {
const form = document.getElementById('createForm');
const formData = new FormData(form);
// 获取选中的GPU
const selectedGPUs = getSelectedGPUs();
let gpuIds = '0';
if (selectedGPUs.length > 0) {
gpuIds = selectedGPUs.map(g => g.id.replace('gpu', '')).filter(g => /^\d+$/.test(g)).join(',');
}
// 获取模型路径
const baseModelSelect = form.querySelector('select[name="base_model"]');
let modelPath = formData.get('base_model') || '';
if (baseModelSelect && baseModelSelect.selectedOptions.length > 0) {
const selectedOption = baseModelSelect.selectedOptions[0];
const pathValue = selectedOption.getAttribute('data-path');
if (pathValue) {
modelPath = pathValue;
}
}
// 获取模板
const template = formData.get('template') || 'qwen3';
// 获取训练类型
const trainType = formData.get('train_type') || 'SFT';
const stageMap = { 'SFT': 'sft', 'DPO': 'dpo', 'CPT': 'cpt' };
// 获取训练方法
const trainMethod = formData.get('train_method') || 'lora';
const methodMap = { 'lora': 'lora', 'full': 'full' };
// 获取输出模型名称
const outputModelName = formData.get('output_model_name') || `${template}/${trainMethod}`;
const outputDir = outputModelName.startsWith('./') ? outputModelName : `./saves/${outputModelName}`;
// 获取数据集名称
const trainDatasetSelect = form.querySelector('select[name="train_dataset_id"]');
let datasetName = formData.get('train_dataset_id') || 'dataset_name';
if (trainDatasetSelect && trainDatasetSelect.selectedOptions.length > 0) {
const selectedOption = trainDatasetSelect.selectedOptions[0];
const datasetValue = selectedOption.getAttribute('data-name');
if (datasetValue) {
datasetName = datasetValue;
}
}
// 获取训练参数
const batchSize = parseInt(formData.get('batch_size')) || 1;
const learningRate = parseFloat(formData.get('learning_rate')) || 0.0001;
const nEpochs = parseFloat(formData.get('n_epochs')) || 1.0;
const maxLength = parseInt(formData.get('max_length')) || 512;
const warmupSteps = parseInt(formData.get('warmup_steps')) || 20;
const evalSteps = parseInt(formData.get('eval_steps')) || 100;
const gradientAccumulationSteps = parseInt(formData.get('gradient_accumulation_steps')) || 8;
const lrSchedulerType = formData.get('lr_scheduler_type') || 'cosine';
// LoRA参数
const loraAlpha = formData.get('lora_alpha') || '32';
const loraDropout = parseFloat(formData.get('lora_dropout')) || 0.1;
const loraRank = formData.get('lora_rank') || '8';
// 构建命令
let cmd = `CUDA_VISIBLE_DEVICES=${gpuIds} llamafactory-cli train \\\n`;
cmd += ` --stage ${stageMap[trainType] || 'sft'} \\\n`;
cmd += ` --do_train \\\n`;
cmd += ` --model_name_or_path ${modelPath} \\\n`;
cmd += ` --dataset ${datasetName} \\\n`;
cmd += ` --dataset_dir ./datasets \\\n`;
cmd += ` --template ${template} \\\n`;
cmd += ` --finetuning_type ${methodMap[trainMethod] || 'lora'} \\\n`;
// LoRA参数仅lora方法时显示
if (trainMethod === 'lora') {
cmd += ` --lora_alpha ${loraAlpha} \\\n`;
cmd += ` --lora_dropout ${loraDropout} \\\n`;
cmd += ` --lora_rank ${loraRank} \\\n`;
}
cmd += ` --output_dir ${outputDir} \\\n`;
cmd += ` --overwrite_cache \\\n`;
cmd += ` --overwrite_output_dir \\\n`;
cmd += ` --cutoff_len ${maxLength} \\\n`;
cmd += ` --preprocessing_num_workers 16 \\\n`;
cmd += ` --per_device_train_batch_size ${batchSize} \\\n`;
cmd += ` --per_device_eval_batch_size 1 \\\n`;
cmd += ` --gradient_accumulation_steps ${gradientAccumulationSteps} \\\n`;
cmd += ` --lr_scheduler_type ${lrSchedulerType} \\\n`;
cmd += ` --logging_steps 50 \\\n`;
cmd += ` --warmup_steps ${warmupSteps} \\\n`;
cmd += ` --save_steps 100 \\\n`;
cmd += ` --eval_steps ${evalSteps} \\\n`;
cmd += ` --learning_rate ${learningRate} \\\n`;
cmd += ` --num_train_epochs ${nEpochs}`;
return cmd;
}
// 更新命令预览
function updateCommandPreview() {
const preview = document.getElementById('commandPreview');
const cmd = buildCommandPreview();
preview.textContent = cmd;
}
// 监听表单变化自动更新预览
function initCommandPreview() {
const form = document.getElementById('createForm');
// 监听所有 input 和 select 的变化
const inputs = form.querySelectorAll('input, select');
inputs.forEach(input => {
input.addEventListener('change', () => setTimeout(updateCommandPreview, 100));
if (input.type === 'text' || input.type === 'number') {
input.addEventListener('input', () => setTimeout(updateCommandPreview, 100));
}
});
// 监听卡片式单选框的点击事件 (训练类型、训练方法)
document.querySelectorAll('.card-radio').forEach(card => {
card.addEventListener('click', () => setTimeout(updateCommandPreview, 100));
});
// 监听 GPU 卡片的点击事件
document.querySelectorAll('.gpu-card').forEach(card => {
card.addEventListener('click', () => setTimeout(updateCommandPreview, 100));
});
// 初始化时更新一次
setTimeout(updateCommandPreview, 500);
}
</script>
<!-- 自定义消息弹窗 -->

View File

@@ -417,9 +417,9 @@
}
}
// 页面加载时获取监控数据,并每5秒刷新
// 页面加载时获取监控数据,并每30秒刷新
fetchSystemMetrics();
setInterval(fetchSystemMetrics, 5000);
setInterval(fetchSystemMetrics, 30000);
// 各功能模块的表格配置
const tableConfigs = {

View File

@@ -517,7 +517,7 @@
if (select.options.length > 1) return;
try {
const response = await fetch(`${API_BASE}/local-models`);
const response = await fetch(`${API_BASE}/model-manage/local-models`);
const result = await response.json();
if (result.code === 0 && result.data && result.data.models) {