1. 增加了合并权重
2. 修改了一些列表展示的bug
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
import os
|
||||
import pymysql
|
||||
import yaml
|
||||
import json
|
||||
import requests
|
||||
import concurrent.futures
|
||||
import subprocess
|
||||
@@ -211,6 +212,200 @@ def model_chat_batch():
|
||||
return jsonify({'code': 0, 'data': results})
|
||||
|
||||
|
||||
@model_chat_bp.route('/trained/preload', methods=['POST'])
|
||||
def preload_trained_model():
|
||||
"""预加载已训练模型(使用 llamafactory)"""
|
||||
import sys as sys_module
|
||||
import pymysql
|
||||
import yaml
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
data = request.json
|
||||
model_name = data.get('model_name') # 模型名称
|
||||
train_method = data.get('train_method', 'lora') # 训练方法: lora, full
|
||||
base_model_path = data.get('base_model_path') # 前端传递的基座模型路径
|
||||
|
||||
if not model_name:
|
||||
return jsonify({'code': 1, 'message': '缺少模型名称'})
|
||||
|
||||
logger.info(f"[PRELOAD] 开始预加载模型: {model_name}, 方法: {train_method}")
|
||||
logger.info(f"[PRELOAD] 前端传递的基座模型路径: {base_model_path}")
|
||||
|
||||
# 获取项目根目录
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
|
||||
try:
|
||||
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||
CONFIG = yaml.safe_load(f)
|
||||
except Exception as e:
|
||||
return jsonify({'code': 1, 'message': f'读取配置失败: {str(e)}'})
|
||||
|
||||
# 优先使用前端传递的基座模型路径,否则从数据库查询
|
||||
if not base_model_path:
|
||||
try:
|
||||
db_config = CONFIG['database']
|
||||
conn = pymysql.connect(
|
||||
host=db_config['host'],
|
||||
port=db_config['port'],
|
||||
user=db_config['username'],
|
||||
password=db_config['password'],
|
||||
database=db_config['name'],
|
||||
charset=db_config.get('charset', 'utf8mb4'),
|
||||
cursorclass=pymysql.cursors.DictCursor
|
||||
)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 优先从训练任务表查询基座模型
|
||||
logger.info(f"[PRELOAD] 尝试从fine_tune表查询模型: {model_name}")
|
||||
cursor.execute("""
|
||||
SELECT base_model, output_model_name FROM fine_tune
|
||||
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
|
||||
LIMIT 1
|
||||
""", (f'%/{model_name}', f'%{model_name}%'))
|
||||
ft_result = cursor.fetchone()
|
||||
logger.info(f"[PRELOAD] fine_tune查询结果: {ft_result}")
|
||||
|
||||
if ft_result and ft_result.get('base_model'):
|
||||
base_model_val = ft_result['base_model']
|
||||
logger.info(f"[PRELOAD] base_model_val: {base_model_val}")
|
||||
# 如果是数字ID,查询模型管理表获取路径
|
||||
if str(base_model_val).isdigit():
|
||||
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
|
||||
model_result = cursor.fetchone()
|
||||
logger.info(f"[PRELOAD] model_manage查询结果(数字ID): {model_result}")
|
||||
if model_result:
|
||||
base_model_path = model_result.get('path')
|
||||
else:
|
||||
# 直接是路径
|
||||
base_model_path = base_model_val
|
||||
|
||||
# 如果训练任务表没找到,尝试从模型管理表按名称查询
|
||||
if not base_model_path:
|
||||
logger.info(f"[PRELOAD] 尝试从model_manage表查询...")
|
||||
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
|
||||
model_result = cursor.fetchone()
|
||||
logger.info(f"[PRELOAD] model_manage查询结果: {model_result}")
|
||||
if model_result:
|
||||
base_model_path = model_result.get('path')
|
||||
|
||||
conn.close()
|
||||
|
||||
if not base_model_path:
|
||||
logger.error(f"[PRELOAD] 未找到模型 {model_name} 的基座模型配置")
|
||||
return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'})
|
||||
except Exception as e:
|
||||
logger.error(f"[PRELOAD] 查询模型配置失败: {e}")
|
||||
return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'})
|
||||
else:
|
||||
logger.info(f"[PRELOAD] 使用前端传递的基座模型路径: {base_model_path}")
|
||||
|
||||
# 训练后的模型路径
|
||||
trained_model_path = f"/app/base/saves/{train_method}/{model_name}"
|
||||
|
||||
# 检查路径是否存在(兼容Windows和Linux)
|
||||
if not os.path.exists(trained_model_path):
|
||||
logger.warning(f"[PRELOAD] 训练模型路径不存在: {trained_model_path}")
|
||||
# 尝试查找适配器文件来确定模型是否存在
|
||||
adapter_path = os.path.join(trained_model_path, 'adapter_model.bin')
|
||||
safetensors_path = os.path.join(trained_model_path, 'model.safetensors')
|
||||
if not os.path.exists(adapter_path) and not os.path.exists(safetensors_path):
|
||||
return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'})
|
||||
|
||||
# 预热消息 - 使用一个简单的问候语来加载模型
|
||||
work_dir = '/app/base'
|
||||
warmup_messages = [{'role': 'system', 'content': 'You are a helpful assistant.'}, {'role': 'user', 'content': 'Hello'}]
|
||||
|
||||
# 根据训练方法选择 finetuning_type
|
||||
finetuning_type = 'lora' if train_method == 'lora' else 'full'
|
||||
|
||||
# 构建 llamafactory 预热脚本
|
||||
preload_script = f'''
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
|
||||
from llmtuner import ChatModel
|
||||
|
||||
def main():
|
||||
chat_model = ChatModel({{
|
||||
"model_name_or_path": "{base_model_path}",
|
||||
"adapter_name_or_path": "{trained_model_path}",
|
||||
"template": "llama3",
|
||||
"finetuning_type": "{finetuning_type}",
|
||||
"temperature": 0.1,
|
||||
"max_new_tokens": 1
|
||||
}})
|
||||
|
||||
messages = {json.dumps(warmup_messages, ensure_ascii=False)}
|
||||
|
||||
# 执行推理以加载模型
|
||||
response = chat_model.chat(messages)
|
||||
print("Model loaded successfully")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
'''
|
||||
|
||||
try:
|
||||
work_dir = '/app/base'
|
||||
script_path = os.path.join(work_dir, 'temp_preload.py')
|
||||
|
||||
with open(script_path, 'w', encoding='utf-8') as f:
|
||||
f.write(preload_script)
|
||||
|
||||
# 设置环境变量,包括 PYTHONPATH 以便找到 llamafactory
|
||||
import sys as sys_module
|
||||
env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'}
|
||||
# 尝试查找 llamafactory 目录并添加到 PYTHONPATH
|
||||
llm_factory_paths = ['/app/base', '/app', '/app/base/src/llamafactory']
|
||||
for path in llm_factory_paths:
|
||||
if os.path.exists(path) and os.path.exists(os.path.join(path, 'llmtuner')):
|
||||
env['PYTHONPATH'] = path
|
||||
logger.info(f"[PRELOAD] 设置 PYTHONPATH={path}")
|
||||
break
|
||||
|
||||
logger.info(f"[PRELOAD] 执行预热脚本...")
|
||||
result = subprocess.run(
|
||||
[sys_module.executable, 'temp_preload.py'],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=180,
|
||||
cwd=work_dir,
|
||||
env=env
|
||||
)
|
||||
|
||||
os.remove(script_path)
|
||||
|
||||
logger.info(f"[PRELOAD] 命令返回码: {result.returncode}")
|
||||
logger.info(f"[PRELOAD] stdout: {result.stdout[:500] if result.stdout else 'empty'}")
|
||||
logger.info(f"[PRELOAD] stderr: {result.stderr[:500] if result.stderr else 'empty'}")
|
||||
|
||||
if result.returncode == 0:
|
||||
return jsonify({
|
||||
'code': 0,
|
||||
'message': '模型预加载成功',
|
||||
'data': {
|
||||
'model_name': model_name,
|
||||
'train_method': train_method,
|
||||
'base_model': base_model_path
|
||||
}
|
||||
})
|
||||
else:
|
||||
error_msg = result.stderr.strip() if result.stderr else result.stdout.strip()
|
||||
if not error_msg:
|
||||
error_msg = f'命令执行失败,返回码: {result.returncode}'
|
||||
return jsonify({'code': 1, 'message': f'预加载失败: {error_msg}'})
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("[PRELOAD] 预加载超时")
|
||||
return jsonify({'code': 1, 'message': '预加载超时,请稍后重试'})
|
||||
except Exception as e:
|
||||
logger.error(f"[PRELOAD] 预加载异常: {str(e)}")
|
||||
return jsonify({'code': 1, 'message': f'预加载异常: {str(e)}'})
|
||||
|
||||
|
||||
@model_chat_bp.route('/trained', methods=['POST'])
|
||||
def chat_trained_model():
|
||||
"""使用已训练模型进行对话推理"""
|
||||
@@ -220,6 +415,7 @@ def chat_trained_model():
|
||||
data = request.json
|
||||
model_name = data.get('model_name') # 模型名称
|
||||
train_method = data.get('train_method', 'lora') # 训练方法: lora, full
|
||||
base_model_path = data.get('base_model_path') # 前端传递的基座模型路径
|
||||
system_prompt = data.get('system_prompt', '')
|
||||
user_question = data.get('user_question')
|
||||
temperature = data.get('temperature', 0.7)
|
||||
@@ -236,39 +432,64 @@ def chat_trained_model():
|
||||
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||
CONFIG = yaml.safe_load(f)
|
||||
|
||||
# 获取基座模型路径 - 从数据库查询 model_name 对应的路径
|
||||
try:
|
||||
db_config = CONFIG['database']
|
||||
conn = pymysql.connect(
|
||||
host=db_config['host'],
|
||||
port=db_config['port'],
|
||||
user=db_config['username'],
|
||||
password=db_config['password'],
|
||||
database=db_config['name'],
|
||||
charset=db_config.get('charset', 'utf8mb4'),
|
||||
cursorclass=pymysql.cursors.DictCursor
|
||||
)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
|
||||
model_result = cursor.fetchone()
|
||||
conn.close()
|
||||
# 优先使用前端传递的基座模型路径,否则从数据库查询
|
||||
if not base_model_path:
|
||||
try:
|
||||
db_config = CONFIG['database']
|
||||
conn = pymysql.connect(
|
||||
host=db_config['host'],
|
||||
port=db_config['port'],
|
||||
user=db_config['username'],
|
||||
password=db_config['password'],
|
||||
database=db_config['name'],
|
||||
charset=db_config.get('charset', 'utf8mb4'),
|
||||
cursorclass=pymysql.cursors.DictCursor
|
||||
)
|
||||
cursor = conn.cursor()
|
||||
|
||||
if not model_result or not model_result.get('path'):
|
||||
return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的配置'})
|
||||
# 优先从训练任务表查询基座模型
|
||||
cursor.execute("""
|
||||
SELECT base_model, output_model_name FROM fine_tune
|
||||
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
|
||||
LIMIT 1
|
||||
""", (f'%/{model_name}', f'%{model_name}%'))
|
||||
ft_result = cursor.fetchone()
|
||||
|
||||
base_model_path = model_result['path']
|
||||
except Exception as e:
|
||||
return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'})
|
||||
if ft_result and ft_result.get('base_model'):
|
||||
base_model_val = ft_result['base_model']
|
||||
# 如果是数字ID,查询模型管理表获取路径
|
||||
if str(base_model_val).isdigit():
|
||||
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
|
||||
model_result = cursor.fetchone()
|
||||
if model_result:
|
||||
base_model_path = model_result.get('path')
|
||||
else:
|
||||
# 直接是路径
|
||||
base_model_path = base_model_val
|
||||
|
||||
# 如果训练任务表没找到,尝试从模型管理表按名称查询
|
||||
if not base_model_path:
|
||||
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
|
||||
model_result = cursor.fetchone()
|
||||
if model_result:
|
||||
base_model_path = model_result.get('path')
|
||||
|
||||
conn.close()
|
||||
|
||||
if not base_model_path:
|
||||
return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'})
|
||||
except Exception as e:
|
||||
return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'})
|
||||
|
||||
# 训练后的模型路径
|
||||
trained_model_path = f"/app/base/saves/{train_method}/{model_name}"
|
||||
|
||||
# 检查路径是否存在(兼容Windows和Linux)
|
||||
if not os.path.exists(trained_model_path):
|
||||
return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'})
|
||||
|
||||
# 构建 llamafactory-cli chat 命令
|
||||
work_dir = '/app/base'
|
||||
llamafactory_dir = '/app/base'
|
||||
adapter_path = os.path.join(trained_model_path, 'adapter_model.bin')
|
||||
safetensors_path = os.path.join(trained_model_path, 'model.safetensors')
|
||||
if not os.path.exists(adapter_path) and not os.path.exists(safetensors_path):
|
||||
return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'})
|
||||
|
||||
# 准备消息
|
||||
messages = []
|
||||
@@ -276,55 +497,76 @@ def chat_trained_model():
|
||||
messages.append({'role': 'system', 'content': system_prompt})
|
||||
messages.append({'role': 'user', 'content': user_question})
|
||||
|
||||
# 将消息转换为 JSON 字符串
|
||||
messages_json = json.dumps(messages, ensure_ascii=False)
|
||||
# 构建 llamafactory 推理脚本
|
||||
inference_script = f'''
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
|
||||
# 构建 llamafactory-cli chat 命令
|
||||
full_cmd = f'cd {llamafactory_dir} && export CUDA_VISIBLE_DEVICES=0 && echo \'{messages_json}\' | llamafactory-cli chat --model_name_or_path {base_model_path} --adapter_name_or_path {trained_model_path} --template llama3 --finetuning_type lora --temperature {temperature} --max_tokens {max_tokens}'
|
||||
from llmtuner import ChatModel
|
||||
|
||||
def main():
|
||||
chat_model = ChatModel({{
|
||||
"model_name_or_path": "{base_model_path}",
|
||||
"adapter_name_or_path": "{trained_model_path}",
|
||||
"template": "llama3",
|
||||
"finetuning_type": "lora",
|
||||
"temperature": {temperature},
|
||||
"max_new_tokens": {max_tokens}
|
||||
}})
|
||||
|
||||
messages = {json.dumps(messages, ensure_ascii=False)}
|
||||
|
||||
response = chat_model.chat(messages)
|
||||
print(response)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
'''
|
||||
|
||||
# 写入临时脚本
|
||||
work_dir = '/app/base'
|
||||
script_path = os.path.join(work_dir, 'temp_inference.py')
|
||||
|
||||
try:
|
||||
# 执行命令
|
||||
with open(script_path, 'w', encoding='utf-8') as f:
|
||||
f.write(inference_script)
|
||||
|
||||
# 设置环境变量,包括 PYTHONPATH 以便找到 llamafactory
|
||||
import sys as sys_module
|
||||
env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'}
|
||||
for path in ['/app/base', '/app', '/app/base/src/llamafactory']:
|
||||
if os.path.exists(path) and os.path.exists(os.path.join(path, 'llmtuner')):
|
||||
env['PYTHONPATH'] = path
|
||||
break
|
||||
|
||||
# 执行推理脚本
|
||||
result = subprocess.run(
|
||||
full_cmd,
|
||||
shell=True,
|
||||
[sys_module.executable, 'temp_inference.py'],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
cwd=work_dir
|
||||
timeout=180,
|
||||
cwd=work_dir,
|
||||
env=env
|
||||
)
|
||||
|
||||
output = result.stdout
|
||||
error = result.stderr
|
||||
# 清理临时脚本
|
||||
os.remove(script_path)
|
||||
|
||||
# 解析输出,提取assistant回复
|
||||
# llamafactory-cli chat 输出格式通常是:
|
||||
# <|im_start|>assistant
|
||||
# xxx
|
||||
# <|im_end|>
|
||||
|
||||
assistant_content = ''
|
||||
if '<|im_start|>assistant' in output:
|
||||
parts = output.split('<|im_start|>assistant')
|
||||
if len(parts) > 1:
|
||||
content_part = parts[1].split('<|im_end|>')[0].strip()
|
||||
# 移除可能存在的换行前缀
|
||||
content_part = content_part.lstrip('\n').strip()
|
||||
assistant_content = content_part
|
||||
elif result.returncode == 0:
|
||||
# 如果没有特殊标记,尝试提取最后一部分作为回复
|
||||
lines = output.strip().split('\n')
|
||||
assistant_content = '\n'.join(lines).strip()
|
||||
if result.returncode == 0:
|
||||
assistant_content = result.stdout.strip()
|
||||
return jsonify({
|
||||
'code': 0,
|
||||
'data': {
|
||||
'model_name': model_name,
|
||||
'train_method': train_method,
|
||||
'response': assistant_content
|
||||
}
|
||||
})
|
||||
else:
|
||||
return jsonify({'code': 1, 'message': f'推理失败: {error or output}'})
|
||||
|
||||
return jsonify({
|
||||
'code': 0,
|
||||
'data': {
|
||||
'model_name': model_name,
|
||||
'train_method': train_method,
|
||||
'response': assistant_content
|
||||
}
|
||||
})
|
||||
error_msg = result.stderr.strip() if result.stderr else result.stdout.strip()
|
||||
return jsonify({'code': 1, 'message': f'推理失败: {error_msg}'})
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return jsonify({'code': 1, 'message': '推理超时,请稍后重试'})
|
||||
|
||||
@@ -47,24 +47,32 @@ def generic_get_all(table_name, order_by='create_time DESC'):
|
||||
|
||||
def get_model_path_by_name(model_name):
|
||||
"""根据模型名称查询模型路径(用于获取基座模型路径)"""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"[DEBUG get_model_path_by_name] 查询模型: {model_name}")
|
||||
|
||||
try:
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 优先从训练任务表查询基座模型
|
||||
logger.info(f"[DEBUG get_model_path_by_name] 尝试从fine_tune表查询...")
|
||||
cursor.execute("""
|
||||
SELECT base_model FROM fine_tune
|
||||
SELECT base_model, output_model_name FROM fine_tune
|
||||
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
|
||||
LIMIT 1
|
||||
""", (f'%/{model_name}', f'%{model_name}%'))
|
||||
ft_result = cursor.fetchone()
|
||||
logger.info(f"[DEBUG get_model_path_by_name] fine_tune查询结果: {ft_result}")
|
||||
|
||||
if ft_result and ft_result.get('base_model'):
|
||||
base_model_val = ft_result['base_model']
|
||||
logger.info(f"[DEBUG get_model_path_by_name] base_model_val: {base_model_val}")
|
||||
# 如果是数字ID,查询模型管理表获取路径
|
||||
if str(base_model_val).isdigit():
|
||||
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
|
||||
model_result = cursor.fetchone()
|
||||
logger.info(f"[DEBUG get_model_path_by_name] model_manage查询结果(数字ID): {model_result}")
|
||||
if model_result:
|
||||
cursor.close()
|
||||
conn.close()
|
||||
@@ -76,12 +84,15 @@ def get_model_path_by_name(model_name):
|
||||
return base_model_val
|
||||
|
||||
# 如果训练任务表没找到,尝试从模型管理表按名称查询
|
||||
logger.info(f"[DEBUG get_model_path_by_name] 尝试从model_manage表查询...")
|
||||
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
|
||||
result = cursor.fetchone()
|
||||
logger.info(f"[DEBUG get_model_path_by_name] model_manage查询结果: {result}")
|
||||
cursor.close()
|
||||
conn.close()
|
||||
if result:
|
||||
return result.get('path')
|
||||
logger.info(f"[DEBUG get_model_path_by_name] 未找到任何匹配,返回None")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[ERROR] 查询模型路径失败: {e}")
|
||||
@@ -387,3 +398,138 @@ def get_trained_models():
|
||||
except Exception as e:
|
||||
logger.error(f"获取已训练模型列表失败: {e}")
|
||||
return jsonify({'code': 1, 'message': str(e)})
|
||||
|
||||
|
||||
# ============ 合并权重接口 ============
|
||||
|
||||
@model_manage_bp.route('/merge', methods=['POST'])
|
||||
def merge_model():
|
||||
"""合并模型权重(将LoRA适配器合并到基座模型)"""
|
||||
import subprocess
|
||||
import sys
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
data = request.json
|
||||
model_name = data.get('model_name') # 模型名称
|
||||
train_method = data.get('train_method', 'lora') # 训练方法
|
||||
base_model_path = data.get('base_model_path') # 基座模型路径
|
||||
|
||||
if not model_name:
|
||||
return jsonify({'code': 1, 'message': '缺少模型名称'})
|
||||
|
||||
logger.info(f"[MERGE] 开始合并模型: {model_name}, 方法: {train_method}")
|
||||
|
||||
# 如果没有提供基座模型路径,从数据库查询
|
||||
if not base_model_path:
|
||||
try:
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 优先从训练任务表查询
|
||||
cursor.execute("""
|
||||
SELECT base_model FROM fine_tune
|
||||
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
|
||||
LIMIT 1
|
||||
""", (f'%/{model_name}', f'%{model_name}%'))
|
||||
ft_result = cursor.fetchone()
|
||||
|
||||
if ft_result and ft_result.get('base_model'):
|
||||
base_model_val = ft_result['base_model']
|
||||
if str(base_model_val).isdigit():
|
||||
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
|
||||
model_result = cursor.fetchone()
|
||||
if model_result:
|
||||
base_model_path = model_result.get('path')
|
||||
else:
|
||||
base_model_path = base_model_val
|
||||
|
||||
# 如果没找到,尝试从模型管理表按名称查询
|
||||
if not base_model_path:
|
||||
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
|
||||
model_result = cursor.fetchone()
|
||||
if model_result:
|
||||
base_model_path = model_result.get('path')
|
||||
|
||||
conn.close()
|
||||
|
||||
if not base_model_path:
|
||||
return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'})
|
||||
except Exception as e:
|
||||
logger.error(f"[MERGE] 查询模型配置失败: {e}")
|
||||
return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'})
|
||||
|
||||
# 训练后的模型路径(LoRA适配器)
|
||||
adapter_path = f"/app/base/saves/{train_method}/{model_name}"
|
||||
|
||||
# 检查路径是否存在
|
||||
if not os.path.exists(adapter_path):
|
||||
return jsonify({'code': 1, 'message': f'训练模型不存在: {adapter_path}'})
|
||||
|
||||
# 合并后的输出路径
|
||||
output_path = f"/app/base/local_trained_models/{model_name}"
|
||||
|
||||
# 创建输出目录
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
try:
|
||||
work_dir = '/app/base'
|
||||
|
||||
# 设置环境变量
|
||||
env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'}
|
||||
|
||||
# 使用 llamafactory-cli export 命令(假设已在系统 PATH 中,与训练命令一致)
|
||||
cli_cmd = ['llamafactory-cli', 'export']
|
||||
|
||||
# 检查 llamafactory-cli 是否存在
|
||||
try:
|
||||
# 尝试使用 which 命令(Linux/Mac)
|
||||
subprocess.run(['which', 'llamafactory-cli'], capture_output=True, check=True)
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
# Windows 上没有 which 命令,直接尝试执行
|
||||
logger.info("[MERGE] which 命令不可用,直接尝试执行 llamafactory-cli")
|
||||
|
||||
# 构建完整命令参数
|
||||
export_args = [
|
||||
'--model_name_or_path', base_model_path,
|
||||
'--adapter_name_or_path', adapter_path,
|
||||
'--export_dir', output_path
|
||||
]
|
||||
|
||||
logger.info(f"[MERGE] 执行合并命令: {' '.join(cli_cmd)} {' '.join(export_args)}")
|
||||
|
||||
# 直接执行 llamafactory-cli export 命令
|
||||
result = subprocess.run(
|
||||
cli_cmd + export_args,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=600,
|
||||
cwd=work_dir or '/app/base',
|
||||
env=env
|
||||
)
|
||||
|
||||
logger.info(f"[MERGE] 命令返回码: {result.returncode}")
|
||||
logger.info(f"[MERGE] stdout: {result.stdout[:500] if result.stdout else 'empty'}")
|
||||
logger.info(f"[MERGE] stderr: {result.stderr[:500] if result.stderr else 'empty'}")
|
||||
|
||||
if result.returncode == 0:
|
||||
return jsonify({
|
||||
'code': 0,
|
||||
'message': f'模型权重已成功合并到 {output_path}',
|
||||
'data': {
|
||||
'model_name': model_name,
|
||||
'output_path': output_path
|
||||
}
|
||||
})
|
||||
else:
|
||||
error_msg = result.stderr.strip() if result.stderr else result.stdout.strip()
|
||||
if not error_msg:
|
||||
error_msg = f'命令执行失败,返回码: {result.returncode}'
|
||||
return jsonify({'code': 1, 'message': f'合并失败: {error_msg}'})
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("[MERGE] 合并超时")
|
||||
return jsonify({'code': 1, 'message': '合并超时,请稍后重试'})
|
||||
except Exception as e:
|
||||
logger.error(f"[MERGE] 合并异常: {str(e)}")
|
||||
return jsonify({'code': 1, 'message': f'合并异常: {str(e)}'})
|
||||
|
||||
@@ -608,7 +608,7 @@
|
||||
'edit': '编辑',
|
||||
'compare': '开始对话',
|
||||
'chat': '对话',
|
||||
'view': '去推理'
|
||||
'view': '合并权重'
|
||||
};
|
||||
|
||||
// 训练进度缓存
|
||||
@@ -1108,11 +1108,11 @@
|
||||
function toggleSelectAll(checkbox, api) {
|
||||
// 使用保存的当前页面数据
|
||||
if (checkbox.checked) {
|
||||
// 全选当前页面的所有数据
|
||||
currentPageData.forEach(item => selectedItems.add(item.id));
|
||||
// 全选当前页面的所有数据(支持 name 或 id)
|
||||
currentPageData.forEach(item => selectedItems.add(item.name || item.id));
|
||||
} else {
|
||||
// 取消全选,移除当前页面所有数据的选中状态
|
||||
currentPageData.forEach(item => selectedItems.delete(item.id));
|
||||
currentPageData.forEach(item => selectedItems.delete(item.name || item.id));
|
||||
}
|
||||
refreshCurrentPage();
|
||||
}
|
||||
@@ -1266,14 +1266,16 @@
|
||||
|
||||
// 渲染表格页面
|
||||
function renderTablePage(config, data) {
|
||||
const createButton = config.hasCreate ? `
|
||||
<button onclick="showCreateModal('${config.api}')" class="bg-primary text-white px-3 py-1.5 rounded text-sm hover:bg-primary/90 transition-colors">
|
||||
<i class="fa fa-plus mr-1"></i>${config.createText}
|
||||
</button>
|
||||
` : '';
|
||||
// 定义表格列
|
||||
const columns = [
|
||||
{ title: '模型名称', key: 'name' },
|
||||
{ title: '训练方法', key: 'train_methods', render: (val) => val && val[0] ? val[0].name : '-' },
|
||||
{ title: '基座模型', key: 'base_model_path', render: (val) => `<span class="text-xs text-gray-500 truncate block" title="${val}">${val || '-'}</span>` },
|
||||
{ title: '创建时间', key: 'create_time', render: (val) => val ? new Date(val).toLocaleString('zh-CN') : '-' }
|
||||
];
|
||||
|
||||
// 搜索框(模型管理和数据集管理)
|
||||
const searchBox = (config.api === 'model-manage' || config.api === 'dataset-manage') ? `
|
||||
// 搜索框
|
||||
const searchBox = (config.api === 'model-manage' || config.api === 'model-manage/trained-models' || config.api === 'dataset-manage') ? `
|
||||
<div class="relative">
|
||||
<input type="text" id="tableSearchInput" placeholder="搜索${config.title}..."
|
||||
class="w-72 pl-9 pr-3 py-1.5 rounded border border-gray-300 text-sm focus:outline-none focus:border-primary focus:ring-1 focus:border-primary"
|
||||
@@ -1283,7 +1285,14 @@
|
||||
` : '';
|
||||
|
||||
// 是否支持多选(模型管理和数据集管理)
|
||||
const supportsMultiSelect = config.api === 'model-manage' || config.api === 'dataset-manage';
|
||||
const supportsMultiSelect = config.api === 'model-manage' || config.api === 'model-manage/trained-models' || config.api === 'dataset-manage';
|
||||
|
||||
// 创建按钮(根据API类型决定是否显示)
|
||||
const createButton = config.api === 'dataset-manage' ? `
|
||||
<button onclick="showCreateModal('${config.api}')" class="bg-primary text-white px-3 py-1.5 rounded text-sm hover:bg-primary/90 transition-colors">
|
||||
<i class="fa fa-plus mr-1"></i>创建数据集
|
||||
</button>
|
||||
` : '';
|
||||
|
||||
// 批量删除按钮(仅当有选中项时显示)
|
||||
const batchDeleteButton = supportsMultiSelect && selectedItems.size > 0 ? `
|
||||
@@ -1292,7 +1301,6 @@
|
||||
</button>
|
||||
` : '';
|
||||
|
||||
const columns = config.columns;
|
||||
const hasData = data && data.length > 0;
|
||||
|
||||
// 多选列头
|
||||
@@ -1334,12 +1342,12 @@
|
||||
</thead>
|
||||
<tbody>
|
||||
${hasData ? data.map(item => `
|
||||
<tr class="border-b border-gray-100 table-row-hover ${selectedItems.has(item.id) ? 'bg-blue-50' : ''}">
|
||||
<tr class="border-b border-gray-100 table-row-hover ${selectedItems.has(item.name || item.id) ? 'bg-blue-50' : ''}">
|
||||
${supportsMultiSelect ? `
|
||||
<td class="px-4 py-4 text-sm text-center">
|
||||
<input type="checkbox" class="w-4 h-4 text-primary rounded border-gray-300 cursor-pointer"
|
||||
${selectedItems.has(item.id) ? 'checked' : ''}
|
||||
onchange="toggleItemSelection(${item.id}, '${config.api}')">
|
||||
${selectedItems.has(item.name || item.id) ? 'checked' : ''}
|
||||
onchange="toggleItemSelection('${item.name || item.id}', '${config.api}')">
|
||||
</td>
|
||||
` : ''}
|
||||
${columns.map(col => `
|
||||
@@ -1349,33 +1357,15 @@
|
||||
`).join('')}
|
||||
<td class="px-4 py-4 text-sm text-center">
|
||||
<div class="flex justify-center space-x-2">
|
||||
${config.actions.map(action => {
|
||||
${['view', 'delete'].map(action => {
|
||||
let onclick = '';
|
||||
let btnClass = 'text-primary hover:text-primary/80';
|
||||
|
||||
// 对于 fine-tune 的停止按钮,检查状态
|
||||
if (action === 'stop' && config.api === 'fine-tune') {
|
||||
// 状态为 completed 或 failed 时隐藏停止按钮
|
||||
if (item.status === 'completed' || item.status === 'failed') {
|
||||
return '';
|
||||
}
|
||||
onclick = `stopItem(${item.id})`;
|
||||
btnClass = 'text-orange-500 hover:text-orange-600';
|
||||
if (action === 'view') {
|
||||
onclick = `viewTrainedModel('${item.name}', '${item.train_methods?.[0]?.name || 'lora'}', '${item.base_model_path || ''}')`;
|
||||
} else if (action === 'delete') {
|
||||
onclick = `deleteItem('${config.api}', ${item.id})`;
|
||||
onclick = `deleteItem('${config.api}', '${item.id}')`;
|
||||
btnClass = 'text-danger hover:text-danger/80';
|
||||
} else if (action === 'edit') {
|
||||
onclick = `editItem('${config.api}', ${item.id})`;
|
||||
} else if (action === 'preview' && config.api === 'dataset-manage') {
|
||||
onclick = `window.location.href = 'dataset-preview.html?id=${item.id}'`;
|
||||
} else if (action === 'download' && config.api === 'dataset-manage') {
|
||||
onclick = `downloadDataset('${item.id}')`;
|
||||
} else if (action === 'compare' && config.api === 'model-compare') {
|
||||
onclick = `startCompare(${item.id})`;
|
||||
} else if (action === 'logs' && config.api === 'fine-tune') {
|
||||
onclick = `navigateToTrainingLog(${item.id})`;
|
||||
} else if (action === 'view' && config.api === 'model-manage/trained-models') {
|
||||
onclick = `viewTrainedModel('${item.name}', '${item.train_methods?.[0]?.name || '-'}', '${item.path || ''}')`;
|
||||
} else {
|
||||
onclick = `showMessage('提示', '${actionLabels[action] || action}功能开发中...', 'info')`;
|
||||
}
|
||||
@@ -3164,10 +3154,55 @@
|
||||
document.body.style.overflow = '';
|
||||
}
|
||||
|
||||
// 查看已训练模型详情 - 跳转到推理页面
|
||||
window.viewTrainedModel = function(name, method, path) {
|
||||
// 跳转到推理测试页面(main.html在pages目录下,所以直接用文件名)
|
||||
window.location.href = `model-inference.html?model=${encodeURIComponent(name)}&method=${encodeURIComponent(method)}`;
|
||||
// 刷新表格数据 - 重新加载当前页面(必须在 viewTrainedModel 之前定义)
|
||||
window.loadTableData = function() {
|
||||
const activeLink = document.querySelector('.nav-link.sidebar-item-active');
|
||||
if (activeLink) {
|
||||
loadPage(activeLink.dataset.page);
|
||||
}
|
||||
};
|
||||
|
||||
// 合并模型权重
|
||||
window.viewTrainedModel = async function(name, method, path) {
|
||||
// 显示加载中弹窗
|
||||
const loadingModal = document.getElementById('loadingModal');
|
||||
if (loadingModal) {
|
||||
document.getElementById('loadingMessage').textContent = '正在合并模型权重,请稍候...';
|
||||
loadingModal.classList.remove('hidden');
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/model-manage/merge`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
model_name: name,
|
||||
train_method: method || 'lora',
|
||||
base_model_path: path
|
||||
})
|
||||
});
|
||||
|
||||
const result = await response.json();
|
||||
|
||||
// 隐藏加载弹窗
|
||||
if (loadingModal) {
|
||||
loadingModal.classList.add('hidden');
|
||||
}
|
||||
|
||||
if (result.code === 0) {
|
||||
showMessage('成功', '模型权重合并成功!', 'success');
|
||||
// 刷新模型列表
|
||||
loadTableData();
|
||||
} else {
|
||||
showMessage('失败', result.message || '合并失败', 'error');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('[DEBUG] 合并失败:', error);
|
||||
if (loadingModal) {
|
||||
loadingModal.classList.add('hidden');
|
||||
}
|
||||
showMessage('错误', '合并失败: ' + error.message, 'error');
|
||||
}
|
||||
};
|
||||
|
||||
// 确认弹窗(两个按钮)- 使用 window 确保全局可访问
|
||||
@@ -3446,5 +3481,15 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 加载中弹窗 -->
|
||||
<div id="loadingModal" class="hidden fixed inset-0 bg-black/50 z-50 flex items-center justify-center">
|
||||
<div class="bg-white rounded-xl shadow-xl max-w-sm w-full mx-4 overflow-hidden transform transition-all">
|
||||
<div class="flex flex-col items-center justify-center min-h-[160px] py-6">
|
||||
<i class="fa fa-spinner fa-spin text-3xl text-primary mb-4"></i>
|
||||
<p id="loadingMessage" class="text-gray-600 text-sm">正在处理...</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
@@ -3,126 +3,71 @@
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>模型推理 - 远光软件微调平台</title>
|
||||
<script src="../lib/tailwindcss/tailwind.js"></script>
|
||||
<title>合并权重 - YG_FT</title>
|
||||
<link href="../../assets/libs/font-awesome-4.7.0/css/font-awesome.min.css" rel="stylesheet">
|
||||
<script src="../../assets/libs/tailwindcss.js"></script>
|
||||
<script>
|
||||
if (typeof console !== 'undefined' && console.warn) {
|
||||
const originalWarn = console.warn;
|
||||
console.warn = function(...args) {
|
||||
if (args[0] && args[0].includes && args[0].includes('cdn.tailwindcss.com')) {
|
||||
return;
|
||||
tailwind.config = {
|
||||
theme: {
|
||||
extend: {
|
||||
colors: {
|
||||
primary: '#1890ff'
|
||||
}
|
||||
}
|
||||
originalWarn.apply(console, args);
|
||||
};
|
||||
}
|
||||
}
|
||||
</script>
|
||||
<link href="../lib/font-awesome/css/font-awesome.min.css" rel="stylesheet">
|
||||
<style>
|
||||
:root { --primary: #1890ff; --danger: #f5222d; --success: #52c41a; }
|
||||
.sidebar-section-title {
|
||||
padding: 0.5rem 1rem;
|
||||
font-size: 0.75rem;
|
||||
color: rgba(191, 203, 217, 0.7);
|
||||
font-weight: 500;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
}
|
||||
.nav-link:hover { background-color: rgba(0, 21, 41, 0.2); }
|
||||
.sidebar-item-active {
|
||||
background-color: rgba(24, 144, 255, 0.1);
|
||||
color: #1890ff;
|
||||
border-left: 4px solid #1890ff;
|
||||
background-color: #1890ff !important;
|
||||
color: white !important;
|
||||
}
|
||||
/* 侧边栏滑块动画 */
|
||||
.sidebar-slider {
|
||||
position: absolute;
|
||||
width: 4px;
|
||||
height: 0;
|
||||
background-color: #1890ff;
|
||||
border-radius: 0 2px 2px 0;
|
||||
transition: top 0.3s cubic-bezier(0.4, 0, 0.2, 1),
|
||||
height 0.3s cubic-bezier(0.4, 0, 0.2, 1);
|
||||
pointer-events: none;
|
||||
z-index: 10;
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
|
||||
}
|
||||
.nav-item-wrapper {
|
||||
position: relative;
|
||||
}
|
||||
.nav-link {
|
||||
position: relative;
|
||||
z-index: 1;
|
||||
}
|
||||
.chat-message { animation: fadeIn 0.3s ease; }
|
||||
@keyframes fadeIn { from { opacity: 0; transform: translateY(10px); } to { opacity: 1; transform: translateY(0); } }
|
||||
.typing-indicator span {
|
||||
display: inline-block;
|
||||
width: 6px;
|
||||
height: 6px;
|
||||
background-color: #1890ff;
|
||||
border-radius: 50%;
|
||||
margin: 0 2px;
|
||||
animation: typing 1.4s infinite ease-in-out;
|
||||
}
|
||||
.typing-indicator span:nth-child(2) { animation-delay: 0.2s; }
|
||||
.typing-indicator span:nth-child(3) { animation-delay: 0.4s; }
|
||||
@keyframes typing { 0%, 60%, 100% { transform: translateY(0); } 30% { transform: translateY(-8px); } }
|
||||
.bg-primary { background-color: #1890ff; }
|
||||
.text-primary { color: #1890ff; }
|
||||
.border-primary { border-color: #1890ff; }
|
||||
</style>
|
||||
</head>
|
||||
<body class="antialiased bg-gray-100 flex h-screen overflow-hidden">
|
||||
<!-- 侧边导航 -->
|
||||
<aside class="w-64 text-[#bfcbd9] flex-shrink-0 hidden md:block flex flex-col h-full" style="background-color: #001529;">
|
||||
<!-- 平台LOGO区域 -->
|
||||
<div class="pt-5 pb-3 border-b border-[#001529]/30 flex items-center justify-center pl-2">
|
||||
<img src="../assets/logo/logo.png" alt="Logo" class="w-8 h-8 object-contain mr-2">
|
||||
<span class="text-white font-medium text-base">远光软件微调平台</span>
|
||||
<body class="bg-gray-100 h-screen flex overflow-hidden">
|
||||
<!-- 侧边栏 -->
|
||||
<aside class="w-56 bg-[#001529] text-white flex flex-col shrink-0">
|
||||
<div class="h-14 flex items-center px-4 border-b border-[#001529]/30">
|
||||
<i class="fa fa-cube text-primary text-xl"></i>
|
||||
<span class="ml-2 font-medium text-lg">YG_FT</span>
|
||||
</div>
|
||||
|
||||
<!-- 导航主区域 -->
|
||||
<nav class="flex-1 overflow-y-auto py-2 relative">
|
||||
<!-- 滑块指示器 -->
|
||||
<div class="sidebar-slider" id="sidebar-slider"></div>
|
||||
|
||||
<!-- 第一分区:模型服务 -->
|
||||
<div class="sidebar-section-title">模型服务</div>
|
||||
<div class="nav-item-wrapper">
|
||||
<a href="#" data-page="fine-tune" class="nav-link flex items-center px-4 py-2.5 hover:bg-[#001529]/20 transition-colors">
|
||||
<i class="fa fa-cogs w-5 text-center"></i>
|
||||
<span class="ml-2">模型调优</span>
|
||||
</a>
|
||||
</div>
|
||||
<nav class="flex-1 overflow-y-auto py-4">
|
||||
<!-- 第一分区:模型管理 -->
|
||||
<div class="sidebar-section-title">模型管理</div>
|
||||
<div class="nav-item-wrapper">
|
||||
<a href="main.html?page=my-models" data-page="my-models" class="nav-link flex items-center px-4 py-2.5 hover:bg-[#001529]/20 transition-colors">
|
||||
<i class="fa fa-database w-5 text-center"></i>
|
||||
<i class="fa fa-cube w-5 text-center"></i>
|
||||
<span class="ml-2">我的模型</span>
|
||||
</a>
|
||||
</div>
|
||||
<div class="nav-item-wrapper">
|
||||
<a href="#" data-page="model-eval" class="nav-link flex items-center px-4 py-2.5 hover:bg-[#001529]/20 transition-colors">
|
||||
<i class="fa fa-line-chart w-5 text-center"></i>
|
||||
<span class="ml-2">模型评测</span>
|
||||
<a href="main.html?page=model-create" data-page="model-create" class="nav-link flex items-center px-4 py-2.5 hover:bg-[#001529]/20 transition-colors">
|
||||
<i class="fa fa-plus w-5 text-center"></i>
|
||||
<span class="ml-2">添加模型</span>
|
||||
</a>
|
||||
</div>
|
||||
<div class="nav-item-wrapper">
|
||||
<a href="#" data-page="model-compare" class="nav-link flex items-center px-4 py-2.5 hover:bg-[#001529]/20 transition-colors">
|
||||
<i class="fa fa-server w-5 text-center"></i>
|
||||
<a href="main.html?page=model-compare" data-page="model-compare" class="nav-link flex items-center px-4 py-2.5 hover:bg-[#001529]/20 transition-colors">
|
||||
<i class="fa fa-clone w-5 text-center"></i>
|
||||
<span class="ml-2">模型对比</span>
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<!-- 第二分区:资源管理 -->
|
||||
<div class="sidebar-section-title mt-6">资源管理</div>
|
||||
<!-- 第二分区:训练管理 -->
|
||||
<div class="sidebar-section-title mt-6">训练管理</div>
|
||||
<div class="nav-item-wrapper">
|
||||
<a href="#" data-page="model-manage" class="nav-link flex items-center px-4 py-2.5 hover:bg-[#001529]/20 transition-colors">
|
||||
<i class="fa fa-cube w-5 text-center"></i>
|
||||
<span class="ml-2">模型管理</span>
|
||||
<a href="main.html?page=fine-tune" data-page="fine-tune" class="nav-link flex items-center px-4 py-2.5 hover:bg-[#001529]/20 transition-colors">
|
||||
<i class="fa fa-magic w-5 text-center"></i>
|
||||
<span class="ml-2">训练任务</span>
|
||||
</a>
|
||||
</div>
|
||||
<div class="nav-item-wrapper">
|
||||
<a href="#" data-page="dataset-manage" class="nav-link flex items-center px-4 py-2.5 hover:bg-[#001529]/20 transition-colors">
|
||||
<i class="fa fa-file-text w-5 text-center"></i>
|
||||
<a href="main.html?page=data-manage" data-page="data-manage" class="nav-link flex items-center px-4 py-2.5 hover:bg-[#001529]/20 transition-colors">
|
||||
<i class="fa fa-database w-5 text-center"></i>
|
||||
<span class="ml-2">数据集管理</span>
|
||||
</a>
|
||||
</div>
|
||||
@@ -177,7 +122,7 @@
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<!-- 对话区域 -->
|
||||
<!-- 主内容 -->
|
||||
<main class="flex-1 flex flex-col overflow-hidden bg-gray-50">
|
||||
<!-- 模型信息栏 -->
|
||||
<div class="bg-white border-b border-gray-200 px-6 py-3 flex items-center justify-between">
|
||||
@@ -191,57 +136,35 @@
|
||||
<span id="baseModel" class="text-sm text-gray-700">-</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center space-x-4">
|
||||
<div class="flex items-center">
|
||||
<label class="text-xs text-gray-400 mr-2">Temperature:</label>
|
||||
<input type="range" id="temperature" min="0" max="1" step="0.1" value="0.7" class="w-24">
|
||||
<span id="temperatureValue" class="ml-2 text-sm text-gray-600">0.7</span>
|
||||
</div>
|
||||
|
||||
<!-- 合并权重区域 -->
|
||||
<div id="mergeContainer" class="flex-1 overflow-y-auto p-6">
|
||||
<!-- 合并状态 -->
|
||||
<div id="mergeState" class="flex flex-col items-center justify-center h-full text-center py-12">
|
||||
<div class="w-16 h-16 rounded-full bg-blue-100 flex items-center justify-center mb-4">
|
||||
<i class="fa fa-compress text-2xl text-blue-500"></i>
|
||||
</div>
|
||||
<button onclick="clearChat()" class="text-xs text-gray-500 hover:text-gray-700">
|
||||
<i class="fa fa-trash-o mr-1"></i>清空对话
|
||||
<h3 class="text-lg font-medium text-gray-600 mb-2">合并模型权重</h3>
|
||||
<p class="text-sm text-gray-400 max-w-sm mb-4" id="mergeText">将LoRA适配器权重合并到基座模型</p>
|
||||
<button onclick="mergeWeights()" id="mergeBtn"
|
||||
class="px-6 py-2.5 bg-primary text-white rounded-lg hover:bg-primary/90 transition-colors flex items-center">
|
||||
<i class="fa fa-compress mr-2"></i>
|
||||
<span>开始合并</span>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 聊天消息区域 -->
|
||||
<div id="chatContainer" class="flex-1 overflow-y-auto p-6 space-y-4">
|
||||
<!-- 欢迎消息 -->
|
||||
<div class="chat-message flex justify-start">
|
||||
<div class="max-w-3xl bg-white rounded-lg shadow-sm p-4 border border-gray-200">
|
||||
<div class="flex items-start">
|
||||
<div class="w-8 h-8 rounded-full bg-primary flex items-center justify-center flex-shrink-0">
|
||||
<i class="fa fa-robot text-white text-sm"></i>
|
||||
</div>
|
||||
<div class="ml-3">
|
||||
<div class="text-xs text-gray-400 mb-1">AI 助手</div>
|
||||
<div class="text-sm text-gray-700 leading-relaxed">
|
||||
您好!我是基于该模型训练的AI助手。请在下方输入您的问题,我会尽力为您解答。
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 输入区域 -->
|
||||
<div class="bg-white border-t border-gray-200 p-4">
|
||||
<div class="max-w-4xl mx-auto">
|
||||
<div class="flex items-stretch gap-3">
|
||||
<div class="flex-1">
|
||||
<textarea id="userInput" placeholder="请输入您的问题..." rows="2"
|
||||
class="w-full px-4 py-2.5 border border-gray-300 rounded-lg focus:border-primary focus:outline-none resize-none text-sm h-full"
|
||||
onkeydown="if(event.key === 'Enter' && !event.shiftKey) { event.preventDefault(); sendMessage(); }"></textarea>
|
||||
</div>
|
||||
<button onclick="sendMessage()" id="sendBtn"
|
||||
class="px-5 py-2 bg-primary text-white rounded-lg hover:bg-primary/90 transition-colors flex items-center justify-center h-auto self-center">
|
||||
<i class="fa fa-paper-plane mr-1.5"></i>
|
||||
<span>发送</span>
|
||||
</button>
|
||||
</div>
|
||||
<div class="mt-2 text-xs text-gray-400 flex items-center">
|
||||
<i class="fa fa-info-circle mr-1"></i>
|
||||
按 Enter 发送,Shift+Enter 换行
|
||||
<!-- 合并结果 -->
|
||||
<div id="mergeResult" class="hidden flex flex-col items-center justify-center h-full text-center py-12">
|
||||
<div class="w-16 h-16 rounded-full bg-green-100 flex items-center justify-center mb-4">
|
||||
<i class="fa fa-check-circle text-2xl text-green-500"></i>
|
||||
</div>
|
||||
<h3 class="text-lg font-medium text-gray-600 mb-2">合并完成</h3>
|
||||
<p class="text-sm text-gray-400 max-w-sm mb-4" id="resultText"></p>
|
||||
<button onclick="location.reload()"
|
||||
class="px-6 py-2.5 bg-primary text-white rounded-lg hover:bg-primary/90 transition-colors">
|
||||
<i class="fa fa-refresh mr-2"></i>
|
||||
<span>返回</span>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</main>
|
||||
@@ -262,11 +185,9 @@
|
||||
trainMethod: '',
|
||||
baseModel: ''
|
||||
};
|
||||
let chatHistory = [];
|
||||
let isLoading = false;
|
||||
|
||||
// 页面初始化
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
document.addEventListener('DOMContentLoaded', async function() {
|
||||
// 解析URL参数
|
||||
const urlParams = new URLSearchParams(window.location.search);
|
||||
const modelName = urlParams.get('model');
|
||||
@@ -289,6 +210,11 @@
|
||||
document.getElementById('trainMethod').textContent = methodDisplay[currentModel.trainMethod] || currentModel.trainMethod;
|
||||
}
|
||||
|
||||
// 加载模型信息
|
||||
if (currentModel.name) {
|
||||
await loadModelInfo();
|
||||
}
|
||||
|
||||
// 绑定侧边栏导航点击事件
|
||||
document.querySelectorAll('.nav-link').forEach(link => {
|
||||
link.addEventListener('click', function(e) {
|
||||
@@ -306,217 +232,97 @@
|
||||
}
|
||||
});
|
||||
updateSidebarSlider();
|
||||
|
||||
// 绑定温度滑块
|
||||
const tempSlider = document.getElementById('temperature');
|
||||
const tempValue = document.getElementById('temperatureValue');
|
||||
tempSlider.addEventListener('input', function() {
|
||||
tempValue.textContent = this.value;
|
||||
});
|
||||
|
||||
// 聚焦输入框
|
||||
document.getElementById('userInput').focus();
|
||||
|
||||
console.log('[DEBUG] 模型推理页面初始化:', currentModel);
|
||||
});
|
||||
|
||||
// 发送消息
|
||||
async function sendMessage() {
|
||||
const input = document.getElementById('userInput');
|
||||
const message = input.value.trim();
|
||||
// 侧边栏高亮
|
||||
function updateSidebarSlider() {
|
||||
const activeItem = document.querySelector('.sidebar-item-active');
|
||||
const slider = document.getElementById('sidebarSlider');
|
||||
const wrapper = document.querySelector('.nav-item-wrapper');
|
||||
if (activeItem && slider && wrapper) {
|
||||
slider.style.display = 'block';
|
||||
slider.style.width = wrapper.offsetWidth + 'px';
|
||||
slider.style.top = wrapper.offsetTop + 'px';
|
||||
slider.style.height = wrapper.offsetHeight + 'px';
|
||||
}
|
||||
}
|
||||
|
||||
if (!message) return;
|
||||
if (isLoading) return;
|
||||
// 加载模型信息
|
||||
async function loadModelInfo() {
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/model-manage/trained-models`);
|
||||
const result = await response.json();
|
||||
|
||||
if (result.code !== 0) {
|
||||
console.error('[DEBUG] 获取模型列表失败');
|
||||
return;
|
||||
}
|
||||
|
||||
const models = result.data?.models || [];
|
||||
const modelInfo = models.find(m => m.name === currentModel.name);
|
||||
console.log('[DEBUG] 模型查找:', currentModel.name, modelInfo);
|
||||
|
||||
if (modelInfo) {
|
||||
currentModel.baseModel = modelInfo.base_model_path || '';
|
||||
document.getElementById('baseModel').textContent = currentModel.baseModel || '未知';
|
||||
} else {
|
||||
document.getElementById('baseModel').textContent = '未知';
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('[DEBUG] 加载模型信息失败:', error);
|
||||
document.getElementById('baseModel').textContent = '未知';
|
||||
}
|
||||
}
|
||||
|
||||
// 合并权重
|
||||
async function mergeWeights() {
|
||||
if (!currentModel.name) {
|
||||
alert('请先选择模型');
|
||||
alert('模型参数无效');
|
||||
return;
|
||||
}
|
||||
|
||||
// 添加用户消息
|
||||
addMessage('user', message);
|
||||
chatHistory.push({ role: 'user', content: message });
|
||||
const mergeBtn = document.getElementById('mergeBtn');
|
||||
const mergeText = document.getElementById('mergeText');
|
||||
const originalBtnText = mergeBtn.innerHTML;
|
||||
|
||||
// 清空输入框
|
||||
input.value = '';
|
||||
|
||||
// 显示加载状态
|
||||
isLoading = true;
|
||||
updateSendButton(true);
|
||||
|
||||
// 添加AI加载提示
|
||||
const loadingId = addLoadingMessage();
|
||||
mergeBtn.disabled = true;
|
||||
mergeBtn.innerHTML = '<i class="fa fa-spinner fa-spin mr-2"></i>合并中...';
|
||||
mergeText.textContent = '正在合并模型权重,请稍候...';
|
||||
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/model-chat/trained`, {
|
||||
const response = await fetch(`${API_BASE}/model-manage/merge`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
model_name: currentModel.name,
|
||||
train_method: currentModel.trainMethod || 'lora',
|
||||
system_prompt: '',
|
||||
user_question: message,
|
||||
temperature: parseFloat(document.getElementById('temperature').value),
|
||||
max_tokens: 2048
|
||||
base_model_path: currentModel.baseModel
|
||||
})
|
||||
});
|
||||
|
||||
const result = await response.json();
|
||||
|
||||
// 移除加载提示
|
||||
removeLoadingMessage(loadingId);
|
||||
console.log('[DEBUG] 合并响应:', result);
|
||||
|
||||
if (result.code === 0) {
|
||||
const aiResponse = result.data?.response || '(无回复)';
|
||||
addMessage('assistant', aiResponse);
|
||||
chatHistory.push({ role: 'assistant', content: aiResponse });
|
||||
mergeText.textContent = '合并成功!';
|
||||
document.getElementById('mergeState').classList.add('hidden');
|
||||
document.getElementById('mergeResult').classList.remove('hidden');
|
||||
document.getElementById('resultText').textContent = result.message || '模型权重已成功合并';
|
||||
} else {
|
||||
addMessage('assistant', `推理失败: ${result.message || '未知错误'}`);
|
||||
mergeText.textContent = result.message || '合并失败';
|
||||
mergeBtn.disabled = false;
|
||||
mergeBtn.innerHTML = originalBtnText;
|
||||
}
|
||||
} catch (error) {
|
||||
removeLoadingMessage(loadingId);
|
||||
addMessage('assistant', `请求异常: ${error.message}`);
|
||||
} finally {
|
||||
isLoading = false;
|
||||
updateSendButton(false);
|
||||
}
|
||||
}
|
||||
|
||||
// 添加消息到聊天框
|
||||
function addMessage(role, content) {
|
||||
const container = document.getElementById('chatContainer');
|
||||
const isUser = role === 'user';
|
||||
|
||||
const messageDiv = document.createElement('div');
|
||||
messageDiv.className = 'chat-message flex justify-' + (isUser ? 'end' : 'start');
|
||||
|
||||
const avatar = isUser
|
||||
? `<div class="w-8 h-8 rounded-full bg-gray-400 flex items-center justify-center flex-shrink-0">
|
||||
<i class="fa fa-user text-white text-sm"></i>
|
||||
</div>`
|
||||
: `<div class="w-8 h-8 rounded-full bg-primary flex items-center justify-center flex-shrink-0">
|
||||
<i class="fa fa-robot text-white text-sm"></i>
|
||||
</div>`;
|
||||
|
||||
const bgClass = isUser ? 'bg-primary text-white' : 'bg-white border border-gray-200';
|
||||
|
||||
messageDiv.innerHTML = `
|
||||
<div class="max-w-3xl ${bgClass} rounded-lg shadow-sm p-4">
|
||||
<div class="flex items-start ${isUser ? 'flex-row-reverse' : ''}">
|
||||
${avatar}
|
||||
<div class="ml-3 ${isUser ? 'text-right' : ''} flex-1">
|
||||
<div class="text-xs ${isUser ? 'text-gray-300' : 'text-gray-400'} mb-1">
|
||||
${isUser ? '你' : 'AI 助手'}
|
||||
</div>
|
||||
<div class="text-sm leading-relaxed whitespace-pre-wrap">${escapeHtml(content)}</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
|
||||
container.appendChild(messageDiv);
|
||||
container.scrollTop = container.scrollHeight;
|
||||
}
|
||||
|
||||
// 添加加载中的消息
|
||||
function addLoadingMessage() {
|
||||
const container = document.getElementById('chatContainer');
|
||||
const loadingId = 'loading-' + Date.now();
|
||||
|
||||
const loadingDiv = document.createElement('div');
|
||||
loadingDiv.id = loadingId;
|
||||
loadingDiv.className = 'chat-message flex justify-start';
|
||||
|
||||
loadingDiv.innerHTML = `
|
||||
<div class="max-w-3xl bg-white rounded-lg shadow-sm p-4 border border-gray-200">
|
||||
<div class="flex items-start">
|
||||
<div class="w-8 h-8 rounded-full bg-primary flex items-center justify-center flex-shrink-0">
|
||||
<i class="fa fa-robot text-white text-sm"></i>
|
||||
</div>
|
||||
<div class="ml-3">
|
||||
<div class="text-xs text-gray-400 mb-1">AI 助手</div>
|
||||
<div class="typing-indicator">
|
||||
<span></span><span></span><span></span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
|
||||
container.appendChild(loadingDiv);
|
||||
container.scrollTop = container.scrollHeight;
|
||||
|
||||
return loadingId;
|
||||
}
|
||||
|
||||
// 移除加载消息
|
||||
function removeLoadingMessage(id) {
|
||||
const loadingDiv = document.getElementById(id);
|
||||
if (loadingDiv) {
|
||||
loadingDiv.remove();
|
||||
}
|
||||
}
|
||||
|
||||
// 更新发送按钮状态
|
||||
function updateSendButton(loading) {
|
||||
const btn = document.getElementById('sendBtn');
|
||||
if (loading) {
|
||||
btn.innerHTML = '<i class="fa fa-spinner fa-spin mr-2"></i>处理中';
|
||||
btn.disabled = true;
|
||||
btn.classList.add('opacity-50', 'cursor-not-allowed');
|
||||
} else {
|
||||
btn.innerHTML = '<i class="fa fa-paper-plane mr-2"></i>发送';
|
||||
btn.disabled = false;
|
||||
btn.classList.remove('opacity-50', 'cursor-not-allowed');
|
||||
}
|
||||
}
|
||||
|
||||
// 清空对话
|
||||
function clearChat() {
|
||||
const container = document.getElementById('chatContainer');
|
||||
container.innerHTML = `
|
||||
<div class="chat-message flex justify-start">
|
||||
<div class="max-w-3xl bg-white rounded-lg shadow-sm p-4 border border-gray-200">
|
||||
<div class="flex items-start">
|
||||
<div class="w-8 h-8 rounded-full bg-primary flex items-center justify-center flex-shrink-0">
|
||||
<i class="fa fa-robot text-white text-sm"></i>
|
||||
</div>
|
||||
<div class="ml-3">
|
||||
<div class="text-xs text-gray-400 mb-1">AI 助手</div>
|
||||
<div class="text-sm text-gray-700 leading-relaxed">
|
||||
对话已清空,请继续提问。
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
chatHistory = [];
|
||||
}
|
||||
|
||||
// HTML转义
|
||||
function escapeHtml(text) {
|
||||
const div = document.createElement('div');
|
||||
div.textContent = text;
|
||||
return div.innerHTML;
|
||||
}
|
||||
|
||||
// 更新侧边栏滑块位置
|
||||
function updateSidebarSlider() {
|
||||
const slider = document.getElementById('sidebar-slider');
|
||||
if (!slider) return;
|
||||
const activeLink = document.querySelector('.nav-link.bg-\\[\\#1890ff\\]\\/10');
|
||||
if (activeLink) {
|
||||
const wrapper = activeLink.closest('.nav-item-wrapper');
|
||||
if (wrapper) {
|
||||
slider.style.top = wrapper.offsetTop + 'px';
|
||||
slider.style.height = wrapper.offsetHeight + 'px';
|
||||
}
|
||||
console.error('[DEBUG] 合并失败:', error);
|
||||
mergeText.textContent = '合并失败: ' + error.message;
|
||||
mergeBtn.disabled = false;
|
||||
mergeBtn.innerHTML = originalBtnText;
|
||||
}
|
||||
}
|
||||
|
||||
// 暴露到全局
|
||||
window.sendMessage = sendMessage;
|
||||
window.clearChat = clearChat;
|
||||
window.mergeWeights = mergeWeights;
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
Reference in New Issue
Block a user