模型开始训练界面以及查看日志功能完善

This commit is contained in:
2026-01-29 10:36:59 +08:00
parent a560d24e2f
commit e9e0e21e47
11 changed files with 2485 additions and 179 deletions

View File

@@ -3,16 +3,31 @@
调用 llamafactory-cli 执行训练任务
"""
import os
import sys
import subprocess
import json
import threading
import time
import signal
import yaml
from flask import Blueprint, request, jsonify
import logging
# 添加项目根目录到路径
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, PROJECT_ROOT)
# 加载配置
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
CONFIG = yaml.safe_load(f)
logger = logging.getLogger(__name__)
train_logger = logging.getLogger('train') # 专门的训练日志 logger输出到 train.log
# 从配置获取训练日志路径
TRAINING_LOGS_DIR = CONFIG.get('training_logs_path', '/app/base/training_logs')
# 创建蓝图
fine_tune_bp = Blueprint('fine_tune', __name__, url_prefix='/api/fine-tune')
@@ -72,21 +87,21 @@ def start_training():
train_logger.info(f"[TRAIN] 模型路径: {model_path}")
# 设置工作目录 llamafactory 目录
llamafactory_dir = '/app/src/llamafactory'
# 设置工作目录 llamafactory 目录
work_dir = '/app/base'
llamafactory_dir = '/app/base'
# 处理数据集文件:将数据集复制到 llamafactory 的 datasets 目录
# 数据集目录直接使用 /app/base/datasets(不再复制)
datasets_dir = '/app/base/datasets'
# 获取数据集名称(用于 --dataset 参数)
dataset_key = None
dataset_id = data.get('train_dataset_id')
try:
dataset_id_int = int(dataset_id) if str(dataset_id).isdigit() else None
except (ValueError, TypeError):
dataset_id_int = None
llamafactory_datasets_dir = os.path.join(llamafactory_dir, 'datasets')
os.makedirs(llamafactory_datasets_dir, exist_ok=True)
# 获取数据集名称(用于 --dataset 参数)
dataset_key = None
if dataset_id_int:
from .datasets import get_db_connection as get_dataset_conn
conn = get_dataset_conn()
@@ -94,43 +109,8 @@ def start_training():
cursor.execute("SELECT dm.name FROM dataset_manage dm WHERE dm.id = %s", (dataset_id_int,))
dataset_result = cursor.fetchone()
conn.close()
dataset_key = dataset_result['name'] if dataset_result else None
if dataset_key:
# 从 dataset_info.json 读取实际文件名
src_info_json = os.path.join('/app/base', 'datasets', 'dataset_info.json')
actual_file_name = None
if os.path.exists(src_info_json):
import json as json_lib
with open(src_info_json, 'r', encoding='utf-8') as f:
dataset_info = json_lib.load(f)
if dataset_key in dataset_info:
actual_file_name = dataset_info[dataset_key].get('file_name')
train_logger.info(f"[TRAIN] 从 dataset_info.json 获取文件名: {dataset_key} -> {actual_file_name}")
# 复制数据集文件到 llamafactory 目录
if actual_file_name:
src_file = os.path.join('/app/base', 'datasets', actual_file_name)
dst_file = os.path.join(llamafactory_datasets_dir, actual_file_name)
if os.path.exists(src_file):
import shutil
shutil.copy2(src_file, dst_file)
train_logger.info(f"[TRAIN] 复制数据集文件: {src_file} -> {dst_file}")
else:
train_logger.warning(f"[TRAIN] 数据集文件不存在: {src_file}")
# 复制 dataset_info.json 到 llamafactory datasets 目录
src_info_json = os.path.join('/app/base', 'datasets', 'dataset_info.json')
dst_info_json = os.path.join(llamafactory_datasets_dir, 'dataset_info.json')
try:
if os.path.exists(src_info_json):
shutil.copy2(src_info_json, dst_info_json)
train_logger.info(f"[TRAIN] 已复制 dataset_info.json 到 llamafactory 目录")
else:
train_logger.warning(f"[TRAIN] dataset_info.json 不存在: {src_info_json}")
except Exception as e:
train_logger.warning(f"[TRAIN] 复制 dataset_info.json 失败: {e}")
train_logger.info(f"[TRAIN] 数据集名称: {dataset_key}")
# 获取选中的 GPU 索引
gpus = data.get('gpus', [])
@@ -145,6 +125,9 @@ def start_training():
env = os.environ.copy()
env['CUDA_VISIBLE_DEVICES'] = cuda_devices
env['TF_CPP_MIN_LOG_LEVEL'] = '2' # 减少 TensorFlow 日志
env['LLAMAFACTORY_DIR'] = '/app/base' # 指定 llamafactory 根目录
env['PYTHONUNBUFFERED'] = '1' # 强制 Python 不缓冲输出,实时写入日志
env['TRANSFORMERS_VERBOSITY'] = 'INFO' # 设置 transformers 日志级别
# 构建 llamafactory-cli 命令(传入数据集名称用于 --dataset 参数)
cmd = build_train_command(data, model_path, dataset_key)
@@ -154,57 +137,93 @@ def start_training():
# 在返回的命令中显示 GPU 配置
cmd_str_with_gpu = f"CUDA_VISIBLE_DEVICES={cuda_devices} {cmd_str}"
# 生成训练日志文件路径(日期目录)
# 生成训练日志文件路径(存储在 logs 目录下的日期目录
from datetime import datetime
today = datetime.now().strftime('%Y-%m-%d')
task_id_str = str(data.get('task_id', 'unknown'))
log_dir = os.path.join(llamafactory_dir, 'logs', today)
train_output_log = os.path.join(log_dir, f'train_{task_id_str}.log')
now_str = datetime.now().strftime('%Y%m%d_%H%M%S') # 时间戳用于排序
task_id = data.get('task_id', 'unknown')
task_name = data.get('name', 'unknown')
# 工作目录设为 /app/base而非 llamafactory 目录)
work_dir = '/app/base'
# 使用 logs 目录下的日期子目录
training_logs_dir = os.path.join('/app/base/logs', today)
os.makedirs(training_logs_dir, exist_ok=True)
# 确保日志目录存在
os.makedirs(log_dir, exist_ok=True)
# 日志文件路径: logs/{日期}/{task_id}_{task_name}.log
log_file = os.path.join(training_logs_dir, f'{task_id}_{task_name}.log')
train_logger.info(f"[TRAIN] 启动训练进程...")
# 用于存储实际进程 PID
actual_pid = None
final_log_path = log_file
# 使用线程在后台运行训练进程
def run_training():
with open(train_output_log, 'w', encoding='utf-8') as log_file:
nonlocal actual_pid, final_log_path
# 从 data 中获取 template 和 train_method与 build_train_command 保持一致)
template = data.get('template', 'default')
train_method = data.get('train_method', 'lora')
# 创建输出目录(如果不存在)
output_model_name = data.get('output_model_name', f"{template}/{train_method}")
if not output_model_name.startswith('/'):
output_model_name = f"/app/base/saves/{output_model_name}"
output_dir = output_model_name
os.makedirs(output_dir, exist_ok=True)
train_logger.info(f"[TRAIN] 输出目录: {output_dir}")
train_logger.info(f"[TRAIN] 完整训练命令: {' '.join(cmd)}")
with open(log_file, 'w', encoding='utf-8') as f:
# 设置 cwd 为 /app但通过 LLAMAFACTORY_DIR 环境变量指定 llamafactory 位置
process = subprocess.Popen(
cmd,
cwd=llamafactory_dir,
stdout=log_file,
cwd=work_dir,
stdout=f,
stderr=subprocess.STDOUT,
env=env
)
train_logger.info(f"[TRAIN] 训练进程 PID: {process.pid}")
actual_pid = process.pid
train_logger.info(f"[TRAIN] 训练进程 PID: {actual_pid}")
train_logger.info(f"[TRAIN] 日志文件: {log_file}")
# 更新数据库中的 PID立即更新方便停止任务
update_fine_tune_status(task_id, 'running', actual_pid)
# 等待进程完成
process.wait()
train_logger.info(f"[TRAIN] 训练进程已结束,退出码: {process.returncode}")
# 更新任务状态
final_status = 'completed' if process.returncode == 0 else 'failed'
update_fine_tune_status(data.get('task_id'), final_status, process.pid)
update_fine_tune_status(task_id, final_status, actual_pid)
# 启动后台线程
training_thread = threading.Thread(target=run_training, daemon=True)
training_thread.start()
# 立即返回,不等待进程完成
pid = None # 此时还不知道实际 PID稍后可从日志获取
train_logger.info(f"[TRAIN] 训练任务已在后台启动")
train_logger.info(f"[TRAIN] 训练日志输出到: {train_output_log}")
# 等待 PID 并更新到数据库
for i in range(10): # 最多等待1秒
time.sleep(0.1)
if actual_pid:
break
# 更新任务状态为运行中
update_fine_tune_status(data.get('task_id'), 'running', 0)
# 立即返回,不等待进程完成
train_logger.info(f"[TRAIN] 训练任务已在后台启动PID: {actual_pid}")
train_logger.info(f"[TRAIN] 训练日志输出到: {log_file}")
return jsonify({
'code': 0,
'message': f'训练任务已启动 (GPU: {cuda_devices})',
'data': {
'task_id': data.get('task_id'),
'task_id': task_id,
'pid': actual_pid,
'gpu_ids': cuda_devices,
'command': cmd_str_with_gpu,
'log_file': train_output_log
'log_file': log_file,
'training_logs_dir': training_logs_dir
}
})
@@ -258,10 +277,11 @@ def build_train_command(data, model_path, dataset_name=None):
train_method = data.get('train_method', 'lora')
cmd.extend(['--finetuning_type', FINETUNING_TYPE_MAP.get(train_method, 'lora')])
# 输出目录
output_dir = data.get('output_model_name', f"./saves/{template}/{train_method}")
if not output_dir.startswith('./'):
output_dir = f"./saves/{output_dir}"
# 输出目录(确保是绝对路径)
output_model_name = data.get('output_model_name', f"{template}/{train_method}")
if not output_model_name.startswith('/'):
output_model_name = f"/app/base/saves/{output_model_name}"
output_dir = output_model_name
cmd.extend(['--output_dir', output_dir])
# 常用参数
@@ -274,10 +294,11 @@ def build_train_command(data, model_path, dataset_name=None):
'--per_device_eval_batch_size', '1',
'--gradient_accumulation_steps', str(data.get('gradient_accumulation_steps', 8)),
'--lr_scheduler_type', data.get('lr_scheduler_type', 'cosine'),
'--logging_steps', '50',
'--logging_steps', '5',
'--warmup_steps', str(data.get('warmup_steps', 20)),
'--save_steps', '100',
'--eval_steps', str(data.get('eval_steps', 100)),
'--save_steps', str(data.get('save_steps', 100)),
'--log_level', 'info', # 设置日志级别为 info
'--log_level_replica', 'info', # 设置副本日志级别
])
# 学习率
@@ -295,9 +316,10 @@ def build_train_command(data, model_path, dataset_name=None):
if data.get('max_samples'):
cmd.extend(['--max_samples', str(data.get('max_samples'))])
# 启用 TensorBoard 日志(用于可视化训练曲线)
cmd.append('--plot_loss')
# 其他选项
if data.get('plot_loss'):
cmd.append('--plot_loss')
if data.get('fp16'):
cmd.append('--fp16')
@@ -386,6 +408,57 @@ def stop_training(task_id):
return jsonify({'code': 1, 'message': str(e)})
@fine_tune_bp.route('/<int:task_id>', methods=['DELETE'])
def delete_training_task(task_id):
"""删除训练任务及对应的日志文件"""
try:
from .model_manage import get_db_connection
# 获取任务信息(用于删除日志文件)
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("SELECT name, process_id FROM fine_tune WHERE id = %s", (task_id,))
task_result = cursor.fetchone()
conn.close()
if not task_result:
return jsonify({'code': 1, 'message': '任务不存在'})
task_name = task_result.get('name', 'unknown')
# 删除日志文件 (logs/{日期}/{task_id}_{task_name}.log)
try:
from datetime import datetime
today = datetime.now().strftime('%Y-%m-%d')
# 可能的日志文件路径
log_paths = [
f'/app/base/logs/{today}/{task_id}_{task_name}.log',
f'/app/base/logs/{task_id}_{task_name}.log',
]
for log_path in log_paths:
if os.path.exists(log_path):
os.remove(log_path)
logger.info(f"已删除日志文件: {log_path}")
except Exception as log_err:
logger.warning(f"删除日志文件失败: {log_err}")
# 删除数据库中的任务记录
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("DELETE FROM fine_tune WHERE id = %s", (task_id,))
conn.commit()
conn.close()
logger.info(f"已删除训练任务 {task_id}: {task_name}")
return jsonify({'code': 0, 'message': '删除成功'})
except Exception as e:
logger.error(f"删除训练任务失败: {e}")
return jsonify({'code': 1, 'message': str(e)})
@fine_tune_bp.route('/status/<int:task_id>', methods=['GET'])
def get_training_status(task_id):
"""获取训练任务状态"""
@@ -402,12 +475,25 @@ def get_training_status(task_id):
conn.close()
if result:
# 检查 PID 是否仍在运行
actual_status = result['status']
pid = result.get('process_id')
if pid and actual_status == 'running':
try:
# 检查进程是否存在
os.kill(pid, 0)
# 进程仍在运行
actual_status = 'running'
except (OSError, ProcessLookupError):
# 进程已结束,尝试更新状态
actual_status = 'completed' # 假设完成(实际可能失败)
return jsonify({
'code': 0,
'data': {
'task_id': result['id'],
'name': result['name'],
'status': result['status'],
'status': actual_status,
'progress': result['progress'],
'pid': result.get('process_id')
}
@@ -420,6 +506,254 @@ def get_training_status(task_id):
return jsonify({'code': 1, 'message': str(e)})
@fine_tune_bp.route('/check-pid/<int:pid>', methods=['GET'])
def check_pid_status(pid):
"""检查 PID 是否仍在运行"""
try:
if pid <= 0:
return jsonify({
'code': 0,
'data': {
'exists': False,
'message': '无效的 PID'
}
})
try:
# 发送信号 0 来检查进程是否存在(不会实际终止进程)
os.kill(pid, 0)
return jsonify({
'code': 0,
'data': {
'exists': True,
'message': '进程仍在运行'
}
})
except (OSError, ProcessLookupError):
# 进程不存在
return jsonify({
'code': 0,
'data': {
'exists': False,
'message': '进程已结束'
}
})
except Exception as e:
logger.error(f"检查 PID 状态失败: {e}")
return jsonify({
'code': 0,
'data': {
'exists': False,
'message': f'检查失败: {str(e)}'
}
})
@fine_tune_bp.route('/log/<int:task_id>', methods=['GET'])
def get_training_log(task_id):
"""获取训练任务日志内容(支持实时读取)"""
try:
from .model_manage import get_db_connection
# 获取任务信息和进程ID
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute(
"SELECT name, process_id, status FROM fine_tune WHERE id = %s",
(task_id,)
)
result = cursor.fetchone()
conn.close()
if not result:
return jsonify({'code': 1, 'message': '任务不存在'})
process_id = result.get('process_id')
task_name = result['name']
status = result['status']
if not process_id:
return jsonify({'code': 1, 'message': '任务尚未启动'})
# 构建日志文件路径 - 新格式: logs/{日期}/{task_id}_{task_name}.log
from datetime import datetime
today = datetime.now().strftime('%Y-%m-%d')
training_logs_dir = os.path.join('/app/base/logs', today)
# 查找日志文件 (新格式: {task_id}_{task_name}.log)
log_file = os.path.join(training_logs_dir, f'{task_id}_{task_name}.log')
if not os.path.exists(log_file):
# 如果没找到,返回空日志
return jsonify({
'code': 0,
'data': {
'content': '',
'status': status,
'message': '日志文件尚未创建'
}
})
# 读取日志文件内容
try:
with open(log_file, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
return jsonify({
'code': 0,
'data': {
'content': content,
'status': status,
'log_file': log_file
}
})
except Exception as e:
return jsonify({
'code': 0,
'data': {
'content': '',
'status': status,
'message': f'读取日志失败: {str(e)}'
}
})
except Exception as e:
logger.error(f"获取训练日志失败: {e}")
return jsonify({'code': 1, 'message': str(e)})
import re
@fine_tune_bp.route('/progress/<int:task_id>', methods=['GET'])
def get_training_progress(task_id):
"""获取训练任务进度(从日志中解析 llamafactory 的进度信息)"""
try:
from .model_manage import get_db_connection
# 获取任务信息和进程ID
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute(
"SELECT name, process_id, status FROM fine_tune WHERE id = %s",
(task_id,)
)
result = cursor.fetchone()
conn.close()
if not result:
return jsonify({'code': 1, 'message': '任务不存在'})
process_id = result.get('process_id')
task_name = result['name']
status = result['status']
if not process_id:
return jsonify({
'code': 0,
'data': {
'progress': 0,
'step': '',
'eta': '',
'speed': '',
'status': status,
'message': '任务尚未启动'
}
})
# 构建日志文件路径 - 新格式: logs/{日期}/{task_id}_{task_name}.log
from datetime import datetime
today = datetime.now().strftime('%Y-%m-%d')
training_logs_dir = os.path.join('/app/base/logs', today)
# 查找日志文件 (新格式: {task_id}_{task_name}.log)
log_file = os.path.join(training_logs_dir, f'{task_id}_{task_name}.log')
# TensorBoard 日志目录(使用默认值)
tensorboard_log_dir = '/app/base/saves'
if not os.path.exists(log_file):
return jsonify({
'code': 0,
'data': {
'step': '',
'elapsed': '',
'eta': '',
'speed': '',
'status': status,
'message': '日志文件尚未创建',
'tensorboard_url': ''
}
})
# 读取日志文件最后部分,解析进度信息
try:
with open(log_file, 'r', encoding='utf-8', errors='ignore') as f:
# 读取最后 10KB 内容
f.seek(0, 2) # 跳到文件末尾
file_size = f.tell()
read_size = min(10240, file_size)
f.seek(max(0, file_size - read_size))
content = f.read()
# 匹配 llamafactory 进度格式: 52%|█████▏ | 17/33 [02:16<02:08, 8.04s/it]
progress_pattern = r'\s*(\d+)%\|[█░▌▋▒█▏▎▏▐▀■□▪▫‣▶➜➡→]+\s*\|\s*(\d+)/(\d+)\s+\[(\d+):(\d+)<(\d+):(\d+),\s*([\d.]+)s/it\]'
match = re.search(progress_pattern, content)
step_info = ''
elapsed = ''
eta = ''
speed = ''
message = '等待训练开始'
if match:
current_step = int(match.group(2))
total_steps = int(match.group(3))
elapsed_min = int(match.group(4))
elapsed_sec = int(match.group(5))
eta_min = int(match.group(6))
eta_sec = int(match.group(7))
speed_val = float(match.group(8))
step_info = f'{current_step}/{total_steps}'
elapsed = f'{elapsed_min:02d}:{elapsed_sec:02d}'
eta = f'{eta_min:02d}:{eta_sec:02d}'
speed = f'{speed_val}s/it'
message = '训练进行中'
return jsonify({
'code': 0,
'data': {
'step': step_info,
'elapsed': elapsed,
'eta': eta,
'speed': speed,
'status': status,
'message': message,
'tensorboard_log_dir': tensorboard_log_dir,
'tensorboard_url': ''
}
})
except Exception as e:
return jsonify({
'code': 0,
'data': {
'step': '',
'elapsed': '',
'eta': '',
'speed': '',
'status': status,
'message': f'读取进度失败: {str(e)}',
'tensorboard_log_dir': tensorboard_log_dir,
'tensorboard_url': ''
}
})
except Exception as e:
logger.error(f"获取训练进度失败: {e}")
return jsonify({'code': 1, 'message': str(e)})
def get_db_connection():
"""获取数据库连接"""
import pymysql
@@ -441,3 +775,129 @@ def get_db_connection():
charset=db_config.get('charset', 'utf8mb4'),
cursorclass=pymysql.cursors.DictCursor
)
@fine_tune_bp.route('/check-name', methods=['GET'])
def check_task_name():
"""检查任务名称是否重复"""
try:
name = request.args.get('name', '').strip()
if not name:
return jsonify({'code': 1, 'message': '任务名称不能为空'})
# 验证任务名称格式:只能包含英文、数字、下划线
import re
if not re.match(r'^[a-zA-Z0-9_]+$', name):
return jsonify({'code': 1, 'message': '任务名称只能包含英文、数字和下划线'})
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("SELECT id FROM fine_tune WHERE name = %s", (name,))
result = cursor.fetchone()
conn.close()
if result:
return jsonify({
'code': 0,
'data': {
'exists': True,
'message': '任务名称已存在'
}
})
else:
return jsonify({
'code': 0,
'data': {
'exists': False,
'message': '任务名称可用'
}
})
except Exception as e:
logger.error(f"检查任务名称失败: {e}")
return jsonify({'code': 1, 'message': str(e)})
# TensorBoard 服务进程
tensorboard_process = None
@fine_tune_bp.route('/tensorboard/start', methods=['POST'])
def start_tensorboard():
"""启动 TensorBoard 服务"""
global tensorboard_process
try:
import subprocess
import os
# 检查是否已有进程在运行
if tensorboard_process and tensorboard_process.poll() is None:
return jsonify({
'code': 0,
'data': {
'url': 'http://10.10.10.177:6006',
'status': 'already_running',
'message': 'TensorBoard 服务已运行'
}
})
# 获取日志目录
log_dir = '/app/base/saves'
# 检查目录是否存在
if not os.path.exists(log_dir):
return jsonify({'code': 1, 'message': f'日志目录不存在: {log_dir}'})
# 启动 TensorBoard后台运行
cmd = ['tensorboard', '--logdir', log_dir, '--port', '6006', '--bind_all']
tensorboard_process = subprocess.Popen(
cmd,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
preexec_fn=os.setsid
)
logger.info(f"TensorBoard 服务已启动: {cmd}")
return jsonify({
'code': 0,
'data': {
'url': 'http://10.10.10.177:6006',
'status': 'started',
'message': 'TensorBoard 服务已启动'
}
})
except Exception as e:
logger.error(f"启动 TensorBoard 失败: {e}")
return jsonify({'code': 1, 'message': str(e)})
@fine_tune_bp.route('/tensorboard/stop', methods=['POST'])
def stop_tensorboard():
"""停止 TensorBoard 服务"""
global tensorboard_process
try:
import subprocess
import signal
if tensorboard_process and tensorboard_process.poll() is None:
# 使用 os.killpg 终止进程组
try:
os.killpg(os.getpgid(tensorboard_process.pid), signal.SIGTERM)
except Exception:
pass
tensorboard_process = None
logger.info("TensorBoard 服务已停止")
return jsonify({
'code': 0,
'data': {
'status': 'stopped',
'message': 'TensorBoard 服务已停止'
}
})
except Exception as e:
logger.error(f"停止 TensorBoard 失败: {e}")
return jsonify({'code': 1, 'message': str(e)})

View File

@@ -169,3 +169,231 @@ def get_log_content():
})
except Exception as e:
return jsonify({'code': 1, 'message': f'读取日志文件失败: {str(e)}'})
# ============ 训练日志相关 API ============
# 训练日志保存在 logs/{日期} 目录下
TRAINING_LOGS_BASE_DIR = '/app/base/logs'
# 本地开发时的备用路径Windows
LOCAL_TRAINING_LOGS_BASE_DIR = os.path.join(PROJECT_ROOT, 'logs')
# 添加调试日志
logs_logger.info(f"[DEBUG] TRAINING_LOGS_BASE_DIR: {TRAINING_LOGS_BASE_DIR}")
logs_logger.info(f"[DEBUG] LOCAL_TRAINING_LOGS_BASE_DIR: {LOCAL_TRAINING_LOGS_BASE_DIR}")
@logs_bp.route('/training-log-files', methods=['GET'])
def get_training_log_files():
"""获取训练日志文件列表 - 从 logs/{日期} 目录下的 .log 文件"""
try:
# 确定基础目录
logs_base_dir = TRAINING_LOGS_BASE_DIR
if not os.path.exists(logs_base_dir):
logs_base_dir = LOCAL_TRAINING_LOGS_BASE_DIR
logs_logger.info(f"[DEBUG] logs_base_dir: {logs_base_dir}, exists: {os.path.exists(logs_base_dir)}")
if not os.path.exists(logs_base_dir):
return jsonify({'code': 0, 'data': []})
# 遍历所有日期目录,收集训练日志文件
log_files = []
date_dirs = []
try:
# 获取所有日期目录(格式: YYYY-MM-DD
for item in os.listdir(logs_base_dir):
item_path = os.path.join(logs_base_dir, item)
if os.path.isdir(item_path):
# 验证是否为日期目录
try:
datetime.strptime(item, '%Y-%m-%d')
date_dirs.append(item)
except ValueError:
pass
except Exception as list_err:
logs_logger.error(f"[DEBUG] Failed to list base directory: {list_err}")
return jsonify({'code': 0, 'data': []})
# 按日期排序(最新的在前面)
date_dirs.sort(reverse=True)
logs_logger.info(f"[DEBUG] Date directories: {date_dirs}")
# 遍历每个日期目录,查找 .log 文件
for date_dir in date_dirs:
date_full_path = os.path.join(logs_base_dir, date_dir)
try:
files = os.listdir(date_full_path)
except Exception as list_err:
logs_logger.warning(f"[DEBUG] Failed to list {date_full_path}: {list_err}")
continue
for file_name in files:
if not file_name.endswith('.log'):
continue
file_path = os.path.join(date_full_path, file_name)
try:
size = os.path.getsize(file_path)
except Exception as size_err:
logs_logger.warning(f"[DEBUG] Failed to get size of {file_path}: {size_err}")
continue
# 文件名格式: {task_id}_{task_name}.log
# 例如: 889_testing.log
parts = file_name.replace('.log', '').split('_', 1)
if len(parts) >= 2:
task_id = parts[0]
task_name = parts[1]
try:
dt = datetime.strptime(date_dir, '%Y-%m-%d')
# 使用日期目录的时间作为排序键
sort_key = dt.timestamp()
display_date = date_dir
except:
sort_key = 0
display_date = date_dir
else:
task_id = 'unknown'
task_name = file_name.replace('.log', '')
sort_key = 0
display_date = date_dir
# 构建相对路径 (日期/文件名)
relative_path = f"{date_dir}/{file_name}"
log_files.append({
'name': task_name,
'file': relative_path,
'task_id': task_id,
'date': display_date,
'size': format_file_size(size),
'sort_key': sort_key
})
# 按时间戳排序(最新的在前面)
log_files.sort(key=lambda x: x['sort_key'] if x['sort_key'] else 0, reverse=True)
logs_logger.info(f"[DEBUG] Found {len(log_files)} training log files")
return jsonify({'code': 0, 'data': log_files})
except Exception as e:
logs_logger.error(f"[DEBUG] 获取训练日志列表失败: {e}")
return jsonify({'code': 1, 'message': f'获取训练日志列表失败: {str(e)}'})
@logs_bp.route('/training-log-content', methods=['GET'])
def get_training_log_content():
"""获取训练日志文件内容 - 从 logs/{日期}/ 目录"""
file_name = request.args.get('file')
if not file_name:
return jsonify({'code': 1, 'message': '缺少文件参数'})
logs_logger.info(f"[DEBUG] ============ get_training_log_content ============")
logs_logger.info(f"[DEBUG] file: {file_name}")
# 防止目录遍历攻击
file_name = file_name.replace('..', '').replace('//', '/')
# file 格式: 日期/文件名,例如: 2026-01-28/889_testing.log
# 解析日期和文件名
parts = file_name.split('/')
if len(parts) < 2:
return jsonify({'code': 1, 'message': '无效的文件路径格式'})
date_dir = parts[0]
log_file_name = '/'.join(parts[1:])
# 验证日期格式
try:
datetime.strptime(date_dir, '%Y-%m-%d')
except ValueError:
return jsonify({'code': 1, 'message': '无效的日期格式'})
# 确定基础目录
container_base_dir = TRAINING_LOGS_BASE_DIR # /app/base/logs
local_base_dir = LOCAL_TRAINING_LOGS_BASE_DIR # 项目目录下的 logs
container_full_path = os.path.join(container_base_dir, date_dir, log_file_name)
local_full_path = os.path.join(local_base_dir, date_dir, log_file_name)
logs_logger.info(f"[DEBUG] container_base_dir: {container_base_dir}, exists: {os.path.exists(container_base_dir)}")
logs_logger.info(f"[DEBUG] local_base_dir: {local_base_dir}, exists: {os.path.exists(local_base_dir)}")
logs_logger.info(f"[DEBUG] container_full_path: {container_full_path}, exists: {os.path.exists(container_full_path)}")
logs_logger.info(f"[DEBUG] local_full_path: {local_full_path}, exists: {os.path.exists(local_full_path)}")
# 选择最终路径
full_path = None
if os.path.exists(container_full_path):
full_path = container_full_path
logs_logger.info(f"[DEBUG] Using container path")
elif os.path.exists(local_full_path):
full_path = local_full_path
logs_logger.info(f"[DEBUG] Using local path")
else:
logs_logger.error(f"[DEBUG] File not found: {file_name}")
return jsonify({'code': 1, 'message': f'日志文件不存在: {file_name}'})
logs_logger.info(f"[DEBUG] Final full_path: {full_path}")
# 尝试直接读取文件
try:
max_size = 10 * 1024 * 1024
content = ''
read_success = False
try:
with open(full_path, 'rb') as f:
f.seek(0, 2)
size = f.tell()
f.seek(0)
if size > max_size:
f.seek(size - max_size)
content = '... (日志文件较大,已显示最后 10MB 内容) ...\n\n' + f.read().decode('utf-8', errors='ignore')
else:
content = f.read().decode('utf-8', errors='ignore')
read_success = True
except (PermissionError, OSError) as pe:
logs_logger.warning(f"[DEBUG] 直接读取失败: {pe},尝试共享模式读取")
import mmap
try:
with open(full_path, 'rb') as f:
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
try:
f.seek(0, 2)
size = f.tell()
if size > max_size:
content = '... (日志文件较大,已显示最后 10MB 内容) ...\n\n' + \
mm[-max_size:].decode('utf-8', errors='ignore')
else:
content = mm[:].decode('utf-8', errors='ignore')
read_success = True
finally:
mm.close()
except Exception as e2:
logs_logger.error(f"[DEBUG] 共享模式读取失败: {e2}")
return jsonify({
'code': 2,
'message': f'日志文件正在被训练进程占用,训练结束后可查看完整内容',
'data': {
'file': log_file_name,
'size': format_file_size(0),
'content': ''
}
})
if read_success:
return jsonify({
'code': 0,
'data': {
'file': log_file_name,
'size': format_file_size(size),
'content': content
}
})
except Exception as e:
logs_logger.error(f"[DEBUG] 读取日志文件失败: {e}")
return jsonify({'code': 1, 'message': f'读取日志文件失败: {str(e)}'})

View File

@@ -193,3 +193,65 @@ def get_local_models():
except Exception as e:
logger.error(f"获取本地模型列表失败: {e}")
return jsonify({'code': 1, 'message': str(e)})
# ============ 已训练模型列表接口 ============
@model_manage_bp.route('/trained-models', methods=['GET'])
def get_trained_models():
"""获取已训练模型列表(从/app/base/saves目录"""
import logging
logger = logging.getLogger(__name__)
try:
# 使用 /app/base/saves 目录(容器内路径)
saves_base_path = '/app/base/saves'
# 本地开发时的备用路径
local_saves_path = os.path.join(PROJECT_ROOT, 'saves')
# 选择存在的路径
base_path = saves_base_path if os.path.exists(saves_base_path) else local_saves_path
logger.info(f"[DEBUG] 已训练模型目录: {base_path}, exists: {os.path.exists(base_path)}")
models = []
if os.path.exists(base_path):
for item in os.listdir(base_path):
item_path = os.path.join(base_path, item)
if os.path.isdir(item_path):
# 检查是否是模板目录(包含训练方法的子目录)
sub_items = []
if os.path.exists(item_path):
for sub_item in os.listdir(item_path):
sub_path = os.path.join(item_path, sub_item)
if os.path.isdir(sub_path):
# 检查是否包含模型文件adapter_model.bin 或 pytorch_model.bin 等)
has_model = False
for f in os.listdir(sub_path):
if f.endswith('.bin') or f.endswith('.safetensors'):
has_model = True
break
if has_model:
sub_items.append({
'name': sub_item,
'path': sub_path
})
models.append({
'name': item,
'path': item_path,
'train_methods': sub_items
})
logger.info(f"[DEBUG] 找到 {len(models)} 个已训练模型")
return jsonify({
'code': 0,
'data': {
'models': models,
'base_path': base_path
}
})
except Exception as e:
logger.error(f"获取已训练模型列表失败: {e}")
return jsonify({'code': 1, 'message': str(e)})

View File

@@ -33,6 +33,9 @@ def load_config():
CONFIG = load_config()
# 训练日志路径
TRAINING_LOGS_DIR = CONFIG.get('training_logs_path', '/app/base/training_logs')
# ============ 日志系统配置 ============
LOG_BASE_DIR = os.path.join(PROJECT_ROOT, 'logs')
@@ -339,9 +342,10 @@ def init_database():
# 为 fine_tune 表添加训练相关列
columns_to_add = [
("description", "TEXT COMMENT '任务描述'"),
("train_dataset_id", "INT COMMENT '训练数据集ID'"),
("valid_dataset_id", "INT COMMENT '验证数据集ID'"),
("eval_steps", "INT DEFAULT 100 COMMENT '评估步数'"),
("save_steps", "INT DEFAULT 100 COMMENT '保存步数'"),
("lr_scheduler_type", "VARCHAR(50) DEFAULT 'cosine' COMMENT '学习率调度器'"),
("warmup_ratio", "FLOAT DEFAULT 0.05 COMMENT '预热比例'"),
("weight_decay", "FLOAT DEFAULT 0.01 COMMENT '权重衰减'"),
@@ -379,8 +383,18 @@ def init_database():
app = Flask(__name__)
app.config['SECRET_KEY'] = CONFIG['secret_key']
app.config['CORS_HEADERS'] = 'Content-Type'
# 允许所有来源
CORS(app, resources={r"/api/*": {"origins": "*"}}, methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["Content-Type", "Authorization"])
# 允许所有来源 - 支持跨域请求
CORS(app, resources={
r"/api/*": {
"origins": "*",
"methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
"allow_headers": ["Content-Type", "Authorization", "X-Requested-With"],
"expose_headers": ["Content-Length", "Content-Range"],
"supports_credentials": False,
"max_age": 86400 # 缓存预检请求结果 24 小时
}
}, vary_header=True)
# 注册蓝图
register_blueprints(app)
@@ -674,6 +688,168 @@ def get_fine_tune():
return jsonify({'code': 0, 'data': generic_get_all('fine_tune')})
@app.route('/api/fine-tune/<int:id>', methods=['GET'])
def get_fine_tune_by_id(id):
"""获取单个训练任务详情"""
try:
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("SELECT * FROM fine_tune WHERE id = %s", (id,))
task = cursor.fetchone()
if not task:
cursor.close()
conn.close()
return jsonify({'code': 1, 'message': '任务不存在'})
# 获取列名并转换为字典get_db_connection已使用DictCursortask已是字典
if isinstance(task, dict):
task_dict = task
else:
columns = [desc[0] for desc in cursor.description]
task_dict = dict(zip(columns, task))
cursor.close()
conn.close()
# 处理 datetime 序列化
for key, value in task_dict.items():
if isinstance(value, datetime):
task_dict[key] = value.strftime('%Y-%m-%d %H:%M:%S')
return jsonify({'code': 0, 'data': task_dict})
except Exception as e:
return jsonify({'code': 1, 'message': str(e)})
@app.route('/api/fine-tune/progress/<int:id>', methods=['GET'])
def get_fine_tune_progress(id):
"""获取训练任务的进度(通过解析日志文件)"""
try:
# 获取任务信息
conn = get_db_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT id, process_id, name, status FROM fine_tune WHERE id = %s", (id,))
task = cursor.fetchone()
conn.close()
if not task:
return jsonify({'code': 1, 'message': '任务不存在'})
process_id = task.get('process_id')
task_name = task.get('name', '')
if not process_id:
return jsonify({'code': 0, 'data': {'progress': 0, 'status': task.get('status', 'unknown'), 'step': '', 'speed': '', 'eta': ''}})
# 查找日志文件 - 优先使用容器路径,如果不存在则使用本地路径
training_logs_dir = TRAINING_LOGS_DIR
if not os.path.exists(training_logs_dir):
training_logs_dir = os.path.join(PROJECT_ROOT, 'training_logs')
if not os.path.exists(training_logs_dir):
return jsonify({'code': 0, 'data': {'progress': 0, 'status': task.get('status', 'unknown'), 'step': '', 'speed': '', 'eta': ''}})
log_file = None
# 优先按 process_id 查找
for file_name in os.listdir(training_logs_dir):
if file_name.endswith('.log') and file_name.startswith(f'{process_id}_'):
log_file = os.path.join(training_logs_dir, file_name)
break
# 如果没找到,尝试按任务名称查找
if not log_file and task_name:
for file_name in os.listdir(training_logs_dir):
if file_name.endswith('.log') and task_name in file_name:
log_file = os.path.join(training_logs_dir, file_name)
break
if not log_file or not os.path.exists(log_file):
return jsonify({'code': 0, 'data': {'progress': 0, 'status': task.get('status', 'unknown'), 'step': '', 'speed': '', 'eta': ''}})
# 读取日志文件内容
try:
with open(log_file, 'r', encoding='utf-8') as f:
content = f.read()
except Exception as e:
return jsonify({'code': 0, 'data': {'progress': 0, 'status': task.get('status', 'unknown'), 'step': '', 'speed': '', 'eta': ''}})
# 解析进度
progress = 0
step_info = ''
speed_info = ''
eta_info = ''
import re
# 处理 Windows 格式的日志(\r 覆盖行),将 \r 替换为换行
content = content.replace('\r', '\n')
# 日志格式: " 3%|▎ | 1/33 [00:09<05:10, 9.69s/it]"
# 或: " 30%|███ | 10/33 [01:22<03:00, 7.86s/it]"
# 匹配 "数字%|进度条| step/total [elapsed<eta, speed]"
progress_pattern = re.compile(r'(\d+)%\s*[\|▌▊█\s]+\s*\|\s*(\d+)/(\d+)\s*\[(\d+):?(\d+)<(\d+):?(\d+),\s*([\d.]+\s*(?:it/s|s/it))\s*\]')
# 按行分割并从后往前搜索
lines = content.split('\n')
for line in reversed(lines):
line = line.strip()
match = progress_pattern.search(line)
if match:
progress = int(match.group(1))
current_step = match.group(2)
total_steps = match.group(3)
elapsed_min = match.group(4)
elapsed_sec = match.group(5)
eta_min = match.group(6)
eta_sec = match.group(7)
speed = match.group(8).strip()
step_info = f'{current_step}/{total_steps}'
eta_info = f'{eta_min}:{eta_sec}'
speed_info = speed
break
# 如果没有找到进度格式,尝试其他格式
if progress == 0:
for line in reversed(lines):
if 'Running training' in line or 'running training' in line:
# 训练刚开始
break
# 尝试匹配简化格式
simple_match = re.search(r'(\d+)%\s*\|\s*(\d+)/(\d+)', line)
if simple_match:
progress = int(simple_match.group(1))
step_info = f'{simple_match.group(2)}/{simple_match.group(3)}'
break
# 检查训练是否完成
status = task.get('status', 'unknown')
for line in reversed(lines):
if 'Training completed' in line or '训练完成' in line:
status = 'completed'
progress = 100
break
if 'error' in line.lower() or 'failed' in line.lower() or 'Error' in line:
if 'KeyboardInterrupt' not in line:
status = 'failed'
break
return jsonify({
'code': 0,
'data': {
'progress': progress,
'status': status,
'step': step_info,
'speed': speed_info,
'eta': eta_info
}
})
except Exception as e:
return jsonify({'code': 1, 'message': f'获取进度失败: {str(e)}'})
@app.route('/api/fine-tune', methods=['POST'])
def create_fine_tune():
data = request.json
@@ -690,6 +866,39 @@ def update_fine_tune(id):
@app.route('/api/fine-tune/<int:id>', methods=['DELETE'])
def delete_fine_tune(id):
# 删除前获取任务信息(用于删除日志文件)
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("SELECT process_id, name FROM fine_tune WHERE id = %s", (id,))
task_info = cursor.fetchone()
conn.close()
# 删除相关的日志文件
if task_info and task_info.get('process_id'):
from datetime import datetime
process_id = task_info['process_id']
task_name = task_info.get('name', 'unknown')
# 优先使用容器路径,如果不存在则使用本地路径
training_logs_dir = TRAINING_LOGS_DIR
if not os.path.exists(training_logs_dir):
training_logs_dir = os.path.join(PROJECT_ROOT, 'training_logs')
try:
if os.path.exists(training_logs_dir):
for file_name in os.listdir(training_logs_dir):
# 查找以 PID 开头的日志文件
if file_name.endswith('.log') and file_name.startswith(f'{process_id}_'):
log_file = os.path.join(training_logs_dir, file_name)
try:
os.remove(log_file)
print(f"[INFO] 已删除日志文件: {log_file}")
except Exception as e:
print(f"[WARN] 删除日志文件失败: {log_file}, 错误: {e}")
except Exception as e:
print(f"[WARN] 查找或删除日志文件时出错: {e}")
# 删除数据库记录
generic_delete('fine_tune', id)
return jsonify({'code': 0, 'message': '删除成功'})