1. 修改了合并模型导出模型的逻辑

2. 修改了一些冗余的bug
3. 页面上表格的调整
This commit is contained in:
2026-01-29 23:10:21 +08:00
parent 0f98d67e41
commit 03b6071856
10 changed files with 1008 additions and 460 deletions

View File

@@ -0,0 +1,8 @@
{
"123": {
"file_name": "1769495241519_8_liangce_257.json"
},
"liangce": {
"file_name": "1769605160299_1_liangce_257.json"
}
}

View File

@@ -0,0 +1,30 @@
{
"architectures": [
"Qwen3ForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 3072,
"max_position_embeddings": 40960,
"max_window_layers": 28,
"model_type": "qwen3",
"num_attention_heads": 16,
"num_hidden_layers": 28,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 1000000,
"sliding_window": null,
"tie_word_embeddings": true,
"torch_dtype": "bfloat16",
"transformers_version": "4.51.0",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 151936
}

View File

@@ -0,0 +1 @@
{"framework": "pytorch", "task": "text-generation", "allow_remote": true}

View File

@@ -0,0 +1,13 @@
{
"bos_token_id": 151643,
"do_sample": true,
"eos_token_id": [
151645,
151643
],
"pad_token_id": 151643,
"temperature": 0.6,
"top_k": 20,
"top_p": 0.95,
"transformers_version": "4.51.0"
}

View File

@@ -4,6 +4,7 @@
import os
import pymysql
import yaml
import json
import requests
import concurrent.futures
import subprocess
@@ -211,6 +212,200 @@ def model_chat_batch():
return jsonify({'code': 0, 'data': results})
@model_chat_bp.route('/trained/preload', methods=['POST'])
def preload_trained_model():
"""预加载已训练模型(使用 llamafactory"""
import sys as sys_module
import pymysql
import yaml
import logging
logger = logging.getLogger(__name__)
data = request.json
model_name = data.get('model_name') # 模型名称
train_method = data.get('train_method', 'lora') # 训练方法: lora, full
base_model_path = data.get('base_model_path') # 前端传递的基座模型路径
if not model_name:
return jsonify({'code': 1, 'message': '缺少模型名称'})
logger.info(f"[PRELOAD] 开始预加载模型: {model_name}, 方法: {train_method}")
logger.info(f"[PRELOAD] 前端传递的基座模型路径: {base_model_path}")
# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
try:
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
CONFIG = yaml.safe_load(f)
except Exception as e:
return jsonify({'code': 1, 'message': f'读取配置失败: {str(e)}'})
# 优先使用前端传递的基座模型路径,否则从数据库查询
if not base_model_path:
try:
db_config = CONFIG['database']
conn = pymysql.connect(
host=db_config['host'],
port=db_config['port'],
user=db_config['username'],
password=db_config['password'],
database=db_config['name'],
charset=db_config.get('charset', 'utf8mb4'),
cursorclass=pymysql.cursors.DictCursor
)
cursor = conn.cursor()
# 优先从训练任务表查询基座模型
logger.info(f"[PRELOAD] 尝试从fine_tune表查询模型: {model_name}")
cursor.execute("""
SELECT base_model, output_model_name FROM fine_tune
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
LIMIT 1
""", (f'%/{model_name}', f'%{model_name}%'))
ft_result = cursor.fetchone()
logger.info(f"[PRELOAD] fine_tune查询结果: {ft_result}")
if ft_result and ft_result.get('base_model'):
base_model_val = ft_result['base_model']
logger.info(f"[PRELOAD] base_model_val: {base_model_val}")
# 如果是数字ID查询模型管理表获取路径
if str(base_model_val).isdigit():
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
model_result = cursor.fetchone()
logger.info(f"[PRELOAD] model_manage查询结果(数字ID): {model_result}")
if model_result:
base_model_path = model_result.get('path')
else:
# 直接是路径
base_model_path = base_model_val
# 如果训练任务表没找到,尝试从模型管理表按名称查询
if not base_model_path:
logger.info(f"[PRELOAD] 尝试从model_manage表查询...")
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
model_result = cursor.fetchone()
logger.info(f"[PRELOAD] model_manage查询结果: {model_result}")
if model_result:
base_model_path = model_result.get('path')
conn.close()
if not base_model_path:
logger.error(f"[PRELOAD] 未找到模型 {model_name} 的基座模型配置")
return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'})
except Exception as e:
logger.error(f"[PRELOAD] 查询模型配置失败: {e}")
return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'})
else:
logger.info(f"[PRELOAD] 使用前端传递的基座模型路径: {base_model_path}")
# 训练后的模型路径
trained_model_path = f"/app/base/saves/{train_method}/{model_name}"
# 检查路径是否存在兼容Windows和Linux
if not os.path.exists(trained_model_path):
logger.warning(f"[PRELOAD] 训练模型路径不存在: {trained_model_path}")
# 尝试查找适配器文件来确定模型是否存在
adapter_path = os.path.join(trained_model_path, 'adapter_model.bin')
safetensors_path = os.path.join(trained_model_path, 'model.safetensors')
if not os.path.exists(adapter_path) and not os.path.exists(safetensors_path):
return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'})
# 预热消息 - 使用一个简单的问候语来加载模型
work_dir = '/app/base'
warmup_messages = [{'role': 'system', 'content': 'You are a helpful assistant.'}, {'role': 'user', 'content': 'Hello'}]
# 根据训练方法选择 finetuning_type
finetuning_type = 'lora' if train_method == 'lora' else 'full'
# 构建 llamafactory 预热脚本
preload_script = f'''
import sys
import json
import logging
logging.basicConfig(level=logging.WARNING)
from llmtuner import ChatModel
def main():
chat_model = ChatModel({{
"model_name_or_path": "{base_model_path}",
"adapter_name_or_path": "{trained_model_path}",
"template": "llama3",
"finetuning_type": "{finetuning_type}",
"temperature": 0.1,
"max_new_tokens": 1
}})
messages = {json.dumps(warmup_messages, ensure_ascii=False)}
# 执行推理以加载模型
response = chat_model.chat(messages)
print("Model loaded successfully")
if __name__ == "__main__":
main()
'''
try:
work_dir = '/app/base'
script_path = os.path.join(work_dir, 'temp_preload.py')
with open(script_path, 'w', encoding='utf-8') as f:
f.write(preload_script)
# 设置环境变量,包括 PYTHONPATH 以便找到 llamafactory
import sys as sys_module
env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'}
# 尝试查找 llamafactory 目录并添加到 PYTHONPATH
llm_factory_paths = ['/app/base', '/app', '/app/base/src/llamafactory']
for path in llm_factory_paths:
if os.path.exists(path) and os.path.exists(os.path.join(path, 'llmtuner')):
env['PYTHONPATH'] = path
logger.info(f"[PRELOAD] 设置 PYTHONPATH={path}")
break
logger.info(f"[PRELOAD] 执行预热脚本...")
result = subprocess.run(
[sys_module.executable, 'temp_preload.py'],
capture_output=True,
text=True,
timeout=180,
cwd=work_dir,
env=env
)
os.remove(script_path)
logger.info(f"[PRELOAD] 命令返回码: {result.returncode}")
logger.info(f"[PRELOAD] stdout: {result.stdout[:500] if result.stdout else 'empty'}")
logger.info(f"[PRELOAD] stderr: {result.stderr[:500] if result.stderr else 'empty'}")
if result.returncode == 0:
return jsonify({
'code': 0,
'message': '模型预加载成功',
'data': {
'model_name': model_name,
'train_method': train_method,
'base_model': base_model_path
}
})
else:
error_msg = result.stderr.strip() if result.stderr else result.stdout.strip()
if not error_msg:
error_msg = f'命令执行失败,返回码: {result.returncode}'
return jsonify({'code': 1, 'message': f'预加载失败: {error_msg}'})
except subprocess.TimeoutExpired:
logger.error("[PRELOAD] 预加载超时")
return jsonify({'code': 1, 'message': '预加载超时,请稍后重试'})
except Exception as e:
logger.error(f"[PRELOAD] 预加载异常: {str(e)}")
return jsonify({'code': 1, 'message': f'预加载异常: {str(e)}'})
@model_chat_bp.route('/trained', methods=['POST'])
def chat_trained_model():
"""使用已训练模型进行对话推理"""
@@ -220,6 +415,7 @@ def chat_trained_model():
data = request.json
model_name = data.get('model_name') # 模型名称
train_method = data.get('train_method', 'lora') # 训练方法: lora, full
base_model_path = data.get('base_model_path') # 前端传递的基座模型路径
system_prompt = data.get('system_prompt', '')
user_question = data.get('user_question')
temperature = data.get('temperature', 0.7)
@@ -236,39 +432,64 @@ def chat_trained_model():
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
CONFIG = yaml.safe_load(f)
# 获取基座模型路径 - 从数据库查询 model_name 对应的路径
try:
db_config = CONFIG['database']
conn = pymysql.connect(
host=db_config['host'],
port=db_config['port'],
user=db_config['username'],
password=db_config['password'],
database=db_config['name'],
charset=db_config.get('charset', 'utf8mb4'),
cursorclass=pymysql.cursors.DictCursor
)
cursor = conn.cursor()
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
model_result = cursor.fetchone()
conn.close()
# 优先使用前端传递的基座模型路径,否则从数据库查询
if not base_model_path:
try:
db_config = CONFIG['database']
conn = pymysql.connect(
host=db_config['host'],
port=db_config['port'],
user=db_config['username'],
password=db_config['password'],
database=db_config['name'],
charset=db_config.get('charset', 'utf8mb4'),
cursorclass=pymysql.cursors.DictCursor
)
cursor = conn.cursor()
if not model_result or not model_result.get('path'):
return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的配置'})
# 优先从训练任务表查询基座模型
cursor.execute("""
SELECT base_model, output_model_name FROM fine_tune
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
LIMIT 1
""", (f'%/{model_name}', f'%{model_name}%'))
ft_result = cursor.fetchone()
base_model_path = model_result['path']
except Exception as e:
return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'})
if ft_result and ft_result.get('base_model'):
base_model_val = ft_result['base_model']
# 如果是数字ID查询模型管理表获取路径
if str(base_model_val).isdigit():
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
model_result = cursor.fetchone()
if model_result:
base_model_path = model_result.get('path')
else:
# 直接是路径
base_model_path = base_model_val
# 如果训练任务表没找到,尝试从模型管理表按名称查询
if not base_model_path:
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
model_result = cursor.fetchone()
if model_result:
base_model_path = model_result.get('path')
conn.close()
if not base_model_path:
return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'})
except Exception as e:
return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'})
# 训练后的模型路径
trained_model_path = f"/app/base/saves/{train_method}/{model_name}"
# 检查路径是否存在兼容Windows和Linux
if not os.path.exists(trained_model_path):
return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'})
# 构建 llamafactory-cli chat 命令
work_dir = '/app/base'
llamafactory_dir = '/app/base'
adapter_path = os.path.join(trained_model_path, 'adapter_model.bin')
safetensors_path = os.path.join(trained_model_path, 'model.safetensors')
if not os.path.exists(adapter_path) and not os.path.exists(safetensors_path):
return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'})
# 准备消息
messages = []
@@ -276,55 +497,76 @@ def chat_trained_model():
messages.append({'role': 'system', 'content': system_prompt})
messages.append({'role': 'user', 'content': user_question})
# 将消息转换为 JSON 字符串
messages_json = json.dumps(messages, ensure_ascii=False)
# 构建 llamafactory 推理脚本
inference_script = f'''
import sys
import json
import logging
logging.basicConfig(level=logging.WARNING)
# 构建 llamafactory-cli chat 命令
full_cmd = f'cd {llamafactory_dir} && export CUDA_VISIBLE_DEVICES=0 && echo \'{messages_json}\' | llamafactory-cli chat --model_name_or_path {base_model_path} --adapter_name_or_path {trained_model_path} --template llama3 --finetuning_type lora --temperature {temperature} --max_tokens {max_tokens}'
from llmtuner import ChatModel
def main():
chat_model = ChatModel({{
"model_name_or_path": "{base_model_path}",
"adapter_name_or_path": "{trained_model_path}",
"template": "llama3",
"finetuning_type": "lora",
"temperature": {temperature},
"max_new_tokens": {max_tokens}
}})
messages = {json.dumps(messages, ensure_ascii=False)}
response = chat_model.chat(messages)
print(response)
if __name__ == "__main__":
main()
'''
# 写入临时脚本
work_dir = '/app/base'
script_path = os.path.join(work_dir, 'temp_inference.py')
try:
# 执行命令
with open(script_path, 'w', encoding='utf-8') as f:
f.write(inference_script)
# 设置环境变量,包括 PYTHONPATH 以便找到 llamafactory
import sys as sys_module
env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'}
for path in ['/app/base', '/app', '/app/base/src/llamafactory']:
if os.path.exists(path) and os.path.exists(os.path.join(path, 'llmtuner')):
env['PYTHONPATH'] = path
break
# 执行推理脚本
result = subprocess.run(
full_cmd,
shell=True,
[sys_module.executable, 'temp_inference.py'],
capture_output=True,
text=True,
timeout=120,
cwd=work_dir
timeout=180,
cwd=work_dir,
env=env
)
output = result.stdout
error = result.stderr
# 清理临时脚本
os.remove(script_path)
# 解析输出提取assistant回复
# llamafactory-cli chat 输出格式通常是:
# <|im_start|>assistant
# xxx
# <|im_end|>
assistant_content = ''
if '<|im_start|>assistant' in output:
parts = output.split('<|im_start|>assistant')
if len(parts) > 1:
content_part = parts[1].split('<|im_end|>')[0].strip()
# 移除可能存在的换行前缀
content_part = content_part.lstrip('\n').strip()
assistant_content = content_part
elif result.returncode == 0:
# 如果没有特殊标记,尝试提取最后一部分作为回复
lines = output.strip().split('\n')
assistant_content = '\n'.join(lines).strip()
if result.returncode == 0:
assistant_content = result.stdout.strip()
return jsonify({
'code': 0,
'data': {
'model_name': model_name,
'train_method': train_method,
'response': assistant_content
}
})
else:
return jsonify({'code': 1, 'message': f'推理失败: {error or output}'})
return jsonify({
'code': 0,
'data': {
'model_name': model_name,
'train_method': train_method,
'response': assistant_content
}
})
error_msg = result.stderr.strip() if result.stderr else result.stdout.strip()
return jsonify({'code': 1, 'message': f'推理失败: {error_msg}'})
except subprocess.TimeoutExpired:
return jsonify({'code': 1, 'message': '推理超时,请稍后重试'})

View File

@@ -47,24 +47,32 @@ def generic_get_all(table_name, order_by='create_time DESC'):
def get_model_path_by_name(model_name):
"""根据模型名称查询模型路径(用于获取基座模型路径)"""
import logging
logger = logging.getLogger(__name__)
logger.info(f"[DEBUG get_model_path_by_name] 查询模型: {model_name}")
try:
conn = get_db_connection()
cursor = conn.cursor()
# 优先从训练任务表查询基座模型
logger.info(f"[DEBUG get_model_path_by_name] 尝试从fine_tune表查询...")
cursor.execute("""
SELECT base_model FROM fine_tune
SELECT base_model, output_model_name FROM fine_tune
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
LIMIT 1
""", (f'%/{model_name}', f'%{model_name}%'))
ft_result = cursor.fetchone()
logger.info(f"[DEBUG get_model_path_by_name] fine_tune查询结果: {ft_result}")
if ft_result and ft_result.get('base_model'):
base_model_val = ft_result['base_model']
logger.info(f"[DEBUG get_model_path_by_name] base_model_val: {base_model_val}")
# 如果是数字ID查询模型管理表获取路径
if str(base_model_val).isdigit():
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
model_result = cursor.fetchone()
logger.info(f"[DEBUG get_model_path_by_name] model_manage查询结果(数字ID): {model_result}")
if model_result:
cursor.close()
conn.close()
@@ -76,12 +84,15 @@ def get_model_path_by_name(model_name):
return base_model_val
# 如果训练任务表没找到,尝试从模型管理表按名称查询
logger.info(f"[DEBUG get_model_path_by_name] 尝试从model_manage表查询...")
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
result = cursor.fetchone()
logger.info(f"[DEBUG get_model_path_by_name] model_manage查询结果: {result}")
cursor.close()
conn.close()
if result:
return result.get('path')
logger.info(f"[DEBUG get_model_path_by_name] 未找到任何匹配返回None")
return None
except Exception as e:
logger.error(f"[ERROR] 查询模型路径失败: {e}")
@@ -377,6 +388,16 @@ def get_trained_models():
logger.info(f"[DEBUG] 找到 {len(models)} 个已训练模型")
# 检查每个模型是否已合并或正在合并
local_trained_path = os.path.join(PROJECT_ROOT, 'local_trained_models')
for model in models:
model_name = model['name']
merged_path = os.path.join(local_trained_path, model_name)
lock_file = os.path.join(local_trained_path, f'.merging_{model_name}.lock')
model['merged'] = os.path.exists(merged_path)
model['merging'] = os.path.exists(lock_file)
logger.info(f"[DEBUG] 模型 {model_name} 已合并: {model['merged']}, 正在合并: {model['merging']}")
return jsonify({
'code': 0,
'data': {
@@ -387,3 +408,264 @@ def get_trained_models():
except Exception as e:
logger.error(f"获取已训练模型列表失败: {e}")
return jsonify({'code': 1, 'message': str(e)})
# ============ 合并权重接口 ============
@model_manage_bp.route('/merge', methods=['POST'])
def merge_model():
"""合并模型权重将LoRA适配器合并到基座模型"""
import subprocess
import sys
import logging
logger = logging.getLogger(__name__)
data = request.json
model_name = data.get('model_name') # 模型名称
train_method = data.get('train_method', 'lora') # 训练方法
base_model_path = data.get('base_model_path') # 基座模型路径
if not model_name:
return jsonify({'code': 1, 'message': '缺少模型名称'})
logger.info(f"[MERGE] 开始合并模型: {model_name}, 方法: {train_method}")
# 如果没有提供基座模型路径,从数据库查询
if not base_model_path:
try:
conn = get_db_connection()
cursor = conn.cursor()
# 优先从训练任务表查询
cursor.execute("""
SELECT base_model FROM fine_tune
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
LIMIT 1
""", (f'%/{model_name}', f'%{model_name}%'))
ft_result = cursor.fetchone()
if ft_result and ft_result.get('base_model'):
base_model_val = ft_result['base_model']
if str(base_model_val).isdigit():
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
model_result = cursor.fetchone()
if model_result:
base_model_path = model_result.get('path')
else:
base_model_path = base_model_val
# 如果没找到,尝试从模型管理表按名称查询
if not base_model_path:
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
model_result = cursor.fetchone()
if model_result:
base_model_path = model_result.get('path')
conn.close()
if not base_model_path:
return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'})
except Exception as e:
logger.error(f"[MERGE] 查询模型配置失败: {e}")
return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'})
# 训练后的模型路径LoRA适配器
adapter_path = f"/app/base/saves/{train_method}/{model_name}"
# 检查路径是否存在
if not os.path.exists(adapter_path):
return jsonify({'code': 1, 'message': f'训练模型不存在: {adapter_path}'})
# 合并后的输出路径
output_path = f"/app/base/local_trained_models/{model_name}"
# 合并状态锁文件
lock_file = f"/app/base/local_trained_models/.merging_{model_name}.lock"
# 创建输出目录
os.makedirs(output_path, exist_ok=True)
# 创建锁文件表示正在合并中
try:
with open(lock_file, 'w') as f:
f.write('merging')
work_dir = '/app/base'
# 设置环境变量
env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'}
# 使用 llamafactory-cli export 命令(假设已在系统 PATH 中,与训练命令一致)
cli_cmd = ['llamafactory-cli', 'export']
# 检查 llamafactory-cli 是否存在
try:
# 尝试使用 which 命令Linux/Mac
subprocess.run(['which', 'llamafactory-cli'], capture_output=True, check=True)
except (subprocess.CalledProcessError, FileNotFoundError):
# Windows 上没有 which 命令,直接尝试执行
logger.info("[MERGE] which 命令不可用,直接尝试执行 llamafactory-cli")
# 构建完整命令参数
export_args = [
'--model_name_or_path', base_model_path,
'--adapter_name_or_path', adapter_path,
'--export_dir', output_path
]
logger.info(f"[MERGE] 执行合并命令: {' '.join(cli_cmd)} {' '.join(export_args)}")
# 直接执行 llamafactory-cli export 命令
result = subprocess.run(
cli_cmd + export_args,
capture_output=True,
text=True,
timeout=600,
cwd=work_dir or '/app/base',
env=env
)
logger.info(f"[MERGE] 命令返回码: {result.returncode}")
logger.info(f"[MERGE] stdout: {result.stdout[:500] if result.stdout else 'empty'}")
logger.info(f"[MERGE] stderr: {result.stderr[:500] if result.stderr else 'empty'}")
# 等待输出目录完全创建
import time
max_wait = 5 # 最多等待5秒
waited = 0
while not os.path.exists(output_path) and waited < max_wait:
time.sleep(0.5)
waited += 0.5
# 无论成功失败,都删除锁文件
if os.path.exists(lock_file):
os.remove(lock_file)
if result.returncode == 0:
# 确保目录存在才返回成功
if os.path.exists(output_path):
return jsonify({
'code': 0,
'message': f'模型权重已成功合并到 {output_path}',
'data': {
'model_name': model_name,
'output_path': output_path
}
})
else:
return jsonify({'code': 1, 'message': '合并失败:输出目录未创建'})
else:
error_msg = result.stderr.strip() if result.stderr else result.stdout.strip()
if not error_msg:
error_msg = f'命令执行失败,返回码: {result.returncode}'
return jsonify({'code': 1, 'message': f'合并失败: {error_msg}'})
except subprocess.TimeoutExpired:
logger.error("[MERGE] 合并超时")
# 删除锁文件
if os.path.exists(lock_file):
os.remove(lock_file)
return jsonify({'code': 1, 'message': '合并超时,请稍后重试'})
except Exception as e:
logger.error(f"[MERGE] 合并异常: {str(e)}")
return jsonify({'code': 1, 'message': f'合并异常: {str(e)}'})
# ============ 删除已训练模型接口 ============
@model_manage_bp.route('/trained-models/<model_name>', methods=['DELETE'])
def delete_trained_model(model_name):
"""删除已训练模型从local_trained_models目录"""
import shutil
import logging
logger = logging.getLogger(__name__)
try:
# 删除 local_trained_models 目录下的模型
model_path = os.path.join(PROJECT_ROOT, 'local_trained_models', model_name)
if not os.path.exists(model_path):
return jsonify({'code': 1, 'message': f'模型不存在: {model_name}'})
# 删除目录
shutil.rmtree(model_path)
logger.info(f"[DELETE] 已删除模型: {model_path}")
return jsonify({'code': 0, 'message': '删除成功'})
except Exception as e:
logger.error(f"[DELETE] 删除模型失败: {str(e)}")
return jsonify({'code': 1, 'message': f'删除失败: {str(e)}'})
# ============ 导出已训练模型接口 ============
@model_manage_bp.route('/trained-models/<model_name>/export', methods=['GET'])
def export_trained_model(model_name):
"""导出已训练模型打包成zip下载"""
import shutil
import logging
from flask import send_file
logger = logging.getLogger(__name__)
try:
# 优先从 local_trained_models 目录查找(合并后的模型)
model_path = os.path.join(PROJECT_ROOT, 'local_trained_models', model_name)
# 如果本地模型目录不存在,尝试从 saves 目录查找(未合并的模型)
if not os.path.exists(model_path):
# 查找 saves 目录下的模型
saves_path = os.path.join(PROJECT_ROOT, 'saves')
train_methods = ['lora', 'full', 'qlora', 'dpo', 'cpt', 'prefix', 'adapter', 'peft']
for method in train_methods:
potential_path = os.path.join(saves_path, method, model_name)
if os.path.exists(potential_path):
model_path = potential_path
logger.info(f"[EXPORT] 从 saves/{method} 目录找到模型: {model_path}")
break
# 如果还是找不到,返回错误
if not os.path.exists(model_path):
return jsonify({'code': 1, 'message': f'模型不存在: {model_name}'})
# 创建临时 zip 文件
zip_path = os.path.join(PROJECT_ROOT, 'temp_exports')
os.makedirs(zip_path, exist_ok=True)
zip_file = os.path.join(zip_path, f'{model_name}.zip')
# 如果已存在先删除
if os.path.exists(zip_file):
os.remove(zip_file)
# 打包成 zip
shutil.make_archive(zip_file[:-4], 'zip', model_path)
logger.info(f"[EXPORT] 已打包模型: {zip_file}")
# 发送文件给前端
response = send_file(
zip_file,
as_attachment=True,
download_name=f'{model_name}.zip',
mimetype='application/zip'
)
# 注册回调,删除临时文件
def cleanup():
try:
if os.path.exists(zip_file):
os.remove(zip_file)
logger.info(f"[EXPORT] 已清理临时文件: {zip_file}")
except:
pass
# 使用 after_request 清理
@response.call_on_close
def cleanup_after_request():
cleanup()
return response
except Exception as e:
logger.error(f"[EXPORT] 导出模型失败: {str(e)}")
return jsonify({'code': 1, 'message': f'导出失败: {str(e)}'})

View File

@@ -223,6 +223,11 @@
</div>
<script>
// 全局返回函数
function goBack() {
window.location.href = 'main.html?page=dataset-manage';
}
// 使用 IIFE 避免全局变量污染
(function() {
// API 基础地址 - 优先使用 main.html 中定义的全局变量

View File

@@ -429,9 +429,10 @@
createText: '创建训练任务',
columns: [
{ title: '任务名称', key: 'name' },
{ title: '基础模型', key: 'base_model', render: (val, row) => `<span class="model-name-cell" data-model-id="${val}">加载中...</span>` },
{ title: '状态', key: 'status', render: (val) => `<span class="px-2 py-1 rounded text-xs ${val === 'running' ? 'bg-green-100 text-green-700' : val === 'failed' ? 'bg-red-100 text-red-700' : 'bg-gray-100 text-gray-700'}">${val}</span>` },
{ title: '创建时间', key: 'create_time', render: (val) => val ? new Date(val).toLocaleString('zh-CN') : '-' }
{ title: '任务状态', key: 'status', render: (val) => `<span class="px-2 py-1 rounded text-xs ${val === 'running' ? 'bg-green-100 text-green-700' : val === 'failed' ? 'bg-red-100 text-red-700' : 'bg-gray-100 text-gray-700'}">${val}</span>` },
{ title: '训练方式', key: 'train_type', render: (val) => val === 'SFT' ? 'SFT 微调训练' : (val === 'DPO' ? 'DPO 偏好训练' : (val === 'CPT' ? 'CPT 继续预训练' : '-')) },
{ title: '训练模板', key: 'template', render: (val) => val || '-' },
{ title: '基座模型', key: 'base_model', render: (val, row) => `<span class="model-name-cell" data-model-id="${val}">加载中...</span>` }
],
actions: ['stop', 'logs', 'delete']
},
@@ -608,7 +609,7 @@
'edit': '编辑',
'compare': '开始对话',
'chat': '对话',
'view': '去推理'
'view': '合并权重'
};
// 训练进度缓存
@@ -1010,21 +1011,47 @@
// 删除数据
async function deleteItem(api, id) {
showConfirm('确认删除', '确定要删除这条记录吗?', async () => {
// 如果是我的模型,提示删除合并模型
const confirmMessage = api === 'model-manage/trained-models'
? '是否删除合并模型?'
: '确定要删除这条记录吗?';
showConfirm('确认删除', confirmMessage, async () => {
try {
const response = await fetch(`${API_BASE}/${api}/${id}`, {
method: 'DELETE'
});
const result = await response.json();
if (result.code === 0) {
// 刷新当前页面
clearSelection(); // 清除选中状态
const activeLink = document.querySelector('.nav-link.sidebar-item-active');
if (activeLink) {
loadPage(activeLink.dataset.page);
// 如果是我的模型调用删除本地训练模型的API
if (api === 'model-manage/trained-models') {
const response = await fetch(`${API_BASE}/model-manage/trained-models/${id}`, {
method: 'DELETE'
});
const result = await response.json();
if (result.code === 0) {
showMessage('成功', '删除成功', 'success');
// 清除合并状态缓存
sessionStorage.removeItem('merge_status_' + id);
// 刷新当前页面
clearSelection();
const activeLink = document.querySelector('.nav-link.sidebar-item-active');
if (activeLink) {
loadPage(activeLink.dataset.page);
}
} else {
showMessage('错误', result.message || '删除失败', 'error');
}
} else {
showMessage('错误', result.message || '删除失败', 'error');
const response = await fetch(`${API_BASE}/${api}/${id}`, {
method: 'DELETE'
});
const result = await response.json();
if (result.code === 0) {
// 刷新当前页面
clearSelection(); // 清除选中状态
const activeLink = document.querySelector('.nav-link.sidebar-item-active');
if (activeLink) {
loadPage(activeLink.dataset.page);
}
} else {
showMessage('错误', result.message || '删除失败', 'error');
}
}
} catch (error) {
showMessage('错误', '删除失败: ' + error.message, 'error');
@@ -1070,6 +1097,15 @@
loadPage('logs');
}
// 查看调优任务日志 - 跳转到training-log.html页面
function viewFineTuneLogs(taskId, taskName) {
// 保存 taskId 到 sessionStorage
sessionStorage.setItem('trainingLogTaskId', taskId.toString());
sessionStorage.setItem('trainingLogTaskName', taskName);
// 跳转到日志页面
navigateToPage('training-log');
}
// 更新模型用途
async function updateModelPurpose(id, purpose) {
try {
@@ -1108,11 +1144,11 @@
function toggleSelectAll(checkbox, api) {
// 使用保存的当前页面数据
if (checkbox.checked) {
// 全选当前页面的所有数据
currentPageData.forEach(item => selectedItems.add(item.id));
// 全选当前页面的所有数据(支持 name 或 id
currentPageData.forEach(item => selectedItems.add(item.name || item.id));
} else {
// 取消全选,移除当前页面所有数据的选中状态
currentPageData.forEach(item => selectedItems.delete(item.id));
currentPageData.forEach(item => selectedItems.delete(item.name || item.id));
}
refreshCurrentPage();
}
@@ -1266,14 +1302,16 @@
// 渲染表格页面
function renderTablePage(config, data) {
const createButton = config.hasCreate ? `
<button onclick="showCreateModal('${config.api}')" class="bg-primary text-white px-3 py-1.5 rounded text-sm hover:bg-primary/90 transition-colors">
<i class="fa fa-plus mr-1"></i>${config.createText}
</button>
` : '';
// 使用配置中的列定义,如果没有则使用默认列
const columns = config.columns || [
{ title: '模型名称', key: 'name' },
{ title: '训练方法', key: 'train_methods', render: (val) => val && val[0] ? val[0].name : '-' },
{ title: '基座模型', key: 'base_model_path', render: (val) => `<span class="text-xs text-gray-500 truncate block" title="${val}">${val || '-'}</span>` },
{ title: '创建时间', key: 'create_time', render: (val) => val ? new Date(val).toLocaleString('zh-CN') : '-' }
];
// 搜索框(模型管理和数据集管理)
const searchBox = (config.api === 'model-manage' || config.api === 'dataset-manage') ? `
// 搜索框
const searchBox = (config.api === 'model-manage' || config.api === 'model-manage/trained-models' || config.api === 'dataset-manage' || config.api === 'fine-tune') ? `
<div class="relative">
<input type="text" id="tableSearchInput" placeholder="搜索${config.title}..."
class="w-72 pl-9 pr-3 py-1.5 rounded border border-gray-300 text-sm focus:outline-none focus:border-primary focus:ring-1 focus:border-primary"
@@ -1283,7 +1321,22 @@
` : '';
// 是否支持多选(模型管理和数据集管理)
const supportsMultiSelect = config.api === 'model-manage' || config.api === 'dataset-manage';
const supportsMultiSelect = config.api === 'model-manage' || config.api === 'model-manage/trained-models' || config.api === 'dataset-manage' || config.api === 'fine-tune';
// 创建按钮根据API类型决定是否显示
const createButton = config.api === 'model-manage' ? `
<button onclick="showCreateModal('${config.api}')" class="bg-primary text-white px-3 py-1.5 rounded text-sm hover:bg-primary/90 transition-colors">
<i class="fa fa-plus mr-1"></i>新建模型
</button>
` : (config.api === 'dataset-manage' ? `
<button onclick="showCreateModal('${config.api}')" class="bg-primary text-white px-3 py-1.5 rounded text-sm hover:bg-primary/90 transition-colors">
<i class="fa fa-plus mr-1"></i>创建数据集
</button>
` : (config.api === 'fine-tune' ? `
<button onclick="navigateToPage('fine-tune-create')" class="bg-primary text-white px-3 py-1.5 rounded text-sm hover:bg-primary/90 transition-colors">
<i class="fa fa-plus mr-1"></i>新建调优任务
</button>
` : ''));
// 批量删除按钮(仅当有选中项时显示)
const batchDeleteButton = supportsMultiSelect && selectedItems.size > 0 ? `
@@ -1292,7 +1345,6 @@
</button>
` : '';
const columns = config.columns;
const hasData = data && data.length > 0;
// 多选列头
@@ -1312,6 +1364,11 @@
${createButton}
</div>
</div>
${config.api === 'model-manage/trained-models' ? `
<div class="px-4 border-b border-gray-100">
<div class="flex space-x-1" id="modelTabs">
</div>
` : ''}
${supportsMultiSelect ? `
<div id="batchActions" class="px-4 py-2 bg-blue-50 border-b border-blue-100 flex items-center justify-between ${selectedItems.size > 0 ? '' : 'hidden'}">
<div class="flex items-center text-sm text-blue-700">
@@ -1334,12 +1391,12 @@
</thead>
<tbody>
${hasData ? data.map(item => `
<tr class="border-b border-gray-100 table-row-hover ${selectedItems.has(item.id) ? 'bg-blue-50' : ''}">
<tr class="border-b border-gray-100 table-row-hover ${selectedItems.has(item.name || item.id) ? 'bg-blue-50' : ''}">
${supportsMultiSelect ? `
<td class="px-4 py-4 text-sm text-center">
<input type="checkbox" class="w-4 h-4 text-primary rounded border-gray-300 cursor-pointer"
${selectedItems.has(item.id) ? 'checked' : ''}
onchange="toggleItemSelection(${item.id}, '${config.api}')">
${selectedItems.has(item.name || item.id) ? 'checked' : ''}
onchange="toggleItemSelection('${item.name || item.id}', '${config.api}')">
</td>
` : ''}
${columns.map(col => `
@@ -1349,38 +1406,23 @@
`).join('')}
<td class="px-4 py-4 text-sm text-center">
<div class="flex justify-center space-x-2">
${config.actions.map(action => {
let onclick = '';
let btnClass = 'text-primary hover:text-primary/80';
// 对于 fine-tune 的停止按钮,检查状态
if (action === 'stop' && config.api === 'fine-tune') {
// 状态为 completed 或 failed 时隐藏停止按钮
if (item.status === 'completed' || item.status === 'failed') {
return '';
}
onclick = `stopItem(${item.id})`;
btnClass = 'text-orange-500 hover:text-orange-600';
} else if (action === 'delete') {
onclick = `deleteItem('${config.api}', ${item.id})`;
btnClass = 'text-danger hover:text-danger/80';
} else if (action === 'edit') {
onclick = `editItem('${config.api}', ${item.id})`;
} else if (action === 'preview' && config.api === 'dataset-manage') {
onclick = `window.location.href = 'dataset-preview.html?id=${item.id}'`;
} else if (action === 'download' && config.api === 'dataset-manage') {
onclick = `downloadDataset('${item.id}')`;
} else if (action === 'compare' && config.api === 'model-compare') {
onclick = `startCompare(${item.id})`;
} else if (action === 'logs' && config.api === 'fine-tune') {
onclick = `navigateToTrainingLog(${item.id})`;
} else if (action === 'view' && config.api === 'model-manage/trained-models') {
onclick = `viewTrainedModel('${item.name}', '${item.train_methods?.[0]?.name || '-'}', '${item.path || ''}')`;
} else {
onclick = `showMessage('提示', '${actionLabels[action] || action}功能开发中...', 'info')`;
}
return `<button onclick="${onclick}" class="${btnClass}">${actionLabels[action] || action}</button>`;
}).join('')}
${config.api === 'fine-tune' ? `
<button onclick="viewFineTuneLogs('${item.id}', '${item.name}')" class="bg-blue-500 text-white px-3 py-1 rounded text-xs hover:bg-blue-600">查看日志</button>
<button onclick="deleteItem('${config.api}', '${item.id}')" class="bg-red-500 text-white px-3 py-1 rounded text-xs hover:bg-red-600">删除任务</button>
` : (config.api === 'model-manage/trained-models' ? `
${getMergeButtonHtml(item.name, item.train_methods?.[0]?.name || 'lora', item.base_model_path || '', item.merged, item.merging)}
${(item.merged && !item.merging) ? `
<button onclick="exportModel('${item.name}')" class="bg-green-500 text-white px-3 py-1 rounded text-xs hover:bg-green-600">导出权重</button>
<button onclick="deleteItem('${config.api}', '${item.name || item.id}')" class="bg-red-500 text-white px-3 py-1 rounded text-xs hover:bg-red-600">删除</button>
` : ''}
` : (config.api === 'model-manage' ? `
<button onclick="editModel('${item.id}')" class="bg-blue-500 text-white px-3 py-1 rounded text-xs hover:bg-blue-600">编辑</button>
<button onclick="deleteItem('${config.api}', '${item.id}')" class="bg-red-500 text-white px-3 py-1 rounded text-xs hover:bg-red-600">删除</button>
` : (config.api === 'dataset-manage' ? `
<button onclick="previewDataset('${item.id}')" class="bg-blue-500 text-white px-3 py-1 rounded text-xs hover:bg-blue-600">预览</button>
<button onclick="downloadDataset('${item.id}')" class="bg-green-500 text-white px-3 py-1 rounded text-xs hover:bg-green-600">下载</button>
<button onclick="deleteItem('${config.api}', '${item.id}')" class="bg-red-500 text-white px-3 py-1 rounded text-xs hover:bg-red-600">删除</button>
` : '')))}
</div>
</td>
</tr>
@@ -2506,8 +2548,8 @@
}
}
// 返回列表页
function goBack() {
// 返回列表页 - 全局
window.goBack = function() {
if (currentParentPage) {
currentPage = currentParentPage;
currentParentPage = null;
@@ -3164,10 +3206,113 @@
document.body.style.overflow = '';
}
// 查看已训练模型详情 - 跳转到推理页面
// 刷新表格数据 - 重新加载当前页面(必须在 viewTrainedModel 之前定义)
window.loadTableData = function() {
const activeLink = document.querySelector('.nav-link.sidebar-item-active');
if (activeLink) {
loadPage(activeLink.dataset.page);
}
};
// 获取合并按钮HTML根据合并状态显示不同按钮
function getMergeButtonHtml(name, method, path, merged, merging) {
// 优先检查 sessionStorage 中的临时状态(用于前端实时显示)
const tempStatus = sessionStorage.getItem('merge_status_' + name);
console.log('[DEBUG] getMergeButtonHtml:', name, 'tempStatus:', tempStatus, 'merged:', merged, 'merging:', merging);
// 如果前端正在合并中,显示合并中
if (tempStatus === 'merging') {
return `<button class="bg-gray-300 text-gray-500 px-3 py-1 rounded text-xs cursor-not-allowed flex items-center" disabled>
<i class="fa fa-spinner fa-spin mr-1"></i>合并中...
</button>`;
}
// 如果后端返回正在合并中(锁文件存在)
if (merging) {
return `<button class="bg-gray-300 text-gray-500 px-3 py-1 rounded text-xs cursor-not-allowed flex items-center" disabled>
<i class="fa fa-spinner fa-spin mr-1"></i>合并中...
</button>`;
}
// 如果前端成功状态且后端也返回已合并,显示成功
if (tempStatus === 'success' && merged) {
// 清除临时成功状态,让后端状态接管
sessionStorage.removeItem('merge_status_' + name);
return `<button class="bg-gray-300 text-gray-500 px-3 py-1 rounded text-xs cursor-not-allowed" disabled>合并成功</button>`;
}
// 如果前端成功状态但后端返回未合并,说明目录被删除,重置状态
if (tempStatus === 'success' && !merged) {
sessionStorage.removeItem('merge_status_' + name);
}
// 如果后端返回已合并,显示成功
if (merged) {
return `<button class="bg-gray-300 text-gray-500 px-3 py-1 rounded text-xs cursor-not-allowed" disabled>合并成功</button>`;
}
return `<button onclick="startMerge('${name}', '${method}', '${path}')" class="bg-primary text-white px-3 py-1 rounded text-xs hover:bg-primary/90">合并权重</button>`;
}
// 启动合并任务
async function startMerge(name, method, path) {
// 先设置状态为"合并中"(存储到 sessionStorage
sessionStorage.setItem('merge_status_' + name, 'merging');
// 刷新表格显示合并中状态
loadTableData();
try {
const response = await fetch(`${API_BASE}/model-manage/merge`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
model_name: name,
train_method: method || 'lora',
base_model_path: path
})
});
const result = await response.json();
if (result.code === 0) {
// 设置为成功状态,确保即使后端还没更新也能显示成功
sessionStorage.setItem('merge_status_' + name, 'success');
// 延迟刷新表格,用后端真实状态替换前端状态
setTimeout(() => loadTableData(), 1500);
} else {
// 清除合并状态
sessionStorage.removeItem('merge_status_' + name);
showMessage('失败', result.message || '合并失败', 'error');
loadTableData();
}
} catch (error) {
console.error('[DEBUG] 合并失败:', error);
// 清除合并状态
sessionStorage.removeItem('merge_status_' + name);
showMessage('错误', '合并失败: ' + error.message, 'error');
loadTableData();
}
}
// 合并模型权重(保留兼容)
window.viewTrainedModel = function(name, method, path) {
// 跳转到推理测试页面main.html在pages目录下所以直接用文件名
window.location.href = `model-inference.html?model=${encodeURIComponent(name)}&method=${encodeURIComponent(method)}`;
startMerge(name, method, path);
};
// 导出模型权重(打包下载)
function exportModel(modelName) {
// 直接跳转到导出接口下载文件
window.open(`${API_BASE}/model-manage/trained-models/${encodeURIComponent(modelName)}/export`, '_blank');
}
// 编辑模型 - 全局
window.editModel = function(modelId) {
window.location.href = `model-manage-create.html?id=${modelId}`;
};
// 预览数据集 - 全局
window.previewDataset = function(datasetId) {
window.location.href = `dataset-preview.html?id=${datasetId}`;
};
// 下载数据集 - 全局
window.downloadDataset = function(datasetId) {
window.open(`${API_BASE}/dataset-manage/download/${datasetId}`, '_blank');
};
// 确认弹窗(两个按钮)- 使用 window 确保全局可访问
@@ -3446,5 +3591,15 @@
</div>
</div>
</div>
<!-- 加载中弹窗 -->
<div id="loadingModal" class="hidden fixed inset-0 bg-black/50 z-50 flex items-center justify-center">
<div class="bg-white rounded-xl shadow-xl max-w-sm w-full mx-4 overflow-hidden transform transition-all">
<div class="flex flex-col items-center justify-center min-h-[160px] py-6">
<i class="fa fa-spinner fa-spin text-3xl text-primary mb-4"></i>
<p id="loadingMessage" class="text-gray-600 text-sm">正在处理...</p>
</div>
</div>
</div>
</body>
</html>

View File

@@ -3,126 +3,71 @@
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>模型推理 - 远光软件微调平台</title>
<script src="../lib/tailwindcss/tailwind.js"></script>
<title>合并权重 - YG_FT</title>
<link href="../../assets/libs/font-awesome-4.7.0/css/font-awesome.min.css" rel="stylesheet">
<script src="../../assets/libs/tailwindcss.js"></script>
<script>
if (typeof console !== 'undefined' && console.warn) {
const originalWarn = console.warn;
console.warn = function(...args) {
if (args[0] && args[0].includes && args[0].includes('cdn.tailwindcss.com')) {
return;
tailwind.config = {
theme: {
extend: {
colors: {
primary: '#1890ff'
}
}
originalWarn.apply(console, args);
};
}
}
</script>
<link href="../lib/font-awesome/css/font-awesome.min.css" rel="stylesheet">
<style>
:root { --primary: #1890ff; --danger: #f5222d; --success: #52c41a; }
.sidebar-section-title {
padding: 0.5rem 1rem;
font-size: 0.75rem;
color: rgba(191, 203, 217, 0.7);
font-weight: 500;
text-transform: uppercase;
letter-spacing: 0.05em;
}
.nav-link:hover { background-color: rgba(0, 21, 41, 0.2); }
.sidebar-item-active {
background-color: rgba(24, 144, 255, 0.1);
color: #1890ff;
border-left: 4px solid #1890ff;
background-color: #1890ff !important;
color: white !important;
}
/* 侧边栏滑块动画 */
.sidebar-slider {
position: absolute;
width: 4px;
height: 0;
background-color: #1890ff;
border-radius: 0 2px 2px 0;
transition: top 0.3s cubic-bezier(0.4, 0, 0.2, 1),
height 0.3s cubic-bezier(0.4, 0, 0.2, 1);
pointer-events: none;
z-index: 10;
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
}
.nav-item-wrapper {
position: relative;
}
.nav-link {
position: relative;
z-index: 1;
}
.chat-message { animation: fadeIn 0.3s ease; }
@keyframes fadeIn { from { opacity: 0; transform: translateY(10px); } to { opacity: 1; transform: translateY(0); } }
.typing-indicator span {
display: inline-block;
width: 6px;
height: 6px;
background-color: #1890ff;
border-radius: 50%;
margin: 0 2px;
animation: typing 1.4s infinite ease-in-out;
}
.typing-indicator span:nth-child(2) { animation-delay: 0.2s; }
.typing-indicator span:nth-child(3) { animation-delay: 0.4s; }
@keyframes typing { 0%, 60%, 100% { transform: translateY(0); } 30% { transform: translateY(-8px); } }
.bg-primary { background-color: #1890ff; }
.text-primary { color: #1890ff; }
.border-primary { border-color: #1890ff; }
</style>
</head>
<body class="antialiased bg-gray-100 flex h-screen overflow-hidden">
<!-- 侧边导航 -->
<aside class="w-64 text-[#bfcbd9] flex-shrink-0 hidden md:block flex flex-col h-full" style="background-color: #001529;">
<!-- 平台LOGO区域 -->
<div class="pt-5 pb-3 border-b border-[#001529]/30 flex items-center justify-center pl-2">
<img src="../assets/logo/logo.png" alt="Logo" class="w-8 h-8 object-contain mr-2">
<span class="text-white font-medium text-base">远光软件微调平台</span>
<body class="bg-gray-100 h-screen flex overflow-hidden">
<!-- 侧边 -->
<aside class="w-56 bg-[#001529] text-white flex flex-col shrink-0">
<div class="h-14 flex items-center px-4 border-b border-[#001529]/30">
<i class="fa fa-cube text-primary text-xl"></i>
<span class="ml-2 font-medium text-lg">YG_FT</span>
</div>
<!-- 导航主区域 -->
<nav class="flex-1 overflow-y-auto py-2 relative">
<!-- 滑块指示器 -->
<div class="sidebar-slider" id="sidebar-slider"></div>
<!-- 第一分区:模型服务 -->
<div class="sidebar-section-title">模型服务</div>
<div class="nav-item-wrapper">
<a href="#" data-page="fine-tune" class="nav-link flex items-center px-4 py-2.5 hover:bg-[#001529]/20 transition-colors">
<i class="fa fa-cogs w-5 text-center"></i>
<span class="ml-2">模型调优</span>
</a>
</div>
<nav class="flex-1 overflow-y-auto py-4">
<!-- 第一分区:模型管理 -->
<div class="sidebar-section-title">模型管理</div>
<div class="nav-item-wrapper">
<a href="main.html?page=my-models" data-page="my-models" class="nav-link flex items-center px-4 py-2.5 hover:bg-[#001529]/20 transition-colors">
<i class="fa fa-database w-5 text-center"></i>
<i class="fa fa-cube w-5 text-center"></i>
<span class="ml-2">我的模型</span>
</a>
</div>
<div class="nav-item-wrapper">
<a href="#" data-page="model-eval" class="nav-link flex items-center px-4 py-2.5 hover:bg-[#001529]/20 transition-colors">
<i class="fa fa-line-chart w-5 text-center"></i>
<span class="ml-2">模型评测</span>
<a href="main.html?page=model-create" data-page="model-create" class="nav-link flex items-center px-4 py-2.5 hover:bg-[#001529]/20 transition-colors">
<i class="fa fa-plus w-5 text-center"></i>
<span class="ml-2">添加模型</span>
</a>
</div>
<div class="nav-item-wrapper">
<a href="#" data-page="model-compare" class="nav-link flex items-center px-4 py-2.5 hover:bg-[#001529]/20 transition-colors">
<i class="fa fa-server w-5 text-center"></i>
<a href="main.html?page=model-compare" data-page="model-compare" class="nav-link flex items-center px-4 py-2.5 hover:bg-[#001529]/20 transition-colors">
<i class="fa fa-clone w-5 text-center"></i>
<span class="ml-2">模型对比</span>
</a>
</div>
<!-- 第二分区:资源管理 -->
<div class="sidebar-section-title mt-6">资源管理</div>
<!-- 第二分区:训练管理 -->
<div class="sidebar-section-title mt-6">训练管理</div>
<div class="nav-item-wrapper">
<a href="#" data-page="model-manage" class="nav-link flex items-center px-4 py-2.5 hover:bg-[#001529]/20 transition-colors">
<i class="fa fa-cube w-5 text-center"></i>
<span class="ml-2">模型管理</span>
<a href="main.html?page=fine-tune" data-page="fine-tune" class="nav-link flex items-center px-4 py-2.5 hover:bg-[#001529]/20 transition-colors">
<i class="fa fa-magic w-5 text-center"></i>
<span class="ml-2">训练任务</span>
</a>
</div>
<div class="nav-item-wrapper">
<a href="#" data-page="dataset-manage" class="nav-link flex items-center px-4 py-2.5 hover:bg-[#001529]/20 transition-colors">
<i class="fa fa-file-text w-5 text-center"></i>
<a href="main.html?page=data-manage" data-page="data-manage" class="nav-link flex items-center px-4 py-2.5 hover:bg-[#001529]/20 transition-colors">
<i class="fa fa-database w-5 text-center"></i>
<span class="ml-2">数据集管理</span>
</a>
</div>
@@ -177,7 +122,7 @@
</div>
</header>
<!-- 对话区域 -->
<!-- 主内容 -->
<main class="flex-1 flex flex-col overflow-hidden bg-gray-50">
<!-- 模型信息栏 -->
<div class="bg-white border-b border-gray-200 px-6 py-3 flex items-center justify-between">
@@ -191,57 +136,35 @@
<span id="baseModel" class="text-sm text-gray-700">-</span>
</div>
</div>
<div class="flex items-center space-x-4">
<div class="flex items-center">
<label class="text-xs text-gray-400 mr-2">Temperature:</label>
<input type="range" id="temperature" min="0" max="1" step="0.1" value="0.7" class="w-24">
<span id="temperatureValue" class="ml-2 text-sm text-gray-600">0.7</span>
</div>
<!-- 合并权重区域 -->
<div id="mergeContainer" class="flex-1 overflow-y-auto p-6">
<!-- 合并状态 -->
<div id="mergeState" class="flex flex-col items-center justify-center h-full text-center py-12">
<div class="w-16 h-16 rounded-full bg-blue-100 flex items-center justify-center mb-4">
<i class="fa fa-compress text-2xl text-blue-500"></i>
</div>
<button onclick="clearChat()" class="text-xs text-gray-500 hover:text-gray-700">
<i class="fa fa-trash-o mr-1"></i>清空对话
<h3 class="text-lg font-medium text-gray-600 mb-2">合并模型权重</h3>
<p class="text-sm text-gray-400 max-w-sm mb-4" id="mergeText">将LoRA适配器权重合并到基座模型</p>
<button onclick="mergeWeights()" id="mergeBtn"
class="px-6 py-2.5 bg-primary text-white rounded-lg hover:bg-primary/90 transition-colors flex items-center">
<i class="fa fa-compress mr-2"></i>
<span>开始合并</span>
</button>
</div>
</div>
<!-- 聊天消息区域 -->
<div id="chatContainer" class="flex-1 overflow-y-auto p-6 space-y-4">
<!-- 欢迎消息 -->
<div class="chat-message flex justify-start">
<div class="max-w-3xl bg-white rounded-lg shadow-sm p-4 border border-gray-200">
<div class="flex items-start">
<div class="w-8 h-8 rounded-full bg-primary flex items-center justify-center flex-shrink-0">
<i class="fa fa-robot text-white text-sm"></i>
</div>
<div class="ml-3">
<div class="text-xs text-gray-400 mb-1">AI 助手</div>
<div class="text-sm text-gray-700 leading-relaxed">
您好我是基于该模型训练的AI助手。请在下方输入您的问题我会尽力为您解答。
</div>
</div>
</div>
</div>
</div>
</div>
<!-- 输入区域 -->
<div class="bg-white border-t border-gray-200 p-4">
<div class="max-w-4xl mx-auto">
<div class="flex items-stretch gap-3">
<div class="flex-1">
<textarea id="userInput" placeholder="请输入您的问题..." rows="2"
class="w-full px-4 py-2.5 border border-gray-300 rounded-lg focus:border-primary focus:outline-none resize-none text-sm h-full"
onkeydown="if(event.key === 'Enter' && !event.shiftKey) { event.preventDefault(); sendMessage(); }"></textarea>
</div>
<button onclick="sendMessage()" id="sendBtn"
class="px-5 py-2 bg-primary text-white rounded-lg hover:bg-primary/90 transition-colors flex items-center justify-center h-auto self-center">
<i class="fa fa-paper-plane mr-1.5"></i>
<span>发送</span>
</button>
</div>
<div class="mt-2 text-xs text-gray-400 flex items-center">
<i class="fa fa-info-circle mr-1"></i>
按 Enter 发送Shift+Enter 换行
<!-- 合并结果 -->
<div id="mergeResult" class="hidden flex flex-col items-center justify-center h-full text-center py-12">
<div class="w-16 h-16 rounded-full bg-green-100 flex items-center justify-center mb-4">
<i class="fa fa-check-circle text-2xl text-green-500"></i>
</div>
<h3 class="text-lg font-medium text-gray-600 mb-2">合并完成</h3>
<p class="text-sm text-gray-400 max-w-sm mb-4" id="resultText"></p>
<button onclick="location.reload()"
class="px-6 py-2.5 bg-primary text-white rounded-lg hover:bg-primary/90 transition-colors">
<i class="fa fa-refresh mr-2"></i>
<span>返回</span>
</button>
</div>
</div>
</main>
@@ -262,11 +185,9 @@
trainMethod: '',
baseModel: ''
};
let chatHistory = [];
let isLoading = false;
// 页面初始化
document.addEventListener('DOMContentLoaded', function() {
document.addEventListener('DOMContentLoaded', async function() {
// 解析URL参数
const urlParams = new URLSearchParams(window.location.search);
const modelName = urlParams.get('model');
@@ -289,6 +210,11 @@
document.getElementById('trainMethod').textContent = methodDisplay[currentModel.trainMethod] || currentModel.trainMethod;
}
// 加载模型信息
if (currentModel.name) {
await loadModelInfo();
}
// 绑定侧边栏导航点击事件
document.querySelectorAll('.nav-link').forEach(link => {
link.addEventListener('click', function(e) {
@@ -306,217 +232,97 @@
}
});
updateSidebarSlider();
// 绑定温度滑块
const tempSlider = document.getElementById('temperature');
const tempValue = document.getElementById('temperatureValue');
tempSlider.addEventListener('input', function() {
tempValue.textContent = this.value;
});
// 聚焦输入框
document.getElementById('userInput').focus();
console.log('[DEBUG] 模型推理页面初始化:', currentModel);
});
// 发送消息
async function sendMessage() {
const input = document.getElementById('userInput');
const message = input.value.trim();
// 侧边栏高亮
function updateSidebarSlider() {
const activeItem = document.querySelector('.sidebar-item-active');
const slider = document.getElementById('sidebarSlider');
const wrapper = document.querySelector('.nav-item-wrapper');
if (activeItem && slider && wrapper) {
slider.style.display = 'block';
slider.style.width = wrapper.offsetWidth + 'px';
slider.style.top = wrapper.offsetTop + 'px';
slider.style.height = wrapper.offsetHeight + 'px';
}
}
if (!message) return;
if (isLoading) return;
// 加载模型信息
async function loadModelInfo() {
try {
const response = await fetch(`${API_BASE}/model-manage/trained-models`);
const result = await response.json();
if (result.code !== 0) {
console.error('[DEBUG] 获取模型列表失败');
return;
}
const models = result.data?.models || [];
const modelInfo = models.find(m => m.name === currentModel.name);
console.log('[DEBUG] 模型查找:', currentModel.name, modelInfo);
if (modelInfo) {
currentModel.baseModel = modelInfo.base_model_path || '';
document.getElementById('baseModel').textContent = currentModel.baseModel || '未知';
} else {
document.getElementById('baseModel').textContent = '未知';
}
} catch (error) {
console.error('[DEBUG] 加载模型信息失败:', error);
document.getElementById('baseModel').textContent = '未知';
}
}
// 合并权重
async function mergeWeights() {
if (!currentModel.name) {
alert('请先选择模型');
alert('模型参数无效');
return;
}
// 添加用户消息
addMessage('user', message);
chatHistory.push({ role: 'user', content: message });
const mergeBtn = document.getElementById('mergeBtn');
const mergeText = document.getElementById('mergeText');
const originalBtnText = mergeBtn.innerHTML;
// 清空输入框
input.value = '';
// 显示加载状态
isLoading = true;
updateSendButton(true);
// 添加AI加载提示
const loadingId = addLoadingMessage();
mergeBtn.disabled = true;
mergeBtn.innerHTML = '<i class="fa fa-spinner fa-spin mr-2"></i>合并中...';
mergeText.textContent = '正在合并模型权重,请稍候...';
try {
const response = await fetch(`${API_BASE}/model-chat/trained`, {
const response = await fetch(`${API_BASE}/model-manage/merge`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
model_name: currentModel.name,
train_method: currentModel.trainMethod || 'lora',
system_prompt: '',
user_question: message,
temperature: parseFloat(document.getElementById('temperature').value),
max_tokens: 2048
base_model_path: currentModel.baseModel
})
});
const result = await response.json();
// 移除加载提示
removeLoadingMessage(loadingId);
console.log('[DEBUG] 合并响应:', result);
if (result.code === 0) {
const aiResponse = result.data?.response || '(无回复)';
addMessage('assistant', aiResponse);
chatHistory.push({ role: 'assistant', content: aiResponse });
mergeText.textContent = '合并成功!';
document.getElementById('mergeState').classList.add('hidden');
document.getElementById('mergeResult').classList.remove('hidden');
document.getElementById('resultText').textContent = result.message || '模型权重已成功合并';
} else {
addMessage('assistant', `推理失败: ${result.message || '未知错误'}`);
mergeText.textContent = result.message || '合并失败';
mergeBtn.disabled = false;
mergeBtn.innerHTML = originalBtnText;
}
} catch (error) {
removeLoadingMessage(loadingId);
addMessage('assistant', `请求异常: ${error.message}`);
} finally {
isLoading = false;
updateSendButton(false);
}
}
// 添加消息到聊天框
function addMessage(role, content) {
const container = document.getElementById('chatContainer');
const isUser = role === 'user';
const messageDiv = document.createElement('div');
messageDiv.className = 'chat-message flex justify-' + (isUser ? 'end' : 'start');
const avatar = isUser
? `<div class="w-8 h-8 rounded-full bg-gray-400 flex items-center justify-center flex-shrink-0">
<i class="fa fa-user text-white text-sm"></i>
</div>`
: `<div class="w-8 h-8 rounded-full bg-primary flex items-center justify-center flex-shrink-0">
<i class="fa fa-robot text-white text-sm"></i>
</div>`;
const bgClass = isUser ? 'bg-primary text-white' : 'bg-white border border-gray-200';
messageDiv.innerHTML = `
<div class="max-w-3xl ${bgClass} rounded-lg shadow-sm p-4">
<div class="flex items-start ${isUser ? 'flex-row-reverse' : ''}">
${avatar}
<div class="ml-3 ${isUser ? 'text-right' : ''} flex-1">
<div class="text-xs ${isUser ? 'text-gray-300' : 'text-gray-400'} mb-1">
${isUser ? '你' : 'AI 助手'}
</div>
<div class="text-sm leading-relaxed whitespace-pre-wrap">${escapeHtml(content)}</div>
</div>
</div>
</div>
`;
container.appendChild(messageDiv);
container.scrollTop = container.scrollHeight;
}
// 添加加载中的消息
function addLoadingMessage() {
const container = document.getElementById('chatContainer');
const loadingId = 'loading-' + Date.now();
const loadingDiv = document.createElement('div');
loadingDiv.id = loadingId;
loadingDiv.className = 'chat-message flex justify-start';
loadingDiv.innerHTML = `
<div class="max-w-3xl bg-white rounded-lg shadow-sm p-4 border border-gray-200">
<div class="flex items-start">
<div class="w-8 h-8 rounded-full bg-primary flex items-center justify-center flex-shrink-0">
<i class="fa fa-robot text-white text-sm"></i>
</div>
<div class="ml-3">
<div class="text-xs text-gray-400 mb-1">AI 助手</div>
<div class="typing-indicator">
<span></span><span></span><span></span>
</div>
</div>
</div>
</div>
`;
container.appendChild(loadingDiv);
container.scrollTop = container.scrollHeight;
return loadingId;
}
// 移除加载消息
function removeLoadingMessage(id) {
const loadingDiv = document.getElementById(id);
if (loadingDiv) {
loadingDiv.remove();
}
}
// 更新发送按钮状态
function updateSendButton(loading) {
const btn = document.getElementById('sendBtn');
if (loading) {
btn.innerHTML = '<i class="fa fa-spinner fa-spin mr-2"></i>处理中';
btn.disabled = true;
btn.classList.add('opacity-50', 'cursor-not-allowed');
} else {
btn.innerHTML = '<i class="fa fa-paper-plane mr-2"></i>发送';
btn.disabled = false;
btn.classList.remove('opacity-50', 'cursor-not-allowed');
}
}
// 清空对话
function clearChat() {
const container = document.getElementById('chatContainer');
container.innerHTML = `
<div class="chat-message flex justify-start">
<div class="max-w-3xl bg-white rounded-lg shadow-sm p-4 border border-gray-200">
<div class="flex items-start">
<div class="w-8 h-8 rounded-full bg-primary flex items-center justify-center flex-shrink-0">
<i class="fa fa-robot text-white text-sm"></i>
</div>
<div class="ml-3">
<div class="text-xs text-gray-400 mb-1">AI 助手</div>
<div class="text-sm text-gray-700 leading-relaxed">
对话已清空,请继续提问。
</div>
</div>
</div>
</div>
</div>
`;
chatHistory = [];
}
// HTML转义
function escapeHtml(text) {
const div = document.createElement('div');
div.textContent = text;
return div.innerHTML;
}
// 更新侧边栏滑块位置
function updateSidebarSlider() {
const slider = document.getElementById('sidebar-slider');
if (!slider) return;
const activeLink = document.querySelector('.nav-link.bg-\\[\\#1890ff\\]\\/10');
if (activeLink) {
const wrapper = activeLink.closest('.nav-item-wrapper');
if (wrapper) {
slider.style.top = wrapper.offsetTop + 'px';
slider.style.height = wrapper.offsetHeight + 'px';
}
console.error('[DEBUG] 合并失败:', error);
mergeText.textContent = '合并失败: ' + error.message;
mergeBtn.disabled = false;
mergeBtn.innerHTML = originalBtnText;
}
}
// 暴露到全局
window.sendMessage = sendMessage;
window.clearChat = clearChat;
window.mergeWeights = mergeWeights;
</script>
</body>
</html>

View File

@@ -210,7 +210,7 @@
<div class="flex items-center text-sm">
<span id="breadcrumbParent" class="text-primary cursor-pointer hover:underline" onclick="goBack()">模型管理</span>
<span class="mx-2 text-gray-300">/</span>
<span class="text-gray-800 font-medium">添加模型</span>
<span id="pageTitle" class="text-gray-800 font-medium">添加模型</span>
</div>
</div>
@@ -382,6 +382,12 @@
if (breadcrumbParent) {
breadcrumbParent.textContent = '模型管理';
}
// 修改页面标题
const pageTitle = document.getElementById('pageTitle');
if (pageTitle) {
pageTitle.textContent = '编辑模型';
}
document.title = '编辑模型 / 远光软件微调平台';
// 修改按钮文字
const saveBtn = document.querySelector('button[onclick="submitForm()"]');
if (saveBtn) {