1. 增加了合并权重
2. 修改了一些列表展示的bug
This commit is contained in:
@@ -47,24 +47,32 @@ def generic_get_all(table_name, order_by='create_time DESC'):
|
||||
|
||||
def get_model_path_by_name(model_name):
|
||||
"""根据模型名称查询模型路径(用于获取基座模型路径)"""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"[DEBUG get_model_path_by_name] 查询模型: {model_name}")
|
||||
|
||||
try:
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 优先从训练任务表查询基座模型
|
||||
logger.info(f"[DEBUG get_model_path_by_name] 尝试从fine_tune表查询...")
|
||||
cursor.execute("""
|
||||
SELECT base_model FROM fine_tune
|
||||
SELECT base_model, output_model_name FROM fine_tune
|
||||
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
|
||||
LIMIT 1
|
||||
""", (f'%/{model_name}', f'%{model_name}%'))
|
||||
ft_result = cursor.fetchone()
|
||||
logger.info(f"[DEBUG get_model_path_by_name] fine_tune查询结果: {ft_result}")
|
||||
|
||||
if ft_result and ft_result.get('base_model'):
|
||||
base_model_val = ft_result['base_model']
|
||||
logger.info(f"[DEBUG get_model_path_by_name] base_model_val: {base_model_val}")
|
||||
# 如果是数字ID,查询模型管理表获取路径
|
||||
if str(base_model_val).isdigit():
|
||||
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
|
||||
model_result = cursor.fetchone()
|
||||
logger.info(f"[DEBUG get_model_path_by_name] model_manage查询结果(数字ID): {model_result}")
|
||||
if model_result:
|
||||
cursor.close()
|
||||
conn.close()
|
||||
@@ -76,12 +84,15 @@ def get_model_path_by_name(model_name):
|
||||
return base_model_val
|
||||
|
||||
# 如果训练任务表没找到,尝试从模型管理表按名称查询
|
||||
logger.info(f"[DEBUG get_model_path_by_name] 尝试从model_manage表查询...")
|
||||
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
|
||||
result = cursor.fetchone()
|
||||
logger.info(f"[DEBUG get_model_path_by_name] model_manage查询结果: {result}")
|
||||
cursor.close()
|
||||
conn.close()
|
||||
if result:
|
||||
return result.get('path')
|
||||
logger.info(f"[DEBUG get_model_path_by_name] 未找到任何匹配,返回None")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[ERROR] 查询模型路径失败: {e}")
|
||||
@@ -387,3 +398,138 @@ def get_trained_models():
|
||||
except Exception as e:
|
||||
logger.error(f"获取已训练模型列表失败: {e}")
|
||||
return jsonify({'code': 1, 'message': str(e)})
|
||||
|
||||
|
||||
# ============ 合并权重接口 ============
|
||||
|
||||
@model_manage_bp.route('/merge', methods=['POST'])
|
||||
def merge_model():
|
||||
"""合并模型权重(将LoRA适配器合并到基座模型)"""
|
||||
import subprocess
|
||||
import sys
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
data = request.json
|
||||
model_name = data.get('model_name') # 模型名称
|
||||
train_method = data.get('train_method', 'lora') # 训练方法
|
||||
base_model_path = data.get('base_model_path') # 基座模型路径
|
||||
|
||||
if not model_name:
|
||||
return jsonify({'code': 1, 'message': '缺少模型名称'})
|
||||
|
||||
logger.info(f"[MERGE] 开始合并模型: {model_name}, 方法: {train_method}")
|
||||
|
||||
# 如果没有提供基座模型路径,从数据库查询
|
||||
if not base_model_path:
|
||||
try:
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 优先从训练任务表查询
|
||||
cursor.execute("""
|
||||
SELECT base_model FROM fine_tune
|
||||
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
|
||||
LIMIT 1
|
||||
""", (f'%/{model_name}', f'%{model_name}%'))
|
||||
ft_result = cursor.fetchone()
|
||||
|
||||
if ft_result and ft_result.get('base_model'):
|
||||
base_model_val = ft_result['base_model']
|
||||
if str(base_model_val).isdigit():
|
||||
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
|
||||
model_result = cursor.fetchone()
|
||||
if model_result:
|
||||
base_model_path = model_result.get('path')
|
||||
else:
|
||||
base_model_path = base_model_val
|
||||
|
||||
# 如果没找到,尝试从模型管理表按名称查询
|
||||
if not base_model_path:
|
||||
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
|
||||
model_result = cursor.fetchone()
|
||||
if model_result:
|
||||
base_model_path = model_result.get('path')
|
||||
|
||||
conn.close()
|
||||
|
||||
if not base_model_path:
|
||||
return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'})
|
||||
except Exception as e:
|
||||
logger.error(f"[MERGE] 查询模型配置失败: {e}")
|
||||
return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'})
|
||||
|
||||
# 训练后的模型路径(LoRA适配器)
|
||||
adapter_path = f"/app/base/saves/{train_method}/{model_name}"
|
||||
|
||||
# 检查路径是否存在
|
||||
if not os.path.exists(adapter_path):
|
||||
return jsonify({'code': 1, 'message': f'训练模型不存在: {adapter_path}'})
|
||||
|
||||
# 合并后的输出路径
|
||||
output_path = f"/app/base/local_trained_models/{model_name}"
|
||||
|
||||
# 创建输出目录
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
try:
|
||||
work_dir = '/app/base'
|
||||
|
||||
# 设置环境变量
|
||||
env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'}
|
||||
|
||||
# 使用 llamafactory-cli export 命令(假设已在系统 PATH 中,与训练命令一致)
|
||||
cli_cmd = ['llamafactory-cli', 'export']
|
||||
|
||||
# 检查 llamafactory-cli 是否存在
|
||||
try:
|
||||
# 尝试使用 which 命令(Linux/Mac)
|
||||
subprocess.run(['which', 'llamafactory-cli'], capture_output=True, check=True)
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
# Windows 上没有 which 命令,直接尝试执行
|
||||
logger.info("[MERGE] which 命令不可用,直接尝试执行 llamafactory-cli")
|
||||
|
||||
# 构建完整命令参数
|
||||
export_args = [
|
||||
'--model_name_or_path', base_model_path,
|
||||
'--adapter_name_or_path', adapter_path,
|
||||
'--export_dir', output_path
|
||||
]
|
||||
|
||||
logger.info(f"[MERGE] 执行合并命令: {' '.join(cli_cmd)} {' '.join(export_args)}")
|
||||
|
||||
# 直接执行 llamafactory-cli export 命令
|
||||
result = subprocess.run(
|
||||
cli_cmd + export_args,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=600,
|
||||
cwd=work_dir or '/app/base',
|
||||
env=env
|
||||
)
|
||||
|
||||
logger.info(f"[MERGE] 命令返回码: {result.returncode}")
|
||||
logger.info(f"[MERGE] stdout: {result.stdout[:500] if result.stdout else 'empty'}")
|
||||
logger.info(f"[MERGE] stderr: {result.stderr[:500] if result.stderr else 'empty'}")
|
||||
|
||||
if result.returncode == 0:
|
||||
return jsonify({
|
||||
'code': 0,
|
||||
'message': f'模型权重已成功合并到 {output_path}',
|
||||
'data': {
|
||||
'model_name': model_name,
|
||||
'output_path': output_path
|
||||
}
|
||||
})
|
||||
else:
|
||||
error_msg = result.stderr.strip() if result.stderr else result.stdout.strip()
|
||||
if not error_msg:
|
||||
error_msg = f'命令执行失败,返回码: {result.returncode}'
|
||||
return jsonify({'code': 1, 'message': f'合并失败: {error_msg}'})
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("[MERGE] 合并超时")
|
||||
return jsonify({'code': 1, 'message': '合并超时,请稍后重试'})
|
||||
except Exception as e:
|
||||
logger.error(f"[MERGE] 合并异常: {str(e)}")
|
||||
return jsonify({'code': 1, 'message': f'合并异常: {str(e)}'})
|
||||
|
||||
Reference in New Issue
Block a user