模型开始训练界面以及查看日志功能完善

This commit is contained in:
2026-01-29 10:36:59 +08:00
parent a560d24e2f
commit e9e0e21e47
11 changed files with 2485 additions and 179 deletions

View File

@@ -33,6 +33,9 @@ def load_config():
CONFIG = load_config()
# 训练日志路径
TRAINING_LOGS_DIR = CONFIG.get('training_logs_path', '/app/base/training_logs')
# ============ 日志系统配置 ============
LOG_BASE_DIR = os.path.join(PROJECT_ROOT, 'logs')
@@ -339,9 +342,10 @@ def init_database():
# 为 fine_tune 表添加训练相关列
columns_to_add = [
("description", "TEXT COMMENT '任务描述'"),
("train_dataset_id", "INT COMMENT '训练数据集ID'"),
("valid_dataset_id", "INT COMMENT '验证数据集ID'"),
("eval_steps", "INT DEFAULT 100 COMMENT '评估步数'"),
("save_steps", "INT DEFAULT 100 COMMENT '保存步数'"),
("lr_scheduler_type", "VARCHAR(50) DEFAULT 'cosine' COMMENT '学习率调度器'"),
("warmup_ratio", "FLOAT DEFAULT 0.05 COMMENT '预热比例'"),
("weight_decay", "FLOAT DEFAULT 0.01 COMMENT '权重衰减'"),
@@ -379,8 +383,18 @@ def init_database():
app = Flask(__name__)
app.config['SECRET_KEY'] = CONFIG['secret_key']
app.config['CORS_HEADERS'] = 'Content-Type'
# 允许所有来源
CORS(app, resources={r"/api/*": {"origins": "*"}}, methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["Content-Type", "Authorization"])
# 允许所有来源 - 支持跨域请求
CORS(app, resources={
r"/api/*": {
"origins": "*",
"methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
"allow_headers": ["Content-Type", "Authorization", "X-Requested-With"],
"expose_headers": ["Content-Length", "Content-Range"],
"supports_credentials": False,
"max_age": 86400 # 缓存预检请求结果 24 小时
}
}, vary_header=True)
# 注册蓝图
register_blueprints(app)
@@ -674,6 +688,168 @@ def get_fine_tune():
return jsonify({'code': 0, 'data': generic_get_all('fine_tune')})
@app.route('/api/fine-tune/<int:id>', methods=['GET'])
def get_fine_tune_by_id(id):
"""获取单个训练任务详情"""
try:
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("SELECT * FROM fine_tune WHERE id = %s", (id,))
task = cursor.fetchone()
if not task:
cursor.close()
conn.close()
return jsonify({'code': 1, 'message': '任务不存在'})
# 获取列名并转换为字典get_db_connection已使用DictCursortask已是字典
if isinstance(task, dict):
task_dict = task
else:
columns = [desc[0] for desc in cursor.description]
task_dict = dict(zip(columns, task))
cursor.close()
conn.close()
# 处理 datetime 序列化
for key, value in task_dict.items():
if isinstance(value, datetime):
task_dict[key] = value.strftime('%Y-%m-%d %H:%M:%S')
return jsonify({'code': 0, 'data': task_dict})
except Exception as e:
return jsonify({'code': 1, 'message': str(e)})
@app.route('/api/fine-tune/progress/<int:id>', methods=['GET'])
def get_fine_tune_progress(id):
"""获取训练任务的进度(通过解析日志文件)"""
try:
# 获取任务信息
conn = get_db_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT id, process_id, name, status FROM fine_tune WHERE id = %s", (id,))
task = cursor.fetchone()
conn.close()
if not task:
return jsonify({'code': 1, 'message': '任务不存在'})
process_id = task.get('process_id')
task_name = task.get('name', '')
if not process_id:
return jsonify({'code': 0, 'data': {'progress': 0, 'status': task.get('status', 'unknown'), 'step': '', 'speed': '', 'eta': ''}})
# 查找日志文件 - 优先使用容器路径,如果不存在则使用本地路径
training_logs_dir = TRAINING_LOGS_DIR
if not os.path.exists(training_logs_dir):
training_logs_dir = os.path.join(PROJECT_ROOT, 'training_logs')
if not os.path.exists(training_logs_dir):
return jsonify({'code': 0, 'data': {'progress': 0, 'status': task.get('status', 'unknown'), 'step': '', 'speed': '', 'eta': ''}})
log_file = None
# 优先按 process_id 查找
for file_name in os.listdir(training_logs_dir):
if file_name.endswith('.log') and file_name.startswith(f'{process_id}_'):
log_file = os.path.join(training_logs_dir, file_name)
break
# 如果没找到,尝试按任务名称查找
if not log_file and task_name:
for file_name in os.listdir(training_logs_dir):
if file_name.endswith('.log') and task_name in file_name:
log_file = os.path.join(training_logs_dir, file_name)
break
if not log_file or not os.path.exists(log_file):
return jsonify({'code': 0, 'data': {'progress': 0, 'status': task.get('status', 'unknown'), 'step': '', 'speed': '', 'eta': ''}})
# 读取日志文件内容
try:
with open(log_file, 'r', encoding='utf-8') as f:
content = f.read()
except Exception as e:
return jsonify({'code': 0, 'data': {'progress': 0, 'status': task.get('status', 'unknown'), 'step': '', 'speed': '', 'eta': ''}})
# 解析进度
progress = 0
step_info = ''
speed_info = ''
eta_info = ''
import re
# 处理 Windows 格式的日志(\r 覆盖行),将 \r 替换为换行
content = content.replace('\r', '\n')
# 日志格式: " 3%|▎ | 1/33 [00:09<05:10, 9.69s/it]"
# 或: " 30%|███ | 10/33 [01:22<03:00, 7.86s/it]"
# 匹配 "数字%|进度条| step/total [elapsed<eta, speed]"
progress_pattern = re.compile(r'(\d+)%\s*[\|▌▊█\s]+\s*\|\s*(\d+)/(\d+)\s*\[(\d+):?(\d+)<(\d+):?(\d+),\s*([\d.]+\s*(?:it/s|s/it))\s*\]')
# 按行分割并从后往前搜索
lines = content.split('\n')
for line in reversed(lines):
line = line.strip()
match = progress_pattern.search(line)
if match:
progress = int(match.group(1))
current_step = match.group(2)
total_steps = match.group(3)
elapsed_min = match.group(4)
elapsed_sec = match.group(5)
eta_min = match.group(6)
eta_sec = match.group(7)
speed = match.group(8).strip()
step_info = f'{current_step}/{total_steps}'
eta_info = f'{eta_min}:{eta_sec}'
speed_info = speed
break
# 如果没有找到进度格式,尝试其他格式
if progress == 0:
for line in reversed(lines):
if 'Running training' in line or 'running training' in line:
# 训练刚开始
break
# 尝试匹配简化格式
simple_match = re.search(r'(\d+)%\s*\|\s*(\d+)/(\d+)', line)
if simple_match:
progress = int(simple_match.group(1))
step_info = f'{simple_match.group(2)}/{simple_match.group(3)}'
break
# 检查训练是否完成
status = task.get('status', 'unknown')
for line in reversed(lines):
if 'Training completed' in line or '训练完成' in line:
status = 'completed'
progress = 100
break
if 'error' in line.lower() or 'failed' in line.lower() or 'Error' in line:
if 'KeyboardInterrupt' not in line:
status = 'failed'
break
return jsonify({
'code': 0,
'data': {
'progress': progress,
'status': status,
'step': step_info,
'speed': speed_info,
'eta': eta_info
}
})
except Exception as e:
return jsonify({'code': 1, 'message': f'获取进度失败: {str(e)}'})
@app.route('/api/fine-tune', methods=['POST'])
def create_fine_tune():
data = request.json
@@ -690,6 +866,39 @@ def update_fine_tune(id):
@app.route('/api/fine-tune/<int:id>', methods=['DELETE'])
def delete_fine_tune(id):
# 删除前获取任务信息(用于删除日志文件)
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("SELECT process_id, name FROM fine_tune WHERE id = %s", (id,))
task_info = cursor.fetchone()
conn.close()
# 删除相关的日志文件
if task_info and task_info.get('process_id'):
from datetime import datetime
process_id = task_info['process_id']
task_name = task_info.get('name', 'unknown')
# 优先使用容器路径,如果不存在则使用本地路径
training_logs_dir = TRAINING_LOGS_DIR
if not os.path.exists(training_logs_dir):
training_logs_dir = os.path.join(PROJECT_ROOT, 'training_logs')
try:
if os.path.exists(training_logs_dir):
for file_name in os.listdir(training_logs_dir):
# 查找以 PID 开头的日志文件
if file_name.endswith('.log') and file_name.startswith(f'{process_id}_'):
log_file = os.path.join(training_logs_dir, file_name)
try:
os.remove(log_file)
print(f"[INFO] 已删除日志文件: {log_file}")
except Exception as e:
print(f"[WARN] 删除日志文件失败: {log_file}, 错误: {e}")
except Exception as e:
print(f"[WARN] 查找或删除日志文件时出错: {e}")
# 删除数据库记录
generic_delete('fine_tune', id)
return jsonify({'code': 0, 'message': '删除成功'})