1. 增加了合并权重
2. 修改了一些列表展示的bug
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
import os
|
||||
import pymysql
|
||||
import yaml
|
||||
import json
|
||||
import requests
|
||||
import concurrent.futures
|
||||
import subprocess
|
||||
@@ -211,6 +212,200 @@ def model_chat_batch():
|
||||
return jsonify({'code': 0, 'data': results})
|
||||
|
||||
|
||||
@model_chat_bp.route('/trained/preload', methods=['POST'])
|
||||
def preload_trained_model():
|
||||
"""预加载已训练模型(使用 llamafactory)"""
|
||||
import sys as sys_module
|
||||
import pymysql
|
||||
import yaml
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
data = request.json
|
||||
model_name = data.get('model_name') # 模型名称
|
||||
train_method = data.get('train_method', 'lora') # 训练方法: lora, full
|
||||
base_model_path = data.get('base_model_path') # 前端传递的基座模型路径
|
||||
|
||||
if not model_name:
|
||||
return jsonify({'code': 1, 'message': '缺少模型名称'})
|
||||
|
||||
logger.info(f"[PRELOAD] 开始预加载模型: {model_name}, 方法: {train_method}")
|
||||
logger.info(f"[PRELOAD] 前端传递的基座模型路径: {base_model_path}")
|
||||
|
||||
# 获取项目根目录
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
|
||||
try:
|
||||
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||
CONFIG = yaml.safe_load(f)
|
||||
except Exception as e:
|
||||
return jsonify({'code': 1, 'message': f'读取配置失败: {str(e)}'})
|
||||
|
||||
# 优先使用前端传递的基座模型路径,否则从数据库查询
|
||||
if not base_model_path:
|
||||
try:
|
||||
db_config = CONFIG['database']
|
||||
conn = pymysql.connect(
|
||||
host=db_config['host'],
|
||||
port=db_config['port'],
|
||||
user=db_config['username'],
|
||||
password=db_config['password'],
|
||||
database=db_config['name'],
|
||||
charset=db_config.get('charset', 'utf8mb4'),
|
||||
cursorclass=pymysql.cursors.DictCursor
|
||||
)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 优先从训练任务表查询基座模型
|
||||
logger.info(f"[PRELOAD] 尝试从fine_tune表查询模型: {model_name}")
|
||||
cursor.execute("""
|
||||
SELECT base_model, output_model_name FROM fine_tune
|
||||
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
|
||||
LIMIT 1
|
||||
""", (f'%/{model_name}', f'%{model_name}%'))
|
||||
ft_result = cursor.fetchone()
|
||||
logger.info(f"[PRELOAD] fine_tune查询结果: {ft_result}")
|
||||
|
||||
if ft_result and ft_result.get('base_model'):
|
||||
base_model_val = ft_result['base_model']
|
||||
logger.info(f"[PRELOAD] base_model_val: {base_model_val}")
|
||||
# 如果是数字ID,查询模型管理表获取路径
|
||||
if str(base_model_val).isdigit():
|
||||
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
|
||||
model_result = cursor.fetchone()
|
||||
logger.info(f"[PRELOAD] model_manage查询结果(数字ID): {model_result}")
|
||||
if model_result:
|
||||
base_model_path = model_result.get('path')
|
||||
else:
|
||||
# 直接是路径
|
||||
base_model_path = base_model_val
|
||||
|
||||
# 如果训练任务表没找到,尝试从模型管理表按名称查询
|
||||
if not base_model_path:
|
||||
logger.info(f"[PRELOAD] 尝试从model_manage表查询...")
|
||||
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
|
||||
model_result = cursor.fetchone()
|
||||
logger.info(f"[PRELOAD] model_manage查询结果: {model_result}")
|
||||
if model_result:
|
||||
base_model_path = model_result.get('path')
|
||||
|
||||
conn.close()
|
||||
|
||||
if not base_model_path:
|
||||
logger.error(f"[PRELOAD] 未找到模型 {model_name} 的基座模型配置")
|
||||
return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'})
|
||||
except Exception as e:
|
||||
logger.error(f"[PRELOAD] 查询模型配置失败: {e}")
|
||||
return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'})
|
||||
else:
|
||||
logger.info(f"[PRELOAD] 使用前端传递的基座模型路径: {base_model_path}")
|
||||
|
||||
# 训练后的模型路径
|
||||
trained_model_path = f"/app/base/saves/{train_method}/{model_name}"
|
||||
|
||||
# 检查路径是否存在(兼容Windows和Linux)
|
||||
if not os.path.exists(trained_model_path):
|
||||
logger.warning(f"[PRELOAD] 训练模型路径不存在: {trained_model_path}")
|
||||
# 尝试查找适配器文件来确定模型是否存在
|
||||
adapter_path = os.path.join(trained_model_path, 'adapter_model.bin')
|
||||
safetensors_path = os.path.join(trained_model_path, 'model.safetensors')
|
||||
if not os.path.exists(adapter_path) and not os.path.exists(safetensors_path):
|
||||
return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'})
|
||||
|
||||
# 预热消息 - 使用一个简单的问候语来加载模型
|
||||
work_dir = '/app/base'
|
||||
warmup_messages = [{'role': 'system', 'content': 'You are a helpful assistant.'}, {'role': 'user', 'content': 'Hello'}]
|
||||
|
||||
# 根据训练方法选择 finetuning_type
|
||||
finetuning_type = 'lora' if train_method == 'lora' else 'full'
|
||||
|
||||
# 构建 llamafactory 预热脚本
|
||||
preload_script = f'''
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
|
||||
from llmtuner import ChatModel
|
||||
|
||||
def main():
|
||||
chat_model = ChatModel({{
|
||||
"model_name_or_path": "{base_model_path}",
|
||||
"adapter_name_or_path": "{trained_model_path}",
|
||||
"template": "llama3",
|
||||
"finetuning_type": "{finetuning_type}",
|
||||
"temperature": 0.1,
|
||||
"max_new_tokens": 1
|
||||
}})
|
||||
|
||||
messages = {json.dumps(warmup_messages, ensure_ascii=False)}
|
||||
|
||||
# 执行推理以加载模型
|
||||
response = chat_model.chat(messages)
|
||||
print("Model loaded successfully")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
'''
|
||||
|
||||
try:
|
||||
work_dir = '/app/base'
|
||||
script_path = os.path.join(work_dir, 'temp_preload.py')
|
||||
|
||||
with open(script_path, 'w', encoding='utf-8') as f:
|
||||
f.write(preload_script)
|
||||
|
||||
# 设置环境变量,包括 PYTHONPATH 以便找到 llamafactory
|
||||
import sys as sys_module
|
||||
env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'}
|
||||
# 尝试查找 llamafactory 目录并添加到 PYTHONPATH
|
||||
llm_factory_paths = ['/app/base', '/app', '/app/base/src/llamafactory']
|
||||
for path in llm_factory_paths:
|
||||
if os.path.exists(path) and os.path.exists(os.path.join(path, 'llmtuner')):
|
||||
env['PYTHONPATH'] = path
|
||||
logger.info(f"[PRELOAD] 设置 PYTHONPATH={path}")
|
||||
break
|
||||
|
||||
logger.info(f"[PRELOAD] 执行预热脚本...")
|
||||
result = subprocess.run(
|
||||
[sys_module.executable, 'temp_preload.py'],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=180,
|
||||
cwd=work_dir,
|
||||
env=env
|
||||
)
|
||||
|
||||
os.remove(script_path)
|
||||
|
||||
logger.info(f"[PRELOAD] 命令返回码: {result.returncode}")
|
||||
logger.info(f"[PRELOAD] stdout: {result.stdout[:500] if result.stdout else 'empty'}")
|
||||
logger.info(f"[PRELOAD] stderr: {result.stderr[:500] if result.stderr else 'empty'}")
|
||||
|
||||
if result.returncode == 0:
|
||||
return jsonify({
|
||||
'code': 0,
|
||||
'message': '模型预加载成功',
|
||||
'data': {
|
||||
'model_name': model_name,
|
||||
'train_method': train_method,
|
||||
'base_model': base_model_path
|
||||
}
|
||||
})
|
||||
else:
|
||||
error_msg = result.stderr.strip() if result.stderr else result.stdout.strip()
|
||||
if not error_msg:
|
||||
error_msg = f'命令执行失败,返回码: {result.returncode}'
|
||||
return jsonify({'code': 1, 'message': f'预加载失败: {error_msg}'})
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("[PRELOAD] 预加载超时")
|
||||
return jsonify({'code': 1, 'message': '预加载超时,请稍后重试'})
|
||||
except Exception as e:
|
||||
logger.error(f"[PRELOAD] 预加载异常: {str(e)}")
|
||||
return jsonify({'code': 1, 'message': f'预加载异常: {str(e)}'})
|
||||
|
||||
|
||||
@model_chat_bp.route('/trained', methods=['POST'])
|
||||
def chat_trained_model():
|
||||
"""使用已训练模型进行对话推理"""
|
||||
@@ -220,6 +415,7 @@ def chat_trained_model():
|
||||
data = request.json
|
||||
model_name = data.get('model_name') # 模型名称
|
||||
train_method = data.get('train_method', 'lora') # 训练方法: lora, full
|
||||
base_model_path = data.get('base_model_path') # 前端传递的基座模型路径
|
||||
system_prompt = data.get('system_prompt', '')
|
||||
user_question = data.get('user_question')
|
||||
temperature = data.get('temperature', 0.7)
|
||||
@@ -236,39 +432,64 @@ def chat_trained_model():
|
||||
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||
CONFIG = yaml.safe_load(f)
|
||||
|
||||
# 获取基座模型路径 - 从数据库查询 model_name 对应的路径
|
||||
try:
|
||||
db_config = CONFIG['database']
|
||||
conn = pymysql.connect(
|
||||
host=db_config['host'],
|
||||
port=db_config['port'],
|
||||
user=db_config['username'],
|
||||
password=db_config['password'],
|
||||
database=db_config['name'],
|
||||
charset=db_config.get('charset', 'utf8mb4'),
|
||||
cursorclass=pymysql.cursors.DictCursor
|
||||
)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
|
||||
model_result = cursor.fetchone()
|
||||
conn.close()
|
||||
# 优先使用前端传递的基座模型路径,否则从数据库查询
|
||||
if not base_model_path:
|
||||
try:
|
||||
db_config = CONFIG['database']
|
||||
conn = pymysql.connect(
|
||||
host=db_config['host'],
|
||||
port=db_config['port'],
|
||||
user=db_config['username'],
|
||||
password=db_config['password'],
|
||||
database=db_config['name'],
|
||||
charset=db_config.get('charset', 'utf8mb4'),
|
||||
cursorclass=pymysql.cursors.DictCursor
|
||||
)
|
||||
cursor = conn.cursor()
|
||||
|
||||
if not model_result or not model_result.get('path'):
|
||||
return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的配置'})
|
||||
# 优先从训练任务表查询基座模型
|
||||
cursor.execute("""
|
||||
SELECT base_model, output_model_name FROM fine_tune
|
||||
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
|
||||
LIMIT 1
|
||||
""", (f'%/{model_name}', f'%{model_name}%'))
|
||||
ft_result = cursor.fetchone()
|
||||
|
||||
base_model_path = model_result['path']
|
||||
except Exception as e:
|
||||
return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'})
|
||||
if ft_result and ft_result.get('base_model'):
|
||||
base_model_val = ft_result['base_model']
|
||||
# 如果是数字ID,查询模型管理表获取路径
|
||||
if str(base_model_val).isdigit():
|
||||
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
|
||||
model_result = cursor.fetchone()
|
||||
if model_result:
|
||||
base_model_path = model_result.get('path')
|
||||
else:
|
||||
# 直接是路径
|
||||
base_model_path = base_model_val
|
||||
|
||||
# 如果训练任务表没找到,尝试从模型管理表按名称查询
|
||||
if not base_model_path:
|
||||
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
|
||||
model_result = cursor.fetchone()
|
||||
if model_result:
|
||||
base_model_path = model_result.get('path')
|
||||
|
||||
conn.close()
|
||||
|
||||
if not base_model_path:
|
||||
return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'})
|
||||
except Exception as e:
|
||||
return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'})
|
||||
|
||||
# 训练后的模型路径
|
||||
trained_model_path = f"/app/base/saves/{train_method}/{model_name}"
|
||||
|
||||
# 检查路径是否存在(兼容Windows和Linux)
|
||||
if not os.path.exists(trained_model_path):
|
||||
return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'})
|
||||
|
||||
# 构建 llamafactory-cli chat 命令
|
||||
work_dir = '/app/base'
|
||||
llamafactory_dir = '/app/base'
|
||||
adapter_path = os.path.join(trained_model_path, 'adapter_model.bin')
|
||||
safetensors_path = os.path.join(trained_model_path, 'model.safetensors')
|
||||
if not os.path.exists(adapter_path) and not os.path.exists(safetensors_path):
|
||||
return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'})
|
||||
|
||||
# 准备消息
|
||||
messages = []
|
||||
@@ -276,55 +497,76 @@ def chat_trained_model():
|
||||
messages.append({'role': 'system', 'content': system_prompt})
|
||||
messages.append({'role': 'user', 'content': user_question})
|
||||
|
||||
# 将消息转换为 JSON 字符串
|
||||
messages_json = json.dumps(messages, ensure_ascii=False)
|
||||
# 构建 llamafactory 推理脚本
|
||||
inference_script = f'''
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
|
||||
# 构建 llamafactory-cli chat 命令
|
||||
full_cmd = f'cd {llamafactory_dir} && export CUDA_VISIBLE_DEVICES=0 && echo \'{messages_json}\' | llamafactory-cli chat --model_name_or_path {base_model_path} --adapter_name_or_path {trained_model_path} --template llama3 --finetuning_type lora --temperature {temperature} --max_tokens {max_tokens}'
|
||||
from llmtuner import ChatModel
|
||||
|
||||
def main():
|
||||
chat_model = ChatModel({{
|
||||
"model_name_or_path": "{base_model_path}",
|
||||
"adapter_name_or_path": "{trained_model_path}",
|
||||
"template": "llama3",
|
||||
"finetuning_type": "lora",
|
||||
"temperature": {temperature},
|
||||
"max_new_tokens": {max_tokens}
|
||||
}})
|
||||
|
||||
messages = {json.dumps(messages, ensure_ascii=False)}
|
||||
|
||||
response = chat_model.chat(messages)
|
||||
print(response)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
'''
|
||||
|
||||
# 写入临时脚本
|
||||
work_dir = '/app/base'
|
||||
script_path = os.path.join(work_dir, 'temp_inference.py')
|
||||
|
||||
try:
|
||||
# 执行命令
|
||||
with open(script_path, 'w', encoding='utf-8') as f:
|
||||
f.write(inference_script)
|
||||
|
||||
# 设置环境变量,包括 PYTHONPATH 以便找到 llamafactory
|
||||
import sys as sys_module
|
||||
env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'}
|
||||
for path in ['/app/base', '/app', '/app/base/src/llamafactory']:
|
||||
if os.path.exists(path) and os.path.exists(os.path.join(path, 'llmtuner')):
|
||||
env['PYTHONPATH'] = path
|
||||
break
|
||||
|
||||
# 执行推理脚本
|
||||
result = subprocess.run(
|
||||
full_cmd,
|
||||
shell=True,
|
||||
[sys_module.executable, 'temp_inference.py'],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
cwd=work_dir
|
||||
timeout=180,
|
||||
cwd=work_dir,
|
||||
env=env
|
||||
)
|
||||
|
||||
output = result.stdout
|
||||
error = result.stderr
|
||||
# 清理临时脚本
|
||||
os.remove(script_path)
|
||||
|
||||
# 解析输出,提取assistant回复
|
||||
# llamafactory-cli chat 输出格式通常是:
|
||||
# <|im_start|>assistant
|
||||
# xxx
|
||||
# <|im_end|>
|
||||
|
||||
assistant_content = ''
|
||||
if '<|im_start|>assistant' in output:
|
||||
parts = output.split('<|im_start|>assistant')
|
||||
if len(parts) > 1:
|
||||
content_part = parts[1].split('<|im_end|>')[0].strip()
|
||||
# 移除可能存在的换行前缀
|
||||
content_part = content_part.lstrip('\n').strip()
|
||||
assistant_content = content_part
|
||||
elif result.returncode == 0:
|
||||
# 如果没有特殊标记,尝试提取最后一部分作为回复
|
||||
lines = output.strip().split('\n')
|
||||
assistant_content = '\n'.join(lines).strip()
|
||||
if result.returncode == 0:
|
||||
assistant_content = result.stdout.strip()
|
||||
return jsonify({
|
||||
'code': 0,
|
||||
'data': {
|
||||
'model_name': model_name,
|
||||
'train_method': train_method,
|
||||
'response': assistant_content
|
||||
}
|
||||
})
|
||||
else:
|
||||
return jsonify({'code': 1, 'message': f'推理失败: {error or output}'})
|
||||
|
||||
return jsonify({
|
||||
'code': 0,
|
||||
'data': {
|
||||
'model_name': model_name,
|
||||
'train_method': train_method,
|
||||
'response': assistant_content
|
||||
}
|
||||
})
|
||||
error_msg = result.stderr.strip() if result.stderr else result.stdout.strip()
|
||||
return jsonify({'code': 1, 'message': f'推理失败: {error_msg}'})
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return jsonify({'code': 1, 'message': '推理超时,请稍后重试'})
|
||||
|
||||
Reference in New Issue
Block a user