diff --git a/src/api/model_chat.py b/src/api/model_chat.py index fba96e3..5609e00 100644 --- a/src/api/model_chat.py +++ b/src/api/model_chat.py @@ -4,6 +4,7 @@ import os import pymysql import yaml +import json import requests import concurrent.futures import subprocess @@ -211,6 +212,200 @@ def model_chat_batch(): return jsonify({'code': 0, 'data': results}) +@model_chat_bp.route('/trained/preload', methods=['POST']) +def preload_trained_model(): + """预加载已训练模型(使用 llamafactory)""" + import sys as sys_module + import pymysql + import yaml + import logging + logger = logging.getLogger(__name__) + + data = request.json + model_name = data.get('model_name') # 模型名称 + train_method = data.get('train_method', 'lora') # 训练方法: lora, full + base_model_path = data.get('base_model_path') # 前端传递的基座模型路径 + + if not model_name: + return jsonify({'code': 1, 'message': '缺少模型名称'}) + + logger.info(f"[PRELOAD] 开始预加载模型: {model_name}, 方法: {train_method}") + logger.info(f"[PRELOAD] 前端传递的基座模型路径: {base_model_path}") + + # 获取项目根目录 + PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml') + try: + with open(CONFIG_PATH, 'r', encoding='utf-8') as f: + CONFIG = yaml.safe_load(f) + except Exception as e: + return jsonify({'code': 1, 'message': f'读取配置失败: {str(e)}'}) + + # 优先使用前端传递的基座模型路径,否则从数据库查询 + if not base_model_path: + try: + db_config = CONFIG['database'] + conn = pymysql.connect( + host=db_config['host'], + port=db_config['port'], + user=db_config['username'], + password=db_config['password'], + database=db_config['name'], + charset=db_config.get('charset', 'utf8mb4'), + cursorclass=pymysql.cursors.DictCursor + ) + cursor = conn.cursor() + + # 优先从训练任务表查询基座模型 + logger.info(f"[PRELOAD] 尝试从fine_tune表查询模型: {model_name}") + cursor.execute(""" + SELECT base_model, output_model_name FROM fine_tune + WHERE output_model_name LIKE %s OR output_model_name LIKE %s + LIMIT 1 + """, (f'%/{model_name}', f'%{model_name}%')) + ft_result = cursor.fetchone() + logger.info(f"[PRELOAD] fine_tune查询结果: {ft_result}") + + if ft_result and ft_result.get('base_model'): + base_model_val = ft_result['base_model'] + logger.info(f"[PRELOAD] base_model_val: {base_model_val}") + # 如果是数字ID,查询模型管理表获取路径 + if str(base_model_val).isdigit(): + cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,)) + model_result = cursor.fetchone() + logger.info(f"[PRELOAD] model_manage查询结果(数字ID): {model_result}") + if model_result: + base_model_path = model_result.get('path') + else: + # 直接是路径 + base_model_path = base_model_val + + # 如果训练任务表没找到,尝试从模型管理表按名称查询 + if not base_model_path: + logger.info(f"[PRELOAD] 尝试从model_manage表查询...") + cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,)) + model_result = cursor.fetchone() + logger.info(f"[PRELOAD] model_manage查询结果: {model_result}") + if model_result: + base_model_path = model_result.get('path') + + conn.close() + + if not base_model_path: + logger.error(f"[PRELOAD] 未找到模型 {model_name} 的基座模型配置") + return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'}) + except Exception as e: + logger.error(f"[PRELOAD] 查询模型配置失败: {e}") + return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'}) + else: + logger.info(f"[PRELOAD] 使用前端传递的基座模型路径: {base_model_path}") + + # 训练后的模型路径 + trained_model_path = f"/app/base/saves/{train_method}/{model_name}" + + # 检查路径是否存在(兼容Windows和Linux) + if not os.path.exists(trained_model_path): + logger.warning(f"[PRELOAD] 训练模型路径不存在: {trained_model_path}") + # 尝试查找适配器文件来确定模型是否存在 + adapter_path = os.path.join(trained_model_path, 'adapter_model.bin') + safetensors_path = os.path.join(trained_model_path, 'model.safetensors') + if not os.path.exists(adapter_path) and not os.path.exists(safetensors_path): + return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'}) + + # 预热消息 - 使用一个简单的问候语来加载模型 + work_dir = '/app/base' + warmup_messages = [{'role': 'system', 'content': 'You are a helpful assistant.'}, {'role': 'user', 'content': 'Hello'}] + + # 根据训练方法选择 finetuning_type + finetuning_type = 'lora' if train_method == 'lora' else 'full' + + # 构建 llamafactory 预热脚本 + preload_script = f''' +import sys +import json +import logging +logging.basicConfig(level=logging.WARNING) + +from llmtuner import ChatModel + +def main(): + chat_model = ChatModel({{ + "model_name_or_path": "{base_model_path}", + "adapter_name_or_path": "{trained_model_path}", + "template": "llama3", + "finetuning_type": "{finetuning_type}", + "temperature": 0.1, + "max_new_tokens": 1 + }}) + + messages = {json.dumps(warmup_messages, ensure_ascii=False)} + + # 执行推理以加载模型 + response = chat_model.chat(messages) + print("Model loaded successfully") + +if __name__ == "__main__": + main() +''' + + try: + work_dir = '/app/base' + script_path = os.path.join(work_dir, 'temp_preload.py') + + with open(script_path, 'w', encoding='utf-8') as f: + f.write(preload_script) + + # 设置环境变量,包括 PYTHONPATH 以便找到 llamafactory + import sys as sys_module + env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'} + # 尝试查找 llamafactory 目录并添加到 PYTHONPATH + llm_factory_paths = ['/app/base', '/app', '/app/base/src/llamafactory'] + for path in llm_factory_paths: + if os.path.exists(path) and os.path.exists(os.path.join(path, 'llmtuner')): + env['PYTHONPATH'] = path + logger.info(f"[PRELOAD] 设置 PYTHONPATH={path}") + break + + logger.info(f"[PRELOAD] 执行预热脚本...") + result = subprocess.run( + [sys_module.executable, 'temp_preload.py'], + capture_output=True, + text=True, + timeout=180, + cwd=work_dir, + env=env + ) + + os.remove(script_path) + + logger.info(f"[PRELOAD] 命令返回码: {result.returncode}") + logger.info(f"[PRELOAD] stdout: {result.stdout[:500] if result.stdout else 'empty'}") + logger.info(f"[PRELOAD] stderr: {result.stderr[:500] if result.stderr else 'empty'}") + + if result.returncode == 0: + return jsonify({ + 'code': 0, + 'message': '模型预加载成功', + 'data': { + 'model_name': model_name, + 'train_method': train_method, + 'base_model': base_model_path + } + }) + else: + error_msg = result.stderr.strip() if result.stderr else result.stdout.strip() + if not error_msg: + error_msg = f'命令执行失败,返回码: {result.returncode}' + return jsonify({'code': 1, 'message': f'预加载失败: {error_msg}'}) + + except subprocess.TimeoutExpired: + logger.error("[PRELOAD] 预加载超时") + return jsonify({'code': 1, 'message': '预加载超时,请稍后重试'}) + except Exception as e: + logger.error(f"[PRELOAD] 预加载异常: {str(e)}") + return jsonify({'code': 1, 'message': f'预加载异常: {str(e)}'}) + + @model_chat_bp.route('/trained', methods=['POST']) def chat_trained_model(): """使用已训练模型进行对话推理""" @@ -220,6 +415,7 @@ def chat_trained_model(): data = request.json model_name = data.get('model_name') # 模型名称 train_method = data.get('train_method', 'lora') # 训练方法: lora, full + base_model_path = data.get('base_model_path') # 前端传递的基座模型路径 system_prompt = data.get('system_prompt', '') user_question = data.get('user_question') temperature = data.get('temperature', 0.7) @@ -236,39 +432,64 @@ def chat_trained_model(): with open(CONFIG_PATH, 'r', encoding='utf-8') as f: CONFIG = yaml.safe_load(f) - # 获取基座模型路径 - 从数据库查询 model_name 对应的路径 - try: - db_config = CONFIG['database'] - conn = pymysql.connect( - host=db_config['host'], - port=db_config['port'], - user=db_config['username'], - password=db_config['password'], - database=db_config['name'], - charset=db_config.get('charset', 'utf8mb4'), - cursorclass=pymysql.cursors.DictCursor - ) - cursor = conn.cursor() - cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,)) - model_result = cursor.fetchone() - conn.close() + # 优先使用前端传递的基座模型路径,否则从数据库查询 + if not base_model_path: + try: + db_config = CONFIG['database'] + conn = pymysql.connect( + host=db_config['host'], + port=db_config['port'], + user=db_config['username'], + password=db_config['password'], + database=db_config['name'], + charset=db_config.get('charset', 'utf8mb4'), + cursorclass=pymysql.cursors.DictCursor + ) + cursor = conn.cursor() - if not model_result or not model_result.get('path'): - return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的配置'}) + # 优先从训练任务表查询基座模型 + cursor.execute(""" + SELECT base_model, output_model_name FROM fine_tune + WHERE output_model_name LIKE %s OR output_model_name LIKE %s + LIMIT 1 + """, (f'%/{model_name}', f'%{model_name}%')) + ft_result = cursor.fetchone() - base_model_path = model_result['path'] - except Exception as e: - return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'}) + if ft_result and ft_result.get('base_model'): + base_model_val = ft_result['base_model'] + # 如果是数字ID,查询模型管理表获取路径 + if str(base_model_val).isdigit(): + cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,)) + model_result = cursor.fetchone() + if model_result: + base_model_path = model_result.get('path') + else: + # 直接是路径 + base_model_path = base_model_val + + # 如果训练任务表没找到,尝试从模型管理表按名称查询 + if not base_model_path: + cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,)) + model_result = cursor.fetchone() + if model_result: + base_model_path = model_result.get('path') + + conn.close() + + if not base_model_path: + return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'}) + except Exception as e: + return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'}) # 训练后的模型路径 trained_model_path = f"/app/base/saves/{train_method}/{model_name}" + # 检查路径是否存在(兼容Windows和Linux) if not os.path.exists(trained_model_path): - return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'}) - - # 构建 llamafactory-cli chat 命令 - work_dir = '/app/base' - llamafactory_dir = '/app/base' + adapter_path = os.path.join(trained_model_path, 'adapter_model.bin') + safetensors_path = os.path.join(trained_model_path, 'model.safetensors') + if not os.path.exists(adapter_path) and not os.path.exists(safetensors_path): + return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'}) # 准备消息 messages = [] @@ -276,55 +497,76 @@ def chat_trained_model(): messages.append({'role': 'system', 'content': system_prompt}) messages.append({'role': 'user', 'content': user_question}) - # 将消息转换为 JSON 字符串 - messages_json = json.dumps(messages, ensure_ascii=False) + # 构建 llamafactory 推理脚本 + inference_script = f''' +import sys +import json +import logging +logging.basicConfig(level=logging.WARNING) - # 构建 llamafactory-cli chat 命令 - full_cmd = f'cd {llamafactory_dir} && export CUDA_VISIBLE_DEVICES=0 && echo \'{messages_json}\' | llamafactory-cli chat --model_name_or_path {base_model_path} --adapter_name_or_path {trained_model_path} --template llama3 --finetuning_type lora --temperature {temperature} --max_tokens {max_tokens}' +from llmtuner import ChatModel + +def main(): + chat_model = ChatModel({{ + "model_name_or_path": "{base_model_path}", + "adapter_name_or_path": "{trained_model_path}", + "template": "llama3", + "finetuning_type": "lora", + "temperature": {temperature}, + "max_new_tokens": {max_tokens} + }}) + + messages = {json.dumps(messages, ensure_ascii=False)} + + response = chat_model.chat(messages) + print(response) + +if __name__ == "__main__": + main() +''' + + # 写入临时脚本 + work_dir = '/app/base' + script_path = os.path.join(work_dir, 'temp_inference.py') try: - # 执行命令 + with open(script_path, 'w', encoding='utf-8') as f: + f.write(inference_script) + + # 设置环境变量,包括 PYTHONPATH 以便找到 llamafactory + import sys as sys_module + env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'} + for path in ['/app/base', '/app', '/app/base/src/llamafactory']: + if os.path.exists(path) and os.path.exists(os.path.join(path, 'llmtuner')): + env['PYTHONPATH'] = path + break + + # 执行推理脚本 result = subprocess.run( - full_cmd, - shell=True, + [sys_module.executable, 'temp_inference.py'], capture_output=True, text=True, - timeout=120, - cwd=work_dir + timeout=180, + cwd=work_dir, + env=env ) - output = result.stdout - error = result.stderr + # 清理临时脚本 + os.remove(script_path) - # 解析输出,提取assistant回复 - # llamafactory-cli chat 输出格式通常是: - # <|im_start|>assistant - # xxx - # <|im_end|> - - assistant_content = '' - if '<|im_start|>assistant' in output: - parts = output.split('<|im_start|>assistant') - if len(parts) > 1: - content_part = parts[1].split('<|im_end|>')[0].strip() - # 移除可能存在的换行前缀 - content_part = content_part.lstrip('\n').strip() - assistant_content = content_part - elif result.returncode == 0: - # 如果没有特殊标记,尝试提取最后一部分作为回复 - lines = output.strip().split('\n') - assistant_content = '\n'.join(lines).strip() + if result.returncode == 0: + assistant_content = result.stdout.strip() + return jsonify({ + 'code': 0, + 'data': { + 'model_name': model_name, + 'train_method': train_method, + 'response': assistant_content + } + }) else: - return jsonify({'code': 1, 'message': f'推理失败: {error or output}'}) - - return jsonify({ - 'code': 0, - 'data': { - 'model_name': model_name, - 'train_method': train_method, - 'response': assistant_content - } - }) + error_msg = result.stderr.strip() if result.stderr else result.stdout.strip() + return jsonify({'code': 1, 'message': f'推理失败: {error_msg}'}) except subprocess.TimeoutExpired: return jsonify({'code': 1, 'message': '推理超时,请稍后重试'}) diff --git a/src/api/model_manage.py b/src/api/model_manage.py index b9532cd..7666fc4 100644 --- a/src/api/model_manage.py +++ b/src/api/model_manage.py @@ -47,24 +47,32 @@ def generic_get_all(table_name, order_by='create_time DESC'): def get_model_path_by_name(model_name): """根据模型名称查询模型路径(用于获取基座模型路径)""" + import logging + logger = logging.getLogger(__name__) + logger.info(f"[DEBUG get_model_path_by_name] 查询模型: {model_name}") + try: conn = get_db_connection() cursor = conn.cursor() # 优先从训练任务表查询基座模型 + logger.info(f"[DEBUG get_model_path_by_name] 尝试从fine_tune表查询...") cursor.execute(""" - SELECT base_model FROM fine_tune + SELECT base_model, output_model_name FROM fine_tune WHERE output_model_name LIKE %s OR output_model_name LIKE %s LIMIT 1 """, (f'%/{model_name}', f'%{model_name}%')) ft_result = cursor.fetchone() + logger.info(f"[DEBUG get_model_path_by_name] fine_tune查询结果: {ft_result}") if ft_result and ft_result.get('base_model'): base_model_val = ft_result['base_model'] + logger.info(f"[DEBUG get_model_path_by_name] base_model_val: {base_model_val}") # 如果是数字ID,查询模型管理表获取路径 if str(base_model_val).isdigit(): cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,)) model_result = cursor.fetchone() + logger.info(f"[DEBUG get_model_path_by_name] model_manage查询结果(数字ID): {model_result}") if model_result: cursor.close() conn.close() @@ -76,12 +84,15 @@ def get_model_path_by_name(model_name): return base_model_val # 如果训练任务表没找到,尝试从模型管理表按名称查询 + logger.info(f"[DEBUG get_model_path_by_name] 尝试从model_manage表查询...") cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,)) result = cursor.fetchone() + logger.info(f"[DEBUG get_model_path_by_name] model_manage查询结果: {result}") cursor.close() conn.close() if result: return result.get('path') + logger.info(f"[DEBUG get_model_path_by_name] 未找到任何匹配,返回None") return None except Exception as e: logger.error(f"[ERROR] 查询模型路径失败: {e}") @@ -387,3 +398,138 @@ def get_trained_models(): except Exception as e: logger.error(f"获取已训练模型列表失败: {e}") return jsonify({'code': 1, 'message': str(e)}) + + +# ============ 合并权重接口 ============ + +@model_manage_bp.route('/merge', methods=['POST']) +def merge_model(): + """合并模型权重(将LoRA适配器合并到基座模型)""" + import subprocess + import sys + import logging + logger = logging.getLogger(__name__) + + data = request.json + model_name = data.get('model_name') # 模型名称 + train_method = data.get('train_method', 'lora') # 训练方法 + base_model_path = data.get('base_model_path') # 基座模型路径 + + if not model_name: + return jsonify({'code': 1, 'message': '缺少模型名称'}) + + logger.info(f"[MERGE] 开始合并模型: {model_name}, 方法: {train_method}") + + # 如果没有提供基座模型路径,从数据库查询 + if not base_model_path: + try: + conn = get_db_connection() + cursor = conn.cursor() + + # 优先从训练任务表查询 + cursor.execute(""" + SELECT base_model FROM fine_tune + WHERE output_model_name LIKE %s OR output_model_name LIKE %s + LIMIT 1 + """, (f'%/{model_name}', f'%{model_name}%')) + ft_result = cursor.fetchone() + + if ft_result and ft_result.get('base_model'): + base_model_val = ft_result['base_model'] + if str(base_model_val).isdigit(): + cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,)) + model_result = cursor.fetchone() + if model_result: + base_model_path = model_result.get('path') + else: + base_model_path = base_model_val + + # 如果没找到,尝试从模型管理表按名称查询 + if not base_model_path: + cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,)) + model_result = cursor.fetchone() + if model_result: + base_model_path = model_result.get('path') + + conn.close() + + if not base_model_path: + return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'}) + except Exception as e: + logger.error(f"[MERGE] 查询模型配置失败: {e}") + return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'}) + + # 训练后的模型路径(LoRA适配器) + adapter_path = f"/app/base/saves/{train_method}/{model_name}" + + # 检查路径是否存在 + if not os.path.exists(adapter_path): + return jsonify({'code': 1, 'message': f'训练模型不存在: {adapter_path}'}) + + # 合并后的输出路径 + output_path = f"/app/base/local_trained_models/{model_name}" + + # 创建输出目录 + os.makedirs(output_path, exist_ok=True) + + try: + work_dir = '/app/base' + + # 设置环境变量 + env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'} + + # 使用 llamafactory-cli export 命令(假设已在系统 PATH 中,与训练命令一致) + cli_cmd = ['llamafactory-cli', 'export'] + + # 检查 llamafactory-cli 是否存在 + try: + # 尝试使用 which 命令(Linux/Mac) + subprocess.run(['which', 'llamafactory-cli'], capture_output=True, check=True) + except (subprocess.CalledProcessError, FileNotFoundError): + # Windows 上没有 which 命令,直接尝试执行 + logger.info("[MERGE] which 命令不可用,直接尝试执行 llamafactory-cli") + + # 构建完整命令参数 + export_args = [ + '--model_name_or_path', base_model_path, + '--adapter_name_or_path', adapter_path, + '--export_dir', output_path + ] + + logger.info(f"[MERGE] 执行合并命令: {' '.join(cli_cmd)} {' '.join(export_args)}") + + # 直接执行 llamafactory-cli export 命令 + result = subprocess.run( + cli_cmd + export_args, + capture_output=True, + text=True, + timeout=600, + cwd=work_dir or '/app/base', + env=env + ) + + logger.info(f"[MERGE] 命令返回码: {result.returncode}") + logger.info(f"[MERGE] stdout: {result.stdout[:500] if result.stdout else 'empty'}") + logger.info(f"[MERGE] stderr: {result.stderr[:500] if result.stderr else 'empty'}") + + if result.returncode == 0: + return jsonify({ + 'code': 0, + 'message': f'模型权重已成功合并到 {output_path}', + 'data': { + 'model_name': model_name, + 'output_path': output_path + } + }) + else: + error_msg = result.stderr.strip() if result.stderr else result.stdout.strip() + if not error_msg: + error_msg = f'命令执行失败,返回码: {result.returncode}' + return jsonify({'code': 1, 'message': f'合并失败: {error_msg}'}) + + except subprocess.TimeoutExpired: + logger.error("[MERGE] 合并超时") + return jsonify({'code': 1, 'message': '合并超时,请稍后重试'}) + except Exception as e: + logger.error(f"[MERGE] 合并异常: {str(e)}") + return jsonify({'code': 1, 'message': f'合并异常: {str(e)}'}) diff --git a/web/pages/main.html b/web/pages/main.html index c59c6f0..e4d0821 100644 --- a/web/pages/main.html +++ b/web/pages/main.html @@ -608,7 +608,7 @@ 'edit': '编辑', 'compare': '开始对话', 'chat': '对话', - 'view': '去推理' + 'view': '合并权重' }; // 训练进度缓存 @@ -1108,11 +1108,11 @@ function toggleSelectAll(checkbox, api) { // 使用保存的当前页面数据 if (checkbox.checked) { - // 全选当前页面的所有数据 - currentPageData.forEach(item => selectedItems.add(item.id)); + // 全选当前页面的所有数据(支持 name 或 id) + currentPageData.forEach(item => selectedItems.add(item.name || item.id)); } else { // 取消全选,移除当前页面所有数据的选中状态 - currentPageData.forEach(item => selectedItems.delete(item.id)); + currentPageData.forEach(item => selectedItems.delete(item.name || item.id)); } refreshCurrentPage(); } @@ -1266,14 +1266,16 @@ // 渲染表格页面 function renderTablePage(config, data) { - const createButton = config.hasCreate ? ` - - ` : ''; + // 定义表格列 + const columns = [ + { title: '模型名称', key: 'name' }, + { title: '训练方法', key: 'train_methods', render: (val) => val && val[0] ? val[0].name : '-' }, + { title: '基座模型', key: 'base_model_path', render: (val) => `${val || '-'}` }, + { title: '创建时间', key: 'create_time', render: (val) => val ? new Date(val).toLocaleString('zh-CN') : '-' } + ]; - // 搜索框(模型管理和数据集管理) - const searchBox = (config.api === 'model-manage' || config.api === 'dataset-manage') ? ` + // 搜索框 + const searchBox = (config.api === 'model-manage' || config.api === 'model-manage/trained-models' || config.api === 'dataset-manage') ? `
+ 创建数据集 + + ` : ''; // 批量删除按钮(仅当有选中项时显示) const batchDeleteButton = supportsMultiSelect && selectedItems.size > 0 ? ` @@ -1292,7 +1301,6 @@ ` : ''; - const columns = config.columns; const hasData = data && data.length > 0; // 多选列头 @@ -1334,12 +1342,12 @@ ${hasData ? data.map(item => ` - + ${supportsMultiSelect ? ` + ${selectedItems.has(item.name || item.id) ? 'checked' : ''} + onchange="toggleItemSelection('${item.name || item.id}', '${config.api}')"> ` : ''} ${columns.map(col => ` @@ -1349,33 +1357,15 @@ `).join('')}
- ${config.actions.map(action => { + ${['view', 'delete'].map(action => { let onclick = ''; let btnClass = 'text-primary hover:text-primary/80'; - // 对于 fine-tune 的停止按钮,检查状态 - if (action === 'stop' && config.api === 'fine-tune') { - // 状态为 completed 或 failed 时隐藏停止按钮 - if (item.status === 'completed' || item.status === 'failed') { - return ''; - } - onclick = `stopItem(${item.id})`; - btnClass = 'text-orange-500 hover:text-orange-600'; + if (action === 'view') { + onclick = `viewTrainedModel('${item.name}', '${item.train_methods?.[0]?.name || 'lora'}', '${item.base_model_path || ''}')`; } else if (action === 'delete') { - onclick = `deleteItem('${config.api}', ${item.id})`; + onclick = `deleteItem('${config.api}', '${item.id}')`; btnClass = 'text-danger hover:text-danger/80'; - } else if (action === 'edit') { - onclick = `editItem('${config.api}', ${item.id})`; - } else if (action === 'preview' && config.api === 'dataset-manage') { - onclick = `window.location.href = 'dataset-preview.html?id=${item.id}'`; - } else if (action === 'download' && config.api === 'dataset-manage') { - onclick = `downloadDataset('${item.id}')`; - } else if (action === 'compare' && config.api === 'model-compare') { - onclick = `startCompare(${item.id})`; - } else if (action === 'logs' && config.api === 'fine-tune') { - onclick = `navigateToTrainingLog(${item.id})`; - } else if (action === 'view' && config.api === 'model-manage/trained-models') { - onclick = `viewTrainedModel('${item.name}', '${item.train_methods?.[0]?.name || '-'}', '${item.path || ''}')`; } else { onclick = `showMessage('提示', '${actionLabels[action] || action}功能开发中...', 'info')`; } @@ -3164,10 +3154,55 @@ document.body.style.overflow = ''; } - // 查看已训练模型详情 - 跳转到推理页面 - window.viewTrainedModel = function(name, method, path) { - // 跳转到推理测试页面(main.html在pages目录下,所以直接用文件名) - window.location.href = `model-inference.html?model=${encodeURIComponent(name)}&method=${encodeURIComponent(method)}`; + // 刷新表格数据 - 重新加载当前页面(必须在 viewTrainedModel 之前定义) + window.loadTableData = function() { + const activeLink = document.querySelector('.nav-link.sidebar-item-active'); + if (activeLink) { + loadPage(activeLink.dataset.page); + } + }; + + // 合并模型权重 + window.viewTrainedModel = async function(name, method, path) { + // 显示加载中弹窗 + const loadingModal = document.getElementById('loadingModal'); + if (loadingModal) { + document.getElementById('loadingMessage').textContent = '正在合并模型权重,请稍候...'; + loadingModal.classList.remove('hidden'); + } + + try { + const response = await fetch(`${API_BASE}/model-manage/merge`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + model_name: name, + train_method: method || 'lora', + base_model_path: path + }) + }); + + const result = await response.json(); + + // 隐藏加载弹窗 + if (loadingModal) { + loadingModal.classList.add('hidden'); + } + + if (result.code === 0) { + showMessage('成功', '模型权重合并成功!', 'success'); + // 刷新模型列表 + loadTableData(); + } else { + showMessage('失败', result.message || '合并失败', 'error'); + } + } catch (error) { + console.error('[DEBUG] 合并失败:', error); + if (loadingModal) { + loadingModal.classList.add('hidden'); + } + showMessage('错误', '合并失败: ' + error.message, 'error'); + } }; // 确认弹窗(两个按钮)- 使用 window 确保全局可访问 @@ -3446,5 +3481,15 @@
+ + + diff --git a/web/pages/model-inference.html b/web/pages/model-inference.html index e5b2072..10ada18 100644 --- a/web/pages/model-inference.html +++ b/web/pages/model-inference.html @@ -3,126 +3,71 @@ - 模型推理 - 远光软件微调平台 - + 合并权重 - YG_FT + + - - - -