1. 修改了一些bug
2. 做了一些调整,比如启动脚本,支持了tenmsorboard
This commit is contained in:
@@ -167,7 +167,9 @@ def start_training():
|
||||
train_method = data.get('train_method', 'lora')
|
||||
|
||||
# 创建输出目录(如果不存在)
|
||||
output_model_name = data.get('output_model_name', f"{template}/{train_method}")
|
||||
# 路径格式: /app/base/saves/{train_method}/{output_model_name}
|
||||
output_model_name = data.get('output_model_name', template)
|
||||
output_model_name = f"{train_method}/{output_model_name}"
|
||||
if not output_model_name.startswith('/'):
|
||||
output_model_name = f"/app/base/saves/{output_model_name}"
|
||||
output_dir = output_model_name
|
||||
@@ -278,7 +280,9 @@ def build_train_command(data, model_path, dataset_name=None):
|
||||
cmd.extend(['--finetuning_type', FINETUNING_TYPE_MAP.get(train_method, 'lora')])
|
||||
|
||||
# 输出目录(确保是绝对路径)
|
||||
output_model_name = data.get('output_model_name', f"{template}/{train_method}")
|
||||
# 路径格式: /app/base/saves/{train_method}/{output_model_name}
|
||||
output_model_name = data.get('output_model_name', template)
|
||||
output_model_name = f"{train_method}/{output_model_name}"
|
||||
if not output_model_name.startswith('/'):
|
||||
output_model_name = f"/app/base/saves/{output_model_name}"
|
||||
output_dir = output_model_name
|
||||
@@ -417,7 +421,12 @@ def delete_training_task(task_id):
|
||||
# 获取任务信息(用于删除日志文件)
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT name, process_id FROM fine_tune WHERE id = %s", (task_id,))
|
||||
# 尝试获取所有字段,如果tensorboard_log_dir不存在会报错
|
||||
try:
|
||||
cursor.execute("SELECT name, process_id, tensorboard_log_dir FROM fine_tune WHERE id = %s", (task_id,))
|
||||
except:
|
||||
# 如果列不存在,只获取基本字段
|
||||
cursor.execute("SELECT name, process_id FROM fine_tune WHERE id = %s", (task_id,))
|
||||
task_result = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
@@ -425,6 +434,7 @@ def delete_training_task(task_id):
|
||||
return jsonify({'code': 1, 'message': '任务不存在'})
|
||||
|
||||
task_name = task_result.get('name', 'unknown')
|
||||
tensorboard_log_dir = task_result.get('tensorboard_log_dir', '') if 'tensorboard_log_dir' in task_result else ''
|
||||
|
||||
# 删除日志文件 (logs/{日期}/{task_id}_{task_name}.log)
|
||||
try:
|
||||
@@ -444,6 +454,17 @@ def delete_training_task(task_id):
|
||||
except Exception as log_err:
|
||||
logger.warning(f"删除日志文件失败: {log_err}")
|
||||
|
||||
# 删除 TensorBoard 进程(如果存在)
|
||||
global tensorboard_process
|
||||
if tensorboard_process and tensorboard_process.poll() is None:
|
||||
try:
|
||||
import signal
|
||||
os.killpg(os.getpgid(tensorboard_process.pid), signal.SIGTERM)
|
||||
tensorboard_process = None
|
||||
logger.info(f"已停止 TensorBoard 进程")
|
||||
except Exception as tb_err:
|
||||
logger.warning(f"停止 TensorBoard 失败: {tb_err}")
|
||||
|
||||
# 删除数据库中的任务记录
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
@@ -204,44 +204,67 @@ def get_trained_models():
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
# 使用 /app/base/saves 目录(容器内路径)
|
||||
saves_base_path = '/app/base/saves'
|
||||
# 本地开发时的备用路径
|
||||
local_saves_path = os.path.join(PROJECT_ROOT, 'saves')
|
||||
# 多个可能的路径
|
||||
potential_paths = [
|
||||
'/app/base/saves', # 容器内路径
|
||||
os.path.join(PROJECT_ROOT, 'saves'), # 本地开发路径
|
||||
os.path.join(os.path.dirname(os.path.dirname(PROJECT_ROOT)), 'YG_FT_Base', 'saves'), # 上级目录
|
||||
]
|
||||
|
||||
# 选择存在的路径
|
||||
base_path = saves_base_path if os.path.exists(saves_base_path) else local_saves_path
|
||||
base_path = None
|
||||
for path in potential_paths:
|
||||
logger.info(f"[DEBUG] 检查路径: {path}, exists: {os.path.exists(path)}")
|
||||
if os.path.exists(path):
|
||||
base_path = path
|
||||
break
|
||||
|
||||
logger.info(f"[DEBUG] 已训练模型目录: {base_path}, exists: {os.path.exists(base_path)}")
|
||||
logger.info(f"[DEBUG] 最终使用的路径: {base_path}")
|
||||
|
||||
models = []
|
||||
if os.path.exists(base_path):
|
||||
for item in os.listdir(base_path):
|
||||
item_path = os.path.join(base_path, item)
|
||||
if os.path.isdir(item_path):
|
||||
# 检查是否是模板目录(包含训练方法的子目录)
|
||||
sub_items = []
|
||||
if os.path.exists(item_path):
|
||||
for sub_item in os.listdir(item_path):
|
||||
sub_path = os.path.join(item_path, sub_item)
|
||||
if os.path.isdir(sub_path):
|
||||
# 检查是否包含模型文件(adapter_model.bin 或 pytorch_model.bin 等)
|
||||
has_model = False
|
||||
for f in os.listdir(sub_path):
|
||||
if f.endswith('.bin') or f.endswith('.safetensors'):
|
||||
has_model = True
|
||||
break
|
||||
if has_model:
|
||||
sub_items.append({
|
||||
'name': sub_item,
|
||||
'path': sub_path
|
||||
})
|
||||
if base_path and os.path.exists(base_path):
|
||||
logger.info(f"[DEBUG] 遍历目录: {base_path}")
|
||||
try:
|
||||
# 路径结构: /app/base/saves/{train_method}/{model_name}/
|
||||
# train_method: lora, full, qlora, dpo, cpt 等
|
||||
|
||||
models.append({
|
||||
'name': item,
|
||||
'path': item_path,
|
||||
'train_methods': sub_items
|
||||
})
|
||||
for train_method in os.listdir(base_path):
|
||||
train_method_path = os.path.join(base_path, train_method)
|
||||
if not os.path.isdir(train_method_path):
|
||||
continue
|
||||
|
||||
logger.info(f"[DEBUG] 检查训练方法目录: {train_method}")
|
||||
model_count = 0
|
||||
|
||||
# 遍历模型文件夹
|
||||
for model_name in os.listdir(train_method_path):
|
||||
model_path = os.path.join(train_method_path, model_name)
|
||||
if not os.path.isdir(model_path):
|
||||
continue
|
||||
|
||||
# 检查是否有模型文件
|
||||
try:
|
||||
files = os.listdir(model_path)
|
||||
logger.info(f"[DEBUG] {train_method}/{model_name} 文件: {files[:5]}...")
|
||||
has_model = any(f.endswith('.bin') or f.endswith('.safetensors') for f in files)
|
||||
|
||||
if has_model:
|
||||
logger.info(f"[DEBUG] 找到模型: {train_method}/{model_name}")
|
||||
models.append({
|
||||
'name': model_name,
|
||||
'path': model_path,
|
||||
'train_methods': [{
|
||||
'name': train_method,
|
||||
'path': model_path
|
||||
}]
|
||||
})
|
||||
model_count += 1
|
||||
except Exception as file_err:
|
||||
logger.error(f"[DEBUG] 读取 {model_path} 失败: {file_err}")
|
||||
|
||||
logger.info(f"[DEBUG] {train_method} 找到 {model_count} 个模型")
|
||||
|
||||
except Exception as list_err:
|
||||
logger.error(f"[DEBUG] 遍历目录失败: {list_err}")
|
||||
|
||||
logger.info(f"[DEBUG] 找到 {len(models)} 个已训练模型")
|
||||
|
||||
@@ -249,7 +272,7 @@ def get_trained_models():
|
||||
'code': 0,
|
||||
'data': {
|
||||
'models': models,
|
||||
'base_path': base_path
|
||||
'base_path': base_path or ''
|
||||
}
|
||||
})
|
||||
except Exception as e:
|
||||
|
||||
@@ -165,6 +165,7 @@ def init_database():
|
||||
valid_ratio INT DEFAULT 10,
|
||||
output_model_name VARCHAR(255),
|
||||
process_id INT COMMENT '训练进程ID',
|
||||
tensorboard_log_dir VARCHAR(255) COMMENT 'TensorBoard日志目录',
|
||||
status VARCHAR(50) DEFAULT 'pending',
|
||||
progress INT DEFAULT 0,
|
||||
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
Reference in New Issue
Block a user