模型开始训练界面以及查看日志功能完善
This commit is contained in:
215
src/main.py
215
src/main.py
@@ -33,6 +33,9 @@ def load_config():
|
||||
|
||||
CONFIG = load_config()
|
||||
|
||||
# 训练日志路径
|
||||
TRAINING_LOGS_DIR = CONFIG.get('training_logs_path', '/app/base/training_logs')
|
||||
|
||||
# ============ 日志系统配置 ============
|
||||
LOG_BASE_DIR = os.path.join(PROJECT_ROOT, 'logs')
|
||||
|
||||
@@ -339,9 +342,10 @@ def init_database():
|
||||
|
||||
# 为 fine_tune 表添加训练相关列
|
||||
columns_to_add = [
|
||||
("description", "TEXT COMMENT '任务描述'"),
|
||||
("train_dataset_id", "INT COMMENT '训练数据集ID'"),
|
||||
("valid_dataset_id", "INT COMMENT '验证数据集ID'"),
|
||||
("eval_steps", "INT DEFAULT 100 COMMENT '评估步数'"),
|
||||
("save_steps", "INT DEFAULT 100 COMMENT '保存步数'"),
|
||||
("lr_scheduler_type", "VARCHAR(50) DEFAULT 'cosine' COMMENT '学习率调度器'"),
|
||||
("warmup_ratio", "FLOAT DEFAULT 0.05 COMMENT '预热比例'"),
|
||||
("weight_decay", "FLOAT DEFAULT 0.01 COMMENT '权重衰减'"),
|
||||
@@ -379,8 +383,18 @@ def init_database():
|
||||
app = Flask(__name__)
|
||||
app.config['SECRET_KEY'] = CONFIG['secret_key']
|
||||
app.config['CORS_HEADERS'] = 'Content-Type'
|
||||
# 允许所有来源
|
||||
CORS(app, resources={r"/api/*": {"origins": "*"}}, methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["Content-Type", "Authorization"])
|
||||
|
||||
# 允许所有来源 - 支持跨域请求
|
||||
CORS(app, resources={
|
||||
r"/api/*": {
|
||||
"origins": "*",
|
||||
"methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
|
||||
"allow_headers": ["Content-Type", "Authorization", "X-Requested-With"],
|
||||
"expose_headers": ["Content-Length", "Content-Range"],
|
||||
"supports_credentials": False,
|
||||
"max_age": 86400 # 缓存预检请求结果 24 小时
|
||||
}
|
||||
}, vary_header=True)
|
||||
|
||||
# 注册蓝图
|
||||
register_blueprints(app)
|
||||
@@ -674,6 +688,168 @@ def get_fine_tune():
|
||||
return jsonify({'code': 0, 'data': generic_get_all('fine_tune')})
|
||||
|
||||
|
||||
@app.route('/api/fine-tune/<int:id>', methods=['GET'])
|
||||
def get_fine_tune_by_id(id):
|
||||
"""获取单个训练任务详情"""
|
||||
try:
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM fine_tune WHERE id = %s", (id,))
|
||||
task = cursor.fetchone()
|
||||
|
||||
if not task:
|
||||
cursor.close()
|
||||
conn.close()
|
||||
return jsonify({'code': 1, 'message': '任务不存在'})
|
||||
|
||||
# 获取列名并转换为字典(get_db_connection已使用DictCursor,task已是字典)
|
||||
if isinstance(task, dict):
|
||||
task_dict = task
|
||||
else:
|
||||
columns = [desc[0] for desc in cursor.description]
|
||||
task_dict = dict(zip(columns, task))
|
||||
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
# 处理 datetime 序列化
|
||||
for key, value in task_dict.items():
|
||||
if isinstance(value, datetime):
|
||||
task_dict[key] = value.strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
return jsonify({'code': 0, 'data': task_dict})
|
||||
except Exception as e:
|
||||
return jsonify({'code': 1, 'message': str(e)})
|
||||
|
||||
|
||||
@app.route('/api/fine-tune/progress/<int:id>', methods=['GET'])
|
||||
def get_fine_tune_progress(id):
|
||||
"""获取训练任务的进度(通过解析日志文件)"""
|
||||
try:
|
||||
# 获取任务信息
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor.execute("SELECT id, process_id, name, status FROM fine_tune WHERE id = %s", (id,))
|
||||
task = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
if not task:
|
||||
return jsonify({'code': 1, 'message': '任务不存在'})
|
||||
|
||||
process_id = task.get('process_id')
|
||||
task_name = task.get('name', '')
|
||||
|
||||
if not process_id:
|
||||
return jsonify({'code': 0, 'data': {'progress': 0, 'status': task.get('status', 'unknown'), 'step': '', 'speed': '', 'eta': ''}})
|
||||
|
||||
# 查找日志文件 - 优先使用容器路径,如果不存在则使用本地路径
|
||||
training_logs_dir = TRAINING_LOGS_DIR
|
||||
if not os.path.exists(training_logs_dir):
|
||||
training_logs_dir = os.path.join(PROJECT_ROOT, 'training_logs')
|
||||
|
||||
if not os.path.exists(training_logs_dir):
|
||||
return jsonify({'code': 0, 'data': {'progress': 0, 'status': task.get('status', 'unknown'), 'step': '', 'speed': '', 'eta': ''}})
|
||||
|
||||
log_file = None
|
||||
|
||||
# 优先按 process_id 查找
|
||||
for file_name in os.listdir(training_logs_dir):
|
||||
if file_name.endswith('.log') and file_name.startswith(f'{process_id}_'):
|
||||
log_file = os.path.join(training_logs_dir, file_name)
|
||||
break
|
||||
|
||||
# 如果没找到,尝试按任务名称查找
|
||||
if not log_file and task_name:
|
||||
for file_name in os.listdir(training_logs_dir):
|
||||
if file_name.endswith('.log') and task_name in file_name:
|
||||
log_file = os.path.join(training_logs_dir, file_name)
|
||||
break
|
||||
|
||||
if not log_file or not os.path.exists(log_file):
|
||||
return jsonify({'code': 0, 'data': {'progress': 0, 'status': task.get('status', 'unknown'), 'step': '', 'speed': '', 'eta': ''}})
|
||||
|
||||
# 读取日志文件内容
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
except Exception as e:
|
||||
return jsonify({'code': 0, 'data': {'progress': 0, 'status': task.get('status', 'unknown'), 'step': '', 'speed': '', 'eta': ''}})
|
||||
|
||||
# 解析进度
|
||||
progress = 0
|
||||
step_info = ''
|
||||
speed_info = ''
|
||||
eta_info = ''
|
||||
|
||||
import re
|
||||
|
||||
# 处理 Windows 格式的日志(\r 覆盖行),将 \r 替换为换行
|
||||
content = content.replace('\r', '\n')
|
||||
|
||||
# 日志格式: " 3%|▎ | 1/33 [00:09<05:10, 9.69s/it]"
|
||||
# 或: " 30%|███ | 10/33 [01:22<03:00, 7.86s/it]"
|
||||
# 匹配 "数字%|进度条| step/total [elapsed<eta, speed]"
|
||||
progress_pattern = re.compile(r'(\d+)%\s*[\|▌▊█\s]+\s*\|\s*(\d+)/(\d+)\s*\[(\d+):?(\d+)<(\d+):?(\d+),\s*([\d.]+\s*(?:it/s|s/it))\s*\]')
|
||||
|
||||
# 按行分割并从后往前搜索
|
||||
lines = content.split('\n')
|
||||
for line in reversed(lines):
|
||||
line = line.strip()
|
||||
match = progress_pattern.search(line)
|
||||
if match:
|
||||
progress = int(match.group(1))
|
||||
current_step = match.group(2)
|
||||
total_steps = match.group(3)
|
||||
elapsed_min = match.group(4)
|
||||
elapsed_sec = match.group(5)
|
||||
eta_min = match.group(6)
|
||||
eta_sec = match.group(7)
|
||||
speed = match.group(8).strip()
|
||||
|
||||
step_info = f'{current_step}/{total_steps}'
|
||||
eta_info = f'{eta_min}:{eta_sec}'
|
||||
speed_info = speed
|
||||
break
|
||||
|
||||
# 如果没有找到进度格式,尝试其他格式
|
||||
if progress == 0:
|
||||
for line in reversed(lines):
|
||||
if 'Running training' in line or 'running training' in line:
|
||||
# 训练刚开始
|
||||
break
|
||||
# 尝试匹配简化格式
|
||||
simple_match = re.search(r'(\d+)%\s*\|\s*(\d+)/(\d+)', line)
|
||||
if simple_match:
|
||||
progress = int(simple_match.group(1))
|
||||
step_info = f'{simple_match.group(2)}/{simple_match.group(3)}'
|
||||
break
|
||||
|
||||
# 检查训练是否完成
|
||||
status = task.get('status', 'unknown')
|
||||
for line in reversed(lines):
|
||||
if 'Training completed' in line or '训练完成' in line:
|
||||
status = 'completed'
|
||||
progress = 100
|
||||
break
|
||||
if 'error' in line.lower() or 'failed' in line.lower() or 'Error' in line:
|
||||
if 'KeyboardInterrupt' not in line:
|
||||
status = 'failed'
|
||||
break
|
||||
|
||||
return jsonify({
|
||||
'code': 0,
|
||||
'data': {
|
||||
'progress': progress,
|
||||
'status': status,
|
||||
'step': step_info,
|
||||
'speed': speed_info,
|
||||
'eta': eta_info
|
||||
}
|
||||
})
|
||||
except Exception as e:
|
||||
return jsonify({'code': 1, 'message': f'获取进度失败: {str(e)}'})
|
||||
|
||||
|
||||
@app.route('/api/fine-tune', methods=['POST'])
|
||||
def create_fine_tune():
|
||||
data = request.json
|
||||
@@ -690,6 +866,39 @@ def update_fine_tune(id):
|
||||
|
||||
@app.route('/api/fine-tune/<int:id>', methods=['DELETE'])
|
||||
def delete_fine_tune(id):
|
||||
# 删除前获取任务信息(用于删除日志文件)
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT process_id, name FROM fine_tune WHERE id = %s", (id,))
|
||||
task_info = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
# 删除相关的日志文件
|
||||
if task_info and task_info.get('process_id'):
|
||||
from datetime import datetime
|
||||
process_id = task_info['process_id']
|
||||
task_name = task_info.get('name', 'unknown')
|
||||
|
||||
# 优先使用容器路径,如果不存在则使用本地路径
|
||||
training_logs_dir = TRAINING_LOGS_DIR
|
||||
if not os.path.exists(training_logs_dir):
|
||||
training_logs_dir = os.path.join(PROJECT_ROOT, 'training_logs')
|
||||
|
||||
try:
|
||||
if os.path.exists(training_logs_dir):
|
||||
for file_name in os.listdir(training_logs_dir):
|
||||
# 查找以 PID 开头的日志文件
|
||||
if file_name.endswith('.log') and file_name.startswith(f'{process_id}_'):
|
||||
log_file = os.path.join(training_logs_dir, file_name)
|
||||
try:
|
||||
os.remove(log_file)
|
||||
print(f"[INFO] 已删除日志文件: {log_file}")
|
||||
except Exception as e:
|
||||
print(f"[WARN] 删除日志文件失败: {log_file}, 错误: {e}")
|
||||
except Exception as e:
|
||||
print(f"[WARN] 查找或删除日志文件时出错: {e}")
|
||||
|
||||
# 删除数据库记录
|
||||
generic_delete('fine_tune', id)
|
||||
return jsonify({'code': 0, 'message': '删除成功'})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user