1. 修改了合并模型导出模型的逻辑

2. 修改了一些冗余的bug
3. 页面上表格的调整
This commit is contained in:
2026-01-29 23:10:21 +08:00
parent 0f98d67e41
commit 03b6071856
10 changed files with 1008 additions and 460 deletions

View File

@@ -4,6 +4,7 @@
import os
import pymysql
import yaml
import json
import requests
import concurrent.futures
import subprocess
@@ -211,6 +212,200 @@ def model_chat_batch():
return jsonify({'code': 0, 'data': results})
@model_chat_bp.route('/trained/preload', methods=['POST'])
def preload_trained_model():
"""预加载已训练模型(使用 llamafactory"""
import sys as sys_module
import pymysql
import yaml
import logging
logger = logging.getLogger(__name__)
data = request.json
model_name = data.get('model_name') # 模型名称
train_method = data.get('train_method', 'lora') # 训练方法: lora, full
base_model_path = data.get('base_model_path') # 前端传递的基座模型路径
if not model_name:
return jsonify({'code': 1, 'message': '缺少模型名称'})
logger.info(f"[PRELOAD] 开始预加载模型: {model_name}, 方法: {train_method}")
logger.info(f"[PRELOAD] 前端传递的基座模型路径: {base_model_path}")
# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
try:
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
CONFIG = yaml.safe_load(f)
except Exception as e:
return jsonify({'code': 1, 'message': f'读取配置失败: {str(e)}'})
# 优先使用前端传递的基座模型路径,否则从数据库查询
if not base_model_path:
try:
db_config = CONFIG['database']
conn = pymysql.connect(
host=db_config['host'],
port=db_config['port'],
user=db_config['username'],
password=db_config['password'],
database=db_config['name'],
charset=db_config.get('charset', 'utf8mb4'),
cursorclass=pymysql.cursors.DictCursor
)
cursor = conn.cursor()
# 优先从训练任务表查询基座模型
logger.info(f"[PRELOAD] 尝试从fine_tune表查询模型: {model_name}")
cursor.execute("""
SELECT base_model, output_model_name FROM fine_tune
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
LIMIT 1
""", (f'%/{model_name}', f'%{model_name}%'))
ft_result = cursor.fetchone()
logger.info(f"[PRELOAD] fine_tune查询结果: {ft_result}")
if ft_result and ft_result.get('base_model'):
base_model_val = ft_result['base_model']
logger.info(f"[PRELOAD] base_model_val: {base_model_val}")
# 如果是数字ID查询模型管理表获取路径
if str(base_model_val).isdigit():
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
model_result = cursor.fetchone()
logger.info(f"[PRELOAD] model_manage查询结果(数字ID): {model_result}")
if model_result:
base_model_path = model_result.get('path')
else:
# 直接是路径
base_model_path = base_model_val
# 如果训练任务表没找到,尝试从模型管理表按名称查询
if not base_model_path:
logger.info(f"[PRELOAD] 尝试从model_manage表查询...")
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
model_result = cursor.fetchone()
logger.info(f"[PRELOAD] model_manage查询结果: {model_result}")
if model_result:
base_model_path = model_result.get('path')
conn.close()
if not base_model_path:
logger.error(f"[PRELOAD] 未找到模型 {model_name} 的基座模型配置")
return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'})
except Exception as e:
logger.error(f"[PRELOAD] 查询模型配置失败: {e}")
return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'})
else:
logger.info(f"[PRELOAD] 使用前端传递的基座模型路径: {base_model_path}")
# 训练后的模型路径
trained_model_path = f"/app/base/saves/{train_method}/{model_name}"
# 检查路径是否存在兼容Windows和Linux
if not os.path.exists(trained_model_path):
logger.warning(f"[PRELOAD] 训练模型路径不存在: {trained_model_path}")
# 尝试查找适配器文件来确定模型是否存在
adapter_path = os.path.join(trained_model_path, 'adapter_model.bin')
safetensors_path = os.path.join(trained_model_path, 'model.safetensors')
if not os.path.exists(adapter_path) and not os.path.exists(safetensors_path):
return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'})
# 预热消息 - 使用一个简单的问候语来加载模型
work_dir = '/app/base'
warmup_messages = [{'role': 'system', 'content': 'You are a helpful assistant.'}, {'role': 'user', 'content': 'Hello'}]
# 根据训练方法选择 finetuning_type
finetuning_type = 'lora' if train_method == 'lora' else 'full'
# 构建 llamafactory 预热脚本
preload_script = f'''
import sys
import json
import logging
logging.basicConfig(level=logging.WARNING)
from llmtuner import ChatModel
def main():
chat_model = ChatModel({{
"model_name_or_path": "{base_model_path}",
"adapter_name_or_path": "{trained_model_path}",
"template": "llama3",
"finetuning_type": "{finetuning_type}",
"temperature": 0.1,
"max_new_tokens": 1
}})
messages = {json.dumps(warmup_messages, ensure_ascii=False)}
# 执行推理以加载模型
response = chat_model.chat(messages)
print("Model loaded successfully")
if __name__ == "__main__":
main()
'''
try:
work_dir = '/app/base'
script_path = os.path.join(work_dir, 'temp_preload.py')
with open(script_path, 'w', encoding='utf-8') as f:
f.write(preload_script)
# 设置环境变量,包括 PYTHONPATH 以便找到 llamafactory
import sys as sys_module
env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'}
# 尝试查找 llamafactory 目录并添加到 PYTHONPATH
llm_factory_paths = ['/app/base', '/app', '/app/base/src/llamafactory']
for path in llm_factory_paths:
if os.path.exists(path) and os.path.exists(os.path.join(path, 'llmtuner')):
env['PYTHONPATH'] = path
logger.info(f"[PRELOAD] 设置 PYTHONPATH={path}")
break
logger.info(f"[PRELOAD] 执行预热脚本...")
result = subprocess.run(
[sys_module.executable, 'temp_preload.py'],
capture_output=True,
text=True,
timeout=180,
cwd=work_dir,
env=env
)
os.remove(script_path)
logger.info(f"[PRELOAD] 命令返回码: {result.returncode}")
logger.info(f"[PRELOAD] stdout: {result.stdout[:500] if result.stdout else 'empty'}")
logger.info(f"[PRELOAD] stderr: {result.stderr[:500] if result.stderr else 'empty'}")
if result.returncode == 0:
return jsonify({
'code': 0,
'message': '模型预加载成功',
'data': {
'model_name': model_name,
'train_method': train_method,
'base_model': base_model_path
}
})
else:
error_msg = result.stderr.strip() if result.stderr else result.stdout.strip()
if not error_msg:
error_msg = f'命令执行失败,返回码: {result.returncode}'
return jsonify({'code': 1, 'message': f'预加载失败: {error_msg}'})
except subprocess.TimeoutExpired:
logger.error("[PRELOAD] 预加载超时")
return jsonify({'code': 1, 'message': '预加载超时,请稍后重试'})
except Exception as e:
logger.error(f"[PRELOAD] 预加载异常: {str(e)}")
return jsonify({'code': 1, 'message': f'预加载异常: {str(e)}'})
@model_chat_bp.route('/trained', methods=['POST'])
def chat_trained_model():
"""使用已训练模型进行对话推理"""
@@ -220,6 +415,7 @@ def chat_trained_model():
data = request.json
model_name = data.get('model_name') # 模型名称
train_method = data.get('train_method', 'lora') # 训练方法: lora, full
base_model_path = data.get('base_model_path') # 前端传递的基座模型路径
system_prompt = data.get('system_prompt', '')
user_question = data.get('user_question')
temperature = data.get('temperature', 0.7)
@@ -236,39 +432,64 @@ def chat_trained_model():
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
CONFIG = yaml.safe_load(f)
# 获取基座模型路径 - 从数据库查询 model_name 对应的路径
try:
db_config = CONFIG['database']
conn = pymysql.connect(
host=db_config['host'],
port=db_config['port'],
user=db_config['username'],
password=db_config['password'],
database=db_config['name'],
charset=db_config.get('charset', 'utf8mb4'),
cursorclass=pymysql.cursors.DictCursor
)
cursor = conn.cursor()
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
model_result = cursor.fetchone()
conn.close()
# 优先使用前端传递的基座模型路径,否则从数据库查询
if not base_model_path:
try:
db_config = CONFIG['database']
conn = pymysql.connect(
host=db_config['host'],
port=db_config['port'],
user=db_config['username'],
password=db_config['password'],
database=db_config['name'],
charset=db_config.get('charset', 'utf8mb4'),
cursorclass=pymysql.cursors.DictCursor
)
cursor = conn.cursor()
if not model_result or not model_result.get('path'):
return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的配置'})
# 优先从训练任务表查询基座模型
cursor.execute("""
SELECT base_model, output_model_name FROM fine_tune
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
LIMIT 1
""", (f'%/{model_name}', f'%{model_name}%'))
ft_result = cursor.fetchone()
base_model_path = model_result['path']
except Exception as e:
return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'})
if ft_result and ft_result.get('base_model'):
base_model_val = ft_result['base_model']
# 如果是数字ID查询模型管理表获取路径
if str(base_model_val).isdigit():
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
model_result = cursor.fetchone()
if model_result:
base_model_path = model_result.get('path')
else:
# 直接是路径
base_model_path = base_model_val
# 如果训练任务表没找到,尝试从模型管理表按名称查询
if not base_model_path:
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
model_result = cursor.fetchone()
if model_result:
base_model_path = model_result.get('path')
conn.close()
if not base_model_path:
return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'})
except Exception as e:
return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'})
# 训练后的模型路径
trained_model_path = f"/app/base/saves/{train_method}/{model_name}"
# 检查路径是否存在兼容Windows和Linux
if not os.path.exists(trained_model_path):
return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'})
# 构建 llamafactory-cli chat 命令
work_dir = '/app/base'
llamafactory_dir = '/app/base'
adapter_path = os.path.join(trained_model_path, 'adapter_model.bin')
safetensors_path = os.path.join(trained_model_path, 'model.safetensors')
if not os.path.exists(adapter_path) and not os.path.exists(safetensors_path):
return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'})
# 准备消息
messages = []
@@ -276,55 +497,76 @@ def chat_trained_model():
messages.append({'role': 'system', 'content': system_prompt})
messages.append({'role': 'user', 'content': user_question})
# 将消息转换为 JSON 字符串
messages_json = json.dumps(messages, ensure_ascii=False)
# 构建 llamafactory 推理脚本
inference_script = f'''
import sys
import json
import logging
logging.basicConfig(level=logging.WARNING)
# 构建 llamafactory-cli chat 命令
full_cmd = f'cd {llamafactory_dir} && export CUDA_VISIBLE_DEVICES=0 && echo \'{messages_json}\' | llamafactory-cli chat --model_name_or_path {base_model_path} --adapter_name_or_path {trained_model_path} --template llama3 --finetuning_type lora --temperature {temperature} --max_tokens {max_tokens}'
from llmtuner import ChatModel
def main():
chat_model = ChatModel({{
"model_name_or_path": "{base_model_path}",
"adapter_name_or_path": "{trained_model_path}",
"template": "llama3",
"finetuning_type": "lora",
"temperature": {temperature},
"max_new_tokens": {max_tokens}
}})
messages = {json.dumps(messages, ensure_ascii=False)}
response = chat_model.chat(messages)
print(response)
if __name__ == "__main__":
main()
'''
# 写入临时脚本
work_dir = '/app/base'
script_path = os.path.join(work_dir, 'temp_inference.py')
try:
# 执行命令
with open(script_path, 'w', encoding='utf-8') as f:
f.write(inference_script)
# 设置环境变量,包括 PYTHONPATH 以便找到 llamafactory
import sys as sys_module
env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'}
for path in ['/app/base', '/app', '/app/base/src/llamafactory']:
if os.path.exists(path) and os.path.exists(os.path.join(path, 'llmtuner')):
env['PYTHONPATH'] = path
break
# 执行推理脚本
result = subprocess.run(
full_cmd,
shell=True,
[sys_module.executable, 'temp_inference.py'],
capture_output=True,
text=True,
timeout=120,
cwd=work_dir
timeout=180,
cwd=work_dir,
env=env
)
output = result.stdout
error = result.stderr
# 清理临时脚本
os.remove(script_path)
# 解析输出提取assistant回复
# llamafactory-cli chat 输出格式通常是:
# <|im_start|>assistant
# xxx
# <|im_end|>
assistant_content = ''
if '<|im_start|>assistant' in output:
parts = output.split('<|im_start|>assistant')
if len(parts) > 1:
content_part = parts[1].split('<|im_end|>')[0].strip()
# 移除可能存在的换行前缀
content_part = content_part.lstrip('\n').strip()
assistant_content = content_part
elif result.returncode == 0:
# 如果没有特殊标记,尝试提取最后一部分作为回复
lines = output.strip().split('\n')
assistant_content = '\n'.join(lines).strip()
if result.returncode == 0:
assistant_content = result.stdout.strip()
return jsonify({
'code': 0,
'data': {
'model_name': model_name,
'train_method': train_method,
'response': assistant_content
}
})
else:
return jsonify({'code': 1, 'message': f'推理失败: {error or output}'})
return jsonify({
'code': 0,
'data': {
'model_name': model_name,
'train_method': train_method,
'response': assistant_content
}
})
error_msg = result.stderr.strip() if result.stderr else result.stdout.strip()
return jsonify({'code': 1, 'message': f'推理失败: {error_msg}'})
except subprocess.TimeoutExpired:
return jsonify({'code': 1, 'message': '推理超时,请稍后重试'})

View File

@@ -47,24 +47,32 @@ def generic_get_all(table_name, order_by='create_time DESC'):
def get_model_path_by_name(model_name):
"""根据模型名称查询模型路径(用于获取基座模型路径)"""
import logging
logger = logging.getLogger(__name__)
logger.info(f"[DEBUG get_model_path_by_name] 查询模型: {model_name}")
try:
conn = get_db_connection()
cursor = conn.cursor()
# 优先从训练任务表查询基座模型
logger.info(f"[DEBUG get_model_path_by_name] 尝试从fine_tune表查询...")
cursor.execute("""
SELECT base_model FROM fine_tune
SELECT base_model, output_model_name FROM fine_tune
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
LIMIT 1
""", (f'%/{model_name}', f'%{model_name}%'))
ft_result = cursor.fetchone()
logger.info(f"[DEBUG get_model_path_by_name] fine_tune查询结果: {ft_result}")
if ft_result and ft_result.get('base_model'):
base_model_val = ft_result['base_model']
logger.info(f"[DEBUG get_model_path_by_name] base_model_val: {base_model_val}")
# 如果是数字ID查询模型管理表获取路径
if str(base_model_val).isdigit():
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
model_result = cursor.fetchone()
logger.info(f"[DEBUG get_model_path_by_name] model_manage查询结果(数字ID): {model_result}")
if model_result:
cursor.close()
conn.close()
@@ -76,12 +84,15 @@ def get_model_path_by_name(model_name):
return base_model_val
# 如果训练任务表没找到,尝试从模型管理表按名称查询
logger.info(f"[DEBUG get_model_path_by_name] 尝试从model_manage表查询...")
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
result = cursor.fetchone()
logger.info(f"[DEBUG get_model_path_by_name] model_manage查询结果: {result}")
cursor.close()
conn.close()
if result:
return result.get('path')
logger.info(f"[DEBUG get_model_path_by_name] 未找到任何匹配返回None")
return None
except Exception as e:
logger.error(f"[ERROR] 查询模型路径失败: {e}")
@@ -377,6 +388,16 @@ def get_trained_models():
logger.info(f"[DEBUG] 找到 {len(models)} 个已训练模型")
# 检查每个模型是否已合并或正在合并
local_trained_path = os.path.join(PROJECT_ROOT, 'local_trained_models')
for model in models:
model_name = model['name']
merged_path = os.path.join(local_trained_path, model_name)
lock_file = os.path.join(local_trained_path, f'.merging_{model_name}.lock')
model['merged'] = os.path.exists(merged_path)
model['merging'] = os.path.exists(lock_file)
logger.info(f"[DEBUG] 模型 {model_name} 已合并: {model['merged']}, 正在合并: {model['merging']}")
return jsonify({
'code': 0,
'data': {
@@ -387,3 +408,264 @@ def get_trained_models():
except Exception as e:
logger.error(f"获取已训练模型列表失败: {e}")
return jsonify({'code': 1, 'message': str(e)})
# ============ 合并权重接口 ============
@model_manage_bp.route('/merge', methods=['POST'])
def merge_model():
"""合并模型权重将LoRA适配器合并到基座模型"""
import subprocess
import sys
import logging
logger = logging.getLogger(__name__)
data = request.json
model_name = data.get('model_name') # 模型名称
train_method = data.get('train_method', 'lora') # 训练方法
base_model_path = data.get('base_model_path') # 基座模型路径
if not model_name:
return jsonify({'code': 1, 'message': '缺少模型名称'})
logger.info(f"[MERGE] 开始合并模型: {model_name}, 方法: {train_method}")
# 如果没有提供基座模型路径,从数据库查询
if not base_model_path:
try:
conn = get_db_connection()
cursor = conn.cursor()
# 优先从训练任务表查询
cursor.execute("""
SELECT base_model FROM fine_tune
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
LIMIT 1
""", (f'%/{model_name}', f'%{model_name}%'))
ft_result = cursor.fetchone()
if ft_result and ft_result.get('base_model'):
base_model_val = ft_result['base_model']
if str(base_model_val).isdigit():
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
model_result = cursor.fetchone()
if model_result:
base_model_path = model_result.get('path')
else:
base_model_path = base_model_val
# 如果没找到,尝试从模型管理表按名称查询
if not base_model_path:
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
model_result = cursor.fetchone()
if model_result:
base_model_path = model_result.get('path')
conn.close()
if not base_model_path:
return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'})
except Exception as e:
logger.error(f"[MERGE] 查询模型配置失败: {e}")
return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'})
# 训练后的模型路径LoRA适配器
adapter_path = f"/app/base/saves/{train_method}/{model_name}"
# 检查路径是否存在
if not os.path.exists(adapter_path):
return jsonify({'code': 1, 'message': f'训练模型不存在: {adapter_path}'})
# 合并后的输出路径
output_path = f"/app/base/local_trained_models/{model_name}"
# 合并状态锁文件
lock_file = f"/app/base/local_trained_models/.merging_{model_name}.lock"
# 创建输出目录
os.makedirs(output_path, exist_ok=True)
# 创建锁文件表示正在合并中
try:
with open(lock_file, 'w') as f:
f.write('merging')
work_dir = '/app/base'
# 设置环境变量
env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'}
# 使用 llamafactory-cli export 命令(假设已在系统 PATH 中,与训练命令一致)
cli_cmd = ['llamafactory-cli', 'export']
# 检查 llamafactory-cli 是否存在
try:
# 尝试使用 which 命令Linux/Mac
subprocess.run(['which', 'llamafactory-cli'], capture_output=True, check=True)
except (subprocess.CalledProcessError, FileNotFoundError):
# Windows 上没有 which 命令,直接尝试执行
logger.info("[MERGE] which 命令不可用,直接尝试执行 llamafactory-cli")
# 构建完整命令参数
export_args = [
'--model_name_or_path', base_model_path,
'--adapter_name_or_path', adapter_path,
'--export_dir', output_path
]
logger.info(f"[MERGE] 执行合并命令: {' '.join(cli_cmd)} {' '.join(export_args)}")
# 直接执行 llamafactory-cli export 命令
result = subprocess.run(
cli_cmd + export_args,
capture_output=True,
text=True,
timeout=600,
cwd=work_dir or '/app/base',
env=env
)
logger.info(f"[MERGE] 命令返回码: {result.returncode}")
logger.info(f"[MERGE] stdout: {result.stdout[:500] if result.stdout else 'empty'}")
logger.info(f"[MERGE] stderr: {result.stderr[:500] if result.stderr else 'empty'}")
# 等待输出目录完全创建
import time
max_wait = 5 # 最多等待5秒
waited = 0
while not os.path.exists(output_path) and waited < max_wait:
time.sleep(0.5)
waited += 0.5
# 无论成功失败,都删除锁文件
if os.path.exists(lock_file):
os.remove(lock_file)
if result.returncode == 0:
# 确保目录存在才返回成功
if os.path.exists(output_path):
return jsonify({
'code': 0,
'message': f'模型权重已成功合并到 {output_path}',
'data': {
'model_name': model_name,
'output_path': output_path
}
})
else:
return jsonify({'code': 1, 'message': '合并失败:输出目录未创建'})
else:
error_msg = result.stderr.strip() if result.stderr else result.stdout.strip()
if not error_msg:
error_msg = f'命令执行失败,返回码: {result.returncode}'
return jsonify({'code': 1, 'message': f'合并失败: {error_msg}'})
except subprocess.TimeoutExpired:
logger.error("[MERGE] 合并超时")
# 删除锁文件
if os.path.exists(lock_file):
os.remove(lock_file)
return jsonify({'code': 1, 'message': '合并超时,请稍后重试'})
except Exception as e:
logger.error(f"[MERGE] 合并异常: {str(e)}")
return jsonify({'code': 1, 'message': f'合并异常: {str(e)}'})
# ============ 删除已训练模型接口 ============
@model_manage_bp.route('/trained-models/<model_name>', methods=['DELETE'])
def delete_trained_model(model_name):
"""删除已训练模型从local_trained_models目录"""
import shutil
import logging
logger = logging.getLogger(__name__)
try:
# 删除 local_trained_models 目录下的模型
model_path = os.path.join(PROJECT_ROOT, 'local_trained_models', model_name)
if not os.path.exists(model_path):
return jsonify({'code': 1, 'message': f'模型不存在: {model_name}'})
# 删除目录
shutil.rmtree(model_path)
logger.info(f"[DELETE] 已删除模型: {model_path}")
return jsonify({'code': 0, 'message': '删除成功'})
except Exception as e:
logger.error(f"[DELETE] 删除模型失败: {str(e)}")
return jsonify({'code': 1, 'message': f'删除失败: {str(e)}'})
# ============ 导出已训练模型接口 ============
@model_manage_bp.route('/trained-models/<model_name>/export', methods=['GET'])
def export_trained_model(model_name):
"""导出已训练模型打包成zip下载"""
import shutil
import logging
from flask import send_file
logger = logging.getLogger(__name__)
try:
# 优先从 local_trained_models 目录查找(合并后的模型)
model_path = os.path.join(PROJECT_ROOT, 'local_trained_models', model_name)
# 如果本地模型目录不存在,尝试从 saves 目录查找(未合并的模型)
if not os.path.exists(model_path):
# 查找 saves 目录下的模型
saves_path = os.path.join(PROJECT_ROOT, 'saves')
train_methods = ['lora', 'full', 'qlora', 'dpo', 'cpt', 'prefix', 'adapter', 'peft']
for method in train_methods:
potential_path = os.path.join(saves_path, method, model_name)
if os.path.exists(potential_path):
model_path = potential_path
logger.info(f"[EXPORT] 从 saves/{method} 目录找到模型: {model_path}")
break
# 如果还是找不到,返回错误
if not os.path.exists(model_path):
return jsonify({'code': 1, 'message': f'模型不存在: {model_name}'})
# 创建临时 zip 文件
zip_path = os.path.join(PROJECT_ROOT, 'temp_exports')
os.makedirs(zip_path, exist_ok=True)
zip_file = os.path.join(zip_path, f'{model_name}.zip')
# 如果已存在先删除
if os.path.exists(zip_file):
os.remove(zip_file)
# 打包成 zip
shutil.make_archive(zip_file[:-4], 'zip', model_path)
logger.info(f"[EXPORT] 已打包模型: {zip_file}")
# 发送文件给前端
response = send_file(
zip_file,
as_attachment=True,
download_name=f'{model_name}.zip',
mimetype='application/zip'
)
# 注册回调,删除临时文件
def cleanup():
try:
if os.path.exists(zip_file):
os.remove(zip_file)
logger.info(f"[EXPORT] 已清理临时文件: {zip_file}")
except:
pass
# 使用 after_request 清理
@response.call_on_close
def cleanup_after_request():
cleanup()
return response
except Exception as e:
logger.error(f"[EXPORT] 导出模型失败: {str(e)}")
return jsonify({'code': 1, 'message': f'导出失败: {str(e)}'})