模型微调已经调通
增加了参数预览
This commit is contained in:
@@ -6,6 +6,7 @@ from .model_manage import model_manage_bp
|
||||
from .model_chat import model_chat_bp
|
||||
from .dimension import dimension_bp
|
||||
from .logs import logs_bp
|
||||
from .fine_tune import fine_tune_bp
|
||||
|
||||
# 注册所有蓝图
|
||||
def register_blueprints(app):
|
||||
@@ -15,3 +16,4 @@ def register_blueprints(app):
|
||||
app.register_blueprint(model_chat_bp)
|
||||
app.register_blueprint(dimension_bp)
|
||||
app.register_blueprint(logs_bp)
|
||||
app.register_blueprint(fine_tune_bp)
|
||||
|
||||
@@ -5,6 +5,7 @@ import io
|
||||
import os
|
||||
import time
|
||||
import zipfile
|
||||
import json
|
||||
from flask import Blueprint, request, jsonify, send_from_directory, Response
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
@@ -52,6 +53,45 @@ def allowed_file(filename):
|
||||
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
||||
|
||||
|
||||
def update_dataset_info_json(dataset_name=None, actual_filename=None, remove_filename=None):
|
||||
"""更新 datasets/dataset_info.json 文件
|
||||
|
||||
Args:
|
||||
dataset_name: 数据集名称(用于作为 key)
|
||||
actual_filename: 实际保存的文件名(含时间戳前缀),用于 file_name
|
||||
remove_filename: 要移除的文件名,为None表示不移除
|
||||
"""
|
||||
info_path = os.path.join(DATASET_FOLDER, 'dataset_info.json')
|
||||
|
||||
# 读取现有配置
|
||||
dataset_info = {}
|
||||
if os.path.exists(info_path):
|
||||
try:
|
||||
with open(info_path, 'r', encoding='utf-8') as f:
|
||||
dataset_info = json.load(f)
|
||||
except Exception as e:
|
||||
print(f"读取 dataset_info.json 失败: {e}")
|
||||
|
||||
# 移除旧条目(根据移除的文件名)
|
||||
if remove_filename:
|
||||
key = os.path.splitext(remove_filename)[0]
|
||||
if key in dataset_info:
|
||||
del dataset_info[key]
|
||||
print(f"从 dataset_info.json 移除: {key}")
|
||||
|
||||
# 添加新条目
|
||||
if dataset_name and actual_filename:
|
||||
dataset_info[dataset_name] = {"file_name": actual_filename}
|
||||
print(f"更新 dataset_info.json: {dataset_name} -> {actual_filename}")
|
||||
|
||||
# 写入文件
|
||||
try:
|
||||
with open(info_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(dataset_info, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
print(f"写入 dataset_info.json 失败: {e}")
|
||||
|
||||
|
||||
def generic_get_by_id(table_name, id_val):
|
||||
"""通用按ID查询"""
|
||||
conn = get_db_connection()
|
||||
@@ -149,17 +189,40 @@ def delete_dataset(id):
|
||||
"""删除数据集"""
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
# 获取文件路径列表
|
||||
cursor.execute("SELECT file_path FROM dataset_files WHERE dataset_id = %s", (id,))
|
||||
|
||||
# 获取数据集名称(用于从 dataset_info.json 移除条目)
|
||||
cursor.execute("SELECT name FROM dataset_manage WHERE id = %s", (id,))
|
||||
dataset_result = cursor.fetchone()
|
||||
dataset_name = dataset_result['name'] if dataset_result else None
|
||||
|
||||
# 获取文件信息列表(包含原始文件名)
|
||||
cursor.execute("SELECT file_name, file_path FROM dataset_files WHERE dataset_id = %s", (id,))
|
||||
files = cursor.fetchall()
|
||||
# 删除文件
|
||||
|
||||
for f in files:
|
||||
file_path = f.get('file_path')
|
||||
if file_path and os.path.exists(file_path):
|
||||
try:
|
||||
os.remove(file_path)
|
||||
except Exception as e:
|
||||
print(f"删除文件失败: {file_path}, {e}")
|
||||
# 尝试多个可能的路径
|
||||
paths_to_try = []
|
||||
if file_path:
|
||||
paths_to_try.append(file_path)
|
||||
# 尝试 PROJECT_ROOT 相对路径
|
||||
rel_path = file_path.replace('/app/base', PROJECT_ROOT, 1) if file_path.startswith('/app/base') else None
|
||||
if rel_path:
|
||||
paths_to_try.append(rel_path)
|
||||
|
||||
for path in paths_to_try:
|
||||
if path and os.path.exists(path):
|
||||
try:
|
||||
os.remove(path)
|
||||
print(f"已删除文件: {path}")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"删除文件失败: {path}, {e}")
|
||||
|
||||
# 使用数据集名称从 dataset_info.json 移除条目
|
||||
if dataset_name:
|
||||
update_dataset_info_json(remove_filename=dataset_name)
|
||||
|
||||
# 删除数据库记录
|
||||
cursor.execute("DELETE FROM dataset_files WHERE dataset_id = %s", (id,))
|
||||
cursor.execute("DELETE FROM dataset_manage WHERE id = %s", (id,))
|
||||
@@ -218,6 +281,12 @@ def upload_dataset_file(dataset_id):
|
||||
file.save(file_path)
|
||||
file_size = os.path.getsize(file_path)
|
||||
|
||||
# 获取数据集名称用于作为 dataset_info.json 的 key
|
||||
dataset_name = dataset.get('name') if dataset else None
|
||||
|
||||
# 更新 dataset_info.json,使用数据集名称作为 key,实际保存的文件名作为 file_name
|
||||
update_dataset_info_json(dataset_name=dataset_name, actual_filename=new_filename)
|
||||
|
||||
# 获取文件扩展名(安全处理无扩展名的情况)
|
||||
parts = filename.rsplit('.', 1)
|
||||
ext = parts[1].lower() if len(parts) > 1 else 'unknown'
|
||||
@@ -304,8 +373,8 @@ def delete_dataset_file(file_id):
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 获取文件信息
|
||||
cursor.execute("SELECT dataset_id, file_path FROM dataset_files WHERE id = %s", (file_id,))
|
||||
# 获取文件信息(包含 dataset_id)
|
||||
cursor.execute("SELECT dataset_id, file_name, file_path FROM dataset_files WHERE id = %s", (file_id,))
|
||||
file_info = cursor.fetchone()
|
||||
|
||||
if not file_info:
|
||||
@@ -313,6 +382,13 @@ def delete_dataset_file(file_id):
|
||||
conn.close()
|
||||
return jsonify({'code': 1, 'message': '文件不存在'})
|
||||
|
||||
dataset_id = file_info['dataset_id']
|
||||
|
||||
# 获取数据集名称(用于从 dataset_info.json 移除条目)
|
||||
cursor.execute("SELECT name FROM dataset_manage WHERE id = %s", (dataset_id,))
|
||||
dataset_result = cursor.fetchone()
|
||||
dataset_name = dataset_result['name'] if dataset_result else None
|
||||
|
||||
# 删除物理文件
|
||||
file_path = file_info['file_path']
|
||||
if file_path and os.path.exists(file_path):
|
||||
@@ -324,8 +400,11 @@ def delete_dataset_file(file_id):
|
||||
# 删除数据库记录
|
||||
cursor.execute("DELETE FROM dataset_files WHERE id = %s", (file_id,))
|
||||
|
||||
# 使用数据集名称从 dataset_info.json 移除条目
|
||||
if dataset_name:
|
||||
update_dataset_info_json(remove_filename=dataset_name)
|
||||
|
||||
# 更新数据集的文件数量和大小
|
||||
dataset_id = file_info['dataset_id']
|
||||
cursor.execute("SELECT COUNT(*) as count, SUM(file_size) as total_size FROM dataset_files WHERE dataset_id = %s", (dataset_id,))
|
||||
result = cursor.fetchone()
|
||||
file_count = result['count'] or 0
|
||||
|
||||
443
src/api/fine_tune.py
Normal file
443
src/api/fine_tune.py
Normal file
@@ -0,0 +1,443 @@
|
||||
"""
|
||||
精调训练 API 路由
|
||||
调用 llamafactory-cli 执行训练任务
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from flask import Blueprint, request, jsonify
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
train_logger = logging.getLogger('train') # 专门的训练日志 logger,输出到 train.log
|
||||
|
||||
# 创建蓝图
|
||||
fine_tune_bp = Blueprint('fine_tune', __name__, url_prefix='/api/fine-tune')
|
||||
|
||||
|
||||
# 训练类型映射
|
||||
TRAIN_TYPE_MAP = {
|
||||
'SFT': 'sft',
|
||||
'DPO': 'dpo',
|
||||
'CPT': 'cpt'
|
||||
}
|
||||
|
||||
# 训练方法映射
|
||||
FINETUNING_TYPE_MAP = {
|
||||
'lora': 'lora',
|
||||
'full': 'full'
|
||||
}
|
||||
|
||||
|
||||
@fine_tune_bp.route('/start', methods=['POST'])
|
||||
def start_training():
|
||||
"""启动 llamafactory 训练任务"""
|
||||
try:
|
||||
data = request.json
|
||||
train_logger.info(f"[TRAIN] ========== 开始训练任务 ==========")
|
||||
train_logger.info(f"[TRAIN] 收到启动训练请求: base_model={data.get('base_model')}, train_dataset_id={data.get('train_dataset_id')}")
|
||||
|
||||
# 必填参数验证
|
||||
required_fields = ['base_model', 'template', 'train_dataset_id']
|
||||
for field in required_fields:
|
||||
if not data.get(field):
|
||||
return jsonify({'code': 1, 'message': f'缺少必要参数: {field}'})
|
||||
|
||||
# 获取模型信息
|
||||
model_path = data.get('base_model')
|
||||
# 尝试转换为整数
|
||||
try:
|
||||
model_id = int(model_path) if str(model_path).isdigit() else None
|
||||
except (ValueError, TypeError):
|
||||
model_id = None
|
||||
|
||||
if model_id:
|
||||
# 如果是 model_id,需要获取模型路径
|
||||
from .model_manage import get_db_connection
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT id, name, path FROM model_manage WHERE id = %s", (model_id,))
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
logger.info(f"模型查询结果: {result}")
|
||||
if result and result.get('path'):
|
||||
model_path = result['path']
|
||||
logger.info(f"从数据库获取的模型路径: {model_path}")
|
||||
else:
|
||||
return jsonify({'code': 1, 'message': '模型不存在或路径为空'})
|
||||
elif not model_path:
|
||||
return jsonify({'code': 1, 'message': f'模型路径为空'})
|
||||
|
||||
train_logger.info(f"[TRAIN] 模型路径: {model_path}")
|
||||
|
||||
# 设置工作目录为 llamafactory 目录
|
||||
llamafactory_dir = '/app/src/llamafactory'
|
||||
|
||||
# 处理数据集文件:将数据集复制到 llamafactory 的 datasets 目录
|
||||
dataset_id = data.get('train_dataset_id')
|
||||
try:
|
||||
dataset_id_int = int(dataset_id) if str(dataset_id).isdigit() else None
|
||||
except (ValueError, TypeError):
|
||||
dataset_id_int = None
|
||||
|
||||
llamafactory_datasets_dir = os.path.join(llamafactory_dir, 'datasets')
|
||||
os.makedirs(llamafactory_datasets_dir, exist_ok=True)
|
||||
|
||||
# 获取数据集名称(用于 --dataset 参数)
|
||||
dataset_key = None
|
||||
if dataset_id_int:
|
||||
from .datasets import get_db_connection as get_dataset_conn
|
||||
conn = get_dataset_conn()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT dm.name FROM dataset_manage dm WHERE dm.id = %s", (dataset_id_int,))
|
||||
dataset_result = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
dataset_key = dataset_result['name'] if dataset_result else None
|
||||
|
||||
if dataset_key:
|
||||
# 从 dataset_info.json 读取实际文件名
|
||||
src_info_json = os.path.join('/app/base', 'datasets', 'dataset_info.json')
|
||||
actual_file_name = None
|
||||
if os.path.exists(src_info_json):
|
||||
import json as json_lib
|
||||
with open(src_info_json, 'r', encoding='utf-8') as f:
|
||||
dataset_info = json_lib.load(f)
|
||||
if dataset_key in dataset_info:
|
||||
actual_file_name = dataset_info[dataset_key].get('file_name')
|
||||
train_logger.info(f"[TRAIN] 从 dataset_info.json 获取文件名: {dataset_key} -> {actual_file_name}")
|
||||
|
||||
# 复制数据集文件到 llamafactory 目录
|
||||
if actual_file_name:
|
||||
src_file = os.path.join('/app/base', 'datasets', actual_file_name)
|
||||
dst_file = os.path.join(llamafactory_datasets_dir, actual_file_name)
|
||||
if os.path.exists(src_file):
|
||||
import shutil
|
||||
shutil.copy2(src_file, dst_file)
|
||||
train_logger.info(f"[TRAIN] 复制数据集文件: {src_file} -> {dst_file}")
|
||||
else:
|
||||
train_logger.warning(f"[TRAIN] 数据集文件不存在: {src_file}")
|
||||
|
||||
# 复制 dataset_info.json 到 llamafactory datasets 目录
|
||||
src_info_json = os.path.join('/app/base', 'datasets', 'dataset_info.json')
|
||||
dst_info_json = os.path.join(llamafactory_datasets_dir, 'dataset_info.json')
|
||||
try:
|
||||
if os.path.exists(src_info_json):
|
||||
shutil.copy2(src_info_json, dst_info_json)
|
||||
train_logger.info(f"[TRAIN] 已复制 dataset_info.json 到 llamafactory 目录")
|
||||
else:
|
||||
train_logger.warning(f"[TRAIN] dataset_info.json 不存在: {src_info_json}")
|
||||
except Exception as e:
|
||||
train_logger.warning(f"[TRAIN] 复制 dataset_info.json 失败: {e}")
|
||||
|
||||
# 获取选中的 GPU 索引
|
||||
gpus = data.get('gpus', [])
|
||||
if gpus:
|
||||
gpu_ids = [gpu.get('id', '').replace('gpu', '') for gpu in gpus]
|
||||
gpu_ids = [g for g in gpu_ids if g.isdigit()]
|
||||
cuda_devices = ','.join(gpu_ids)
|
||||
else:
|
||||
cuda_devices = '0'
|
||||
|
||||
# 设置环境变量
|
||||
env = os.environ.copy()
|
||||
env['CUDA_VISIBLE_DEVICES'] = cuda_devices
|
||||
env['TF_CPP_MIN_LOG_LEVEL'] = '2' # 减少 TensorFlow 日志
|
||||
|
||||
# 构建 llamafactory-cli 命令(传入数据集名称用于 --dataset 参数)
|
||||
cmd = build_train_command(data, model_path, dataset_key)
|
||||
cmd_str = ' '.join(cmd)
|
||||
train_logger.info(f"[TRAIN] 执行训练命令: {cmd_str}")
|
||||
|
||||
# 在返回的命令中显示 GPU 配置
|
||||
cmd_str_with_gpu = f"CUDA_VISIBLE_DEVICES={cuda_devices} {cmd_str}"
|
||||
|
||||
# 生成训练日志文件路径(按日期分目录)
|
||||
from datetime import datetime
|
||||
today = datetime.now().strftime('%Y-%m-%d')
|
||||
task_id_str = str(data.get('task_id', 'unknown'))
|
||||
log_dir = os.path.join(llamafactory_dir, 'logs', today)
|
||||
train_output_log = os.path.join(log_dir, f'train_{task_id_str}.log')
|
||||
|
||||
# 确保日志目录存在
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
train_logger.info(f"[TRAIN] 启动训练进程...")
|
||||
|
||||
# 使用线程在后台运行训练进程
|
||||
def run_training():
|
||||
with open(train_output_log, 'w', encoding='utf-8') as log_file:
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=llamafactory_dir,
|
||||
stdout=log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
env=env
|
||||
)
|
||||
train_logger.info(f"[TRAIN] 训练进程 PID: {process.pid}")
|
||||
# 等待进程完成
|
||||
process.wait()
|
||||
train_logger.info(f"[TRAIN] 训练进程已结束,退出码: {process.returncode}")
|
||||
|
||||
# 更新任务状态
|
||||
final_status = 'completed' if process.returncode == 0 else 'failed'
|
||||
update_fine_tune_status(data.get('task_id'), final_status, process.pid)
|
||||
|
||||
# 启动后台线程
|
||||
training_thread = threading.Thread(target=run_training, daemon=True)
|
||||
training_thread.start()
|
||||
|
||||
# 立即返回,不等待进程完成
|
||||
pid = None # 此时还不知道实际 PID,稍后可从日志获取
|
||||
train_logger.info(f"[TRAIN] 训练任务已在后台启动")
|
||||
train_logger.info(f"[TRAIN] 训练日志输出到: {train_output_log}")
|
||||
|
||||
# 更新任务状态为运行中
|
||||
update_fine_tune_status(data.get('task_id'), 'running', 0)
|
||||
|
||||
return jsonify({
|
||||
'code': 0,
|
||||
'message': f'训练任务已启动 (GPU: {cuda_devices})',
|
||||
'data': {
|
||||
'task_id': data.get('task_id'),
|
||||
'gpu_ids': cuda_devices,
|
||||
'command': cmd_str_with_gpu,
|
||||
'log_file': train_output_log
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
train_logger.error(f"[TRAIN] 启动训练任务失败: {e}")
|
||||
train_logger.error(f"[TRAIN] 详细错误: {traceback.format_exc()}")
|
||||
return jsonify({'code': 1, 'message': str(e)})
|
||||
|
||||
|
||||
def build_train_command(data, model_path, dataset_name=None):
|
||||
"""构建 llamafactory-cli train 命令"""
|
||||
# llamafactory-cli 路径(已在系统 PATH 中)
|
||||
cmd = ['llamafactory-cli', 'train']
|
||||
|
||||
# 训练阶段
|
||||
train_type = data.get('train_type', 'SFT')
|
||||
cmd.extend(['--stage', TRAIN_TYPE_MAP.get(train_type, 'sft')])
|
||||
cmd.append('--do_train')
|
||||
|
||||
# 模型路径
|
||||
cmd.extend(['--model_name_or_path', model_path])
|
||||
|
||||
# 数据集 - 使用数据集名称(dataset_manage.name),不是实际文件名
|
||||
if dataset_name:
|
||||
cmd.extend(['--dataset', dataset_name])
|
||||
train_logger.info(f"[TRAIN] 使用数据集名称: {dataset_name}")
|
||||
else:
|
||||
# 回退到原有逻辑
|
||||
dataset_id = data.get('train_dataset_id')
|
||||
try:
|
||||
dataset_id_int = int(dataset_id) if str(dataset_id).isdigit() else None
|
||||
except (ValueError, TypeError):
|
||||
dataset_id_int = None
|
||||
|
||||
if dataset_id_int:
|
||||
dataset_name = get_dataset_name(dataset_id_int)
|
||||
train_logger.info(f"[TRAIN] 从数据库获取的数据集名称: {dataset_name}")
|
||||
else:
|
||||
dataset_name = dataset_id
|
||||
cmd.extend(['--dataset', dataset_name])
|
||||
|
||||
# 数据集目录
|
||||
cmd.extend(['--dataset_dir', './datasets']) # llamafactory 工作目录下的 datasets 目录
|
||||
|
||||
# 模板
|
||||
template = data.get('template')
|
||||
cmd.extend(['--template', template])
|
||||
|
||||
# 训练方法
|
||||
train_method = data.get('train_method', 'lora')
|
||||
cmd.extend(['--finetuning_type', FINETUNING_TYPE_MAP.get(train_method, 'lora')])
|
||||
|
||||
# 输出目录
|
||||
output_dir = data.get('output_model_name', f"./saves/{template}/{train_method}")
|
||||
if not output_dir.startswith('./'):
|
||||
output_dir = f"./saves/{output_dir}"
|
||||
cmd.extend(['--output_dir', output_dir])
|
||||
|
||||
# 常用参数
|
||||
cmd.extend([
|
||||
'--overwrite_cache',
|
||||
'--overwrite_output_dir',
|
||||
'--cutoff_len', str(data.get('max_length', 512)),
|
||||
'--preprocessing_num_workers', '16',
|
||||
'--per_device_train_batch_size', str(data.get('batch_size', 1)),
|
||||
'--per_device_eval_batch_size', '1',
|
||||
'--gradient_accumulation_steps', str(data.get('gradient_accumulation_steps', 8)),
|
||||
'--lr_scheduler_type', data.get('lr_scheduler_type', 'cosine'),
|
||||
'--logging_steps', '50',
|
||||
'--warmup_steps', str(data.get('warmup_steps', 20)),
|
||||
'--save_steps', '100',
|
||||
'--eval_steps', str(data.get('eval_steps', 100)),
|
||||
])
|
||||
|
||||
# 学习率
|
||||
cmd.extend(['--learning_rate', str(data.get('learning_rate', 0.0001))])
|
||||
|
||||
# 训练轮数
|
||||
cmd.extend(['--num_train_epochs', str(data.get('n_epochs', 1.0))])
|
||||
|
||||
# 验证集比例
|
||||
val_ratio = data.get('valid_ratio', 0)
|
||||
if val_ratio > 0:
|
||||
cmd.extend(['--val_size', str(val_ratio / 100)])
|
||||
|
||||
# 最大样本数
|
||||
if data.get('max_samples'):
|
||||
cmd.extend(['--max_samples', str(data.get('max_samples'))])
|
||||
|
||||
# 其他选项
|
||||
if data.get('plot_loss'):
|
||||
cmd.append('--plot_loss')
|
||||
|
||||
if data.get('fp16'):
|
||||
cmd.append('--fp16')
|
||||
|
||||
if data.get('load_best_model_at_end'):
|
||||
cmd.append('--load_best_model_at_end')
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
def get_dataset_name(dataset_id):
|
||||
"""根据数据集 ID 获取数据集名称"""
|
||||
try:
|
||||
from .datasets import get_db_connection
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT id, name FROM dataset_manage WHERE id = %s", (dataset_id,))
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
logger.info(f"数据集查询结果: {result}")
|
||||
if result and result.get('name'):
|
||||
return result['name']
|
||||
logger.warning(f"未找到数据集 ID={dataset_id},使用默认值")
|
||||
return 'default'
|
||||
except Exception as e:
|
||||
logger.error(f"查询数据集失败: {e}")
|
||||
return 'default'
|
||||
|
||||
|
||||
def update_fine_tune_status(task_id, status, pid=None):
|
||||
"""更新训练任务状态"""
|
||||
try:
|
||||
from .model_manage import get_db_connection
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
if status == 'running' and pid:
|
||||
cursor.execute(
|
||||
"UPDATE fine_tune SET status = %s, process_id = %s WHERE id = %s",
|
||||
(status, pid, task_id)
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"UPDATE fine_tune SET status = %s WHERE id = %s",
|
||||
(status, task_id)
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logger.error(f"更新任务状态失败: {e}")
|
||||
|
||||
|
||||
@fine_tune_bp.route('/stop/<int:task_id>', methods=['POST'])
|
||||
def stop_training(task_id):
|
||||
"""停止训练任务"""
|
||||
try:
|
||||
from .model_manage import get_db_connection
|
||||
|
||||
# 获取进程 ID
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT process_id FROM fine_tune WHERE id = %s", (task_id,))
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
if result and result.get('process_id'):
|
||||
pid = result['process_id']
|
||||
try:
|
||||
# 尝试终止进程
|
||||
import signal
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
logger.info(f"已终止训练进程 PID: {pid}")
|
||||
except ProcessLookupError:
|
||||
logger.warning(f"进程 {pid} 不存在")
|
||||
except PermissionError:
|
||||
logger.error(f"没有权限终止进程 {pid}")
|
||||
|
||||
# 更新状态
|
||||
update_fine_tune_status(task_id, 'stopped')
|
||||
|
||||
return jsonify({'code': 0, 'message': '训练任务已停止'})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"停止训练任务失败: {e}")
|
||||
return jsonify({'code': 1, 'message': str(e)})
|
||||
|
||||
|
||||
@fine_tune_bp.route('/status/<int:task_id>', methods=['GET'])
|
||||
def get_training_status(task_id):
|
||||
"""获取训练任务状态"""
|
||||
try:
|
||||
from .model_manage import get_db_connection
|
||||
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT id, name, status, progress, process_id FROM fine_tune WHERE id = %s",
|
||||
(task_id,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
if result:
|
||||
return jsonify({
|
||||
'code': 0,
|
||||
'data': {
|
||||
'task_id': result['id'],
|
||||
'name': result['name'],
|
||||
'status': result['status'],
|
||||
'progress': result['progress'],
|
||||
'pid': result.get('process_id')
|
||||
}
|
||||
})
|
||||
else:
|
||||
return jsonify({'code': 1, 'message': '任务不存在'})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取任务状态失败: {e}")
|
||||
return jsonify({'code': 1, 'message': str(e)})
|
||||
|
||||
|
||||
def get_db_connection():
|
||||
"""获取数据库连接"""
|
||||
import pymysql
|
||||
import yaml
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
|
||||
|
||||
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||
CONFIG = yaml.safe_load(f)
|
||||
|
||||
db_config = CONFIG['database']
|
||||
return pymysql.connect(
|
||||
host=db_config['host'],
|
||||
port=db_config['port'],
|
||||
user=db_config['username'],
|
||||
password=db_config['password'],
|
||||
database=db_config['name'],
|
||||
charset=db_config.get('charset', 'utf8mb4'),
|
||||
cursorclass=pymysql.cursors.DictCursor
|
||||
)
|
||||
@@ -6,8 +6,12 @@ import pymysql
|
||||
import yaml
|
||||
from flask import Blueprint, request, jsonify
|
||||
|
||||
# 获取项目根目录
|
||||
# 获取项目根目录 - 优先使用环境变量,否则从文件路径计算
|
||||
MOUNT_BASE = os.environ.get('MOUNT_BASE', '/app/base')
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
# 如果 PROJECT_ROOT 是 /app 或 /app/src/llamafactory,则使用挂载路径
|
||||
if PROJECT_ROOT in ('/app', '/app/src/llamafactory'):
|
||||
PROJECT_ROOT = MOUNT_BASE
|
||||
|
||||
# 创建蓝图
|
||||
model_manage_bp = Blueprint('model_manage', __name__, url_prefix='/api/model-manage')
|
||||
|
||||
60
src/main.py
60
src/main.py
@@ -86,6 +86,15 @@ def setup_logger(name='app'):
|
||||
datefmt='%H:%M:%S'
|
||||
))
|
||||
|
||||
# 5. 训练日志处理器 - 专门记录训练输出
|
||||
train_log_path = os.path.join(log_dir, 'train.log')
|
||||
train_handler = RotatingFileHandler(train_log_path, maxBytes=100*1024*1024, backupCount=5, encoding='utf-8')
|
||||
train_handler.setLevel(logging.INFO)
|
||||
train_handler.setFormatter(logging.Formatter(
|
||||
'[%(asctime)s] %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
))
|
||||
|
||||
# 添加处理器到 logger
|
||||
logger.addHandler(all_handler)
|
||||
logger.addHandler(error_handler)
|
||||
@@ -98,6 +107,13 @@ def setup_logger(name='app'):
|
||||
request_logger.addHandler(request_handler)
|
||||
request_logger.addHandler(console_handler)
|
||||
|
||||
# 为训练日志创建单独的 logger
|
||||
train_logger = logging.getLogger('train')
|
||||
train_logger.setLevel(logging.INFO)
|
||||
train_logger.handlers.clear()
|
||||
train_logger.addHandler(train_handler)
|
||||
train_logger.addHandler(console_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
@@ -137,6 +153,7 @@ def init_database():
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
base_model VARCHAR(255),
|
||||
template VARCHAR(100) COMMENT '训练模板,如 qwen, llama, chatglm 等',
|
||||
train_type VARCHAR(50),
|
||||
train_method VARCHAR(50),
|
||||
gpus JSON COMMENT 'GPU硬件选择,支持多卡训练',
|
||||
@@ -144,6 +161,7 @@ def init_database():
|
||||
valid_split VARCHAR(50),
|
||||
valid_ratio INT DEFAULT 10,
|
||||
output_model_name VARCHAR(255),
|
||||
process_id INT COMMENT '训练进程ID',
|
||||
status VARCHAR(50) DEFAULT 'pending',
|
||||
progress INT DEFAULT 0,
|
||||
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
@@ -305,6 +323,44 @@ def init_database():
|
||||
except Exception:
|
||||
pass # 列已存在时不输出任何信息
|
||||
|
||||
# 为 fine_tune 表添加 template 列
|
||||
try:
|
||||
cursor.execute("ALTER TABLE fine_tune ADD COLUMN template VARCHAR(100) COMMENT '训练模板,如 qwen, llama, chatglm 等'")
|
||||
logger.debug("fine_tune 表添加 template 列成功")
|
||||
except Exception:
|
||||
pass # 列已存在时不输出任何信息
|
||||
|
||||
# 为 fine_tune 表添加 process_id 列
|
||||
try:
|
||||
cursor.execute("ALTER TABLE fine_tune ADD COLUMN process_id INT COMMENT '训练进程ID'")
|
||||
logger.debug("fine_tune 表添加 process_id 列成功")
|
||||
except Exception:
|
||||
pass # 列已存在时不输出任何信息
|
||||
|
||||
# 为 fine_tune 表添加训练相关列
|
||||
columns_to_add = [
|
||||
("train_dataset_id", "INT COMMENT '训练数据集ID'"),
|
||||
("valid_dataset_id", "INT COMMENT '验证数据集ID'"),
|
||||
("eval_steps", "INT DEFAULT 100 COMMENT '评估步数'"),
|
||||
("lr_scheduler_type", "VARCHAR(50) DEFAULT 'cosine' COMMENT '学习率调度器'"),
|
||||
("warmup_ratio", "FLOAT DEFAULT 0.05 COMMENT '预热比例'"),
|
||||
("weight_decay", "FLOAT DEFAULT 0.01 COMMENT '权重衰减'"),
|
||||
("batch_size", "INT DEFAULT 1 COMMENT '批次大小'"),
|
||||
("learning_rate", "FLOAT DEFAULT 0.0001 COMMENT '学习率'"),
|
||||
("n_epochs", "FLOAT DEFAULT 1.0 COMMENT '训练轮数'"),
|
||||
("max_length", "INT DEFAULT 512 COMMENT '最大长度'"),
|
||||
("lora_alpha", "VARCHAR(10) DEFAULT '32' COMMENT 'LoRA alpha'"),
|
||||
("lora_rank", "VARCHAR(10) DEFAULT '8' COMMENT 'LoRA rank'"),
|
||||
("lora_dropout", "FLOAT DEFAULT 0.1 COMMENT 'LoRA dropout'"),
|
||||
("valid_ratio", "INT DEFAULT 10 COMMENT '验证集比例'"),
|
||||
]
|
||||
for col_name, col_def in columns_to_add:
|
||||
try:
|
||||
cursor.execute(f"ALTER TABLE fine_tune ADD COLUMN {col_name} {col_def}")
|
||||
logger.debug(f"fine_tune 表添加 {col_name} 列成功")
|
||||
except Exception:
|
||||
pass # 列已存在时不输出任何信息
|
||||
|
||||
# 插入默认管理员用户
|
||||
cursor.execute("SELECT * FROM users WHERE username = 'admin'")
|
||||
if not cursor.fetchone():
|
||||
@@ -323,8 +379,8 @@ def init_database():
|
||||
app = Flask(__name__)
|
||||
app.config['SECRET_KEY'] = CONFIG['secret_key']
|
||||
app.config['CORS_HEADERS'] = 'Content-Type'
|
||||
# 使用字符串形式的 origins
|
||||
CORS(app, origins="*", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["Content-Type", "Authorization"], supports_credentials=False)
|
||||
# 允许所有来源
|
||||
CORS(app, resources={r"/api/*": {"origins": "*"}}, methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["Content-Type", "Authorization"])
|
||||
|
||||
# 注册蓝图
|
||||
register_blueprints(app)
|
||||
|
||||
Reference in New Issue
Block a user