From 03b6071856abd5efe5596e3ab0e3824968ad4a02 Mon Sep 17 00:00:00 2001 From: "WIN-JHFT4D3SIVT\\caoxiaozhu" Date: Thu, 29 Jan 2026 23:10:21 +0800 Subject: [PATCH] =?UTF-8?q?1.=20=E4=BF=AE=E6=94=B9=E4=BA=86=E5=90=88?= =?UTF-8?q?=E5=B9=B6=E6=A8=A1=E5=9E=8B=E5=AF=BC=E5=87=BA=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E7=9A=84=E9=80=BB=E8=BE=91=202.=20=E4=BF=AE=E6=94=B9=E4=BA=86?= =?UTF-8?q?=E4=B8=80=E4=BA=9B=E5=86=97=E4=BD=99=E7=9A=84bug=203.=20?= =?UTF-8?q?=E9=A1=B5=E9=9D=A2=E4=B8=8A=E8=A1=A8=E6=A0=BC=E7=9A=84=E8=B0=83?= =?UTF-8?q?=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- datasets/dataset_info.json | 8 + local_models/Qwen3-0.6B/config.json | 30 ++ local_models/Qwen3-0.6B/configuration.json | 1 + .../Qwen3-0.6B/generation_config.json | 13 + src/api/model_chat.py | 372 +++++++++++--- src/api/model_manage.py | 284 ++++++++++- web/pages/dataset-preview.html | 5 + web/pages/main.html | 291 ++++++++--- web/pages/model-inference.html | 456 +++++------------- web/pages/model-manage-create.html | 8 +- 10 files changed, 1008 insertions(+), 460 deletions(-) create mode 100644 datasets/dataset_info.json create mode 100644 local_models/Qwen3-0.6B/config.json create mode 100644 local_models/Qwen3-0.6B/configuration.json create mode 100644 local_models/Qwen3-0.6B/generation_config.json diff --git a/datasets/dataset_info.json b/datasets/dataset_info.json new file mode 100644 index 0000000..7043e33 --- /dev/null +++ b/datasets/dataset_info.json @@ -0,0 +1,8 @@ +{ + "123": { + "file_name": "1769495241519_8_liangce_257.json" + }, + "liangce": { + "file_name": "1769605160299_1_liangce_257.json" + } +} \ No newline at end of file diff --git a/local_models/Qwen3-0.6B/config.json b/local_models/Qwen3-0.6B/config.json new file mode 100644 index 0000000..f5c3703 --- /dev/null +++ b/local_models/Qwen3-0.6B/config.json @@ -0,0 +1,30 @@ +{ + "architectures": [ + "Qwen3ForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 3072, + "max_position_embeddings": 40960, + "max_window_layers": 28, + "model_type": "qwen3", + "num_attention_heads": 16, + "num_hidden_layers": 28, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": null, + "tie_word_embeddings": true, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936 +} \ No newline at end of file diff --git a/local_models/Qwen3-0.6B/configuration.json b/local_models/Qwen3-0.6B/configuration.json new file mode 100644 index 0000000..bbeeda1 --- /dev/null +++ b/local_models/Qwen3-0.6B/configuration.json @@ -0,0 +1 @@ +{"framework": "pytorch", "task": "text-generation", "allow_remote": true} \ No newline at end of file diff --git a/local_models/Qwen3-0.6B/generation_config.json b/local_models/Qwen3-0.6B/generation_config.json new file mode 100644 index 0000000..20a8a91 --- /dev/null +++ b/local_models/Qwen3-0.6B/generation_config.json @@ -0,0 +1,13 @@ +{ + "bos_token_id": 151643, + "do_sample": true, + "eos_token_id": [ + 151645, + 151643 + ], + "pad_token_id": 151643, + "temperature": 0.6, + "top_k": 20, + "top_p": 0.95, + "transformers_version": "4.51.0" +} \ No newline at end of file diff --git a/src/api/model_chat.py b/src/api/model_chat.py index fba96e3..5609e00 100644 --- a/src/api/model_chat.py +++ b/src/api/model_chat.py @@ -4,6 +4,7 @@ import os import pymysql import yaml +import json import requests import concurrent.futures import subprocess @@ -211,6 +212,200 @@ def model_chat_batch(): return jsonify({'code': 0, 'data': results}) +@model_chat_bp.route('/trained/preload', methods=['POST']) +def preload_trained_model(): + """预加载已训练模型(使用 llamafactory)""" + import sys as sys_module + import pymysql + import yaml + import logging + logger = logging.getLogger(__name__) + + data = request.json + model_name = data.get('model_name') # 模型名称 + train_method = data.get('train_method', 'lora') # 训练方法: lora, full + base_model_path = data.get('base_model_path') # 前端传递的基座模型路径 + + if not model_name: + return jsonify({'code': 1, 'message': '缺少模型名称'}) + + logger.info(f"[PRELOAD] 开始预加载模型: {model_name}, 方法: {train_method}") + logger.info(f"[PRELOAD] 前端传递的基座模型路径: {base_model_path}") + + # 获取项目根目录 + PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml') + try: + with open(CONFIG_PATH, 'r', encoding='utf-8') as f: + CONFIG = yaml.safe_load(f) + except Exception as e: + return jsonify({'code': 1, 'message': f'读取配置失败: {str(e)}'}) + + # 优先使用前端传递的基座模型路径,否则从数据库查询 + if not base_model_path: + try: + db_config = CONFIG['database'] + conn = pymysql.connect( + host=db_config['host'], + port=db_config['port'], + user=db_config['username'], + password=db_config['password'], + database=db_config['name'], + charset=db_config.get('charset', 'utf8mb4'), + cursorclass=pymysql.cursors.DictCursor + ) + cursor = conn.cursor() + + # 优先从训练任务表查询基座模型 + logger.info(f"[PRELOAD] 尝试从fine_tune表查询模型: {model_name}") + cursor.execute(""" + SELECT base_model, output_model_name FROM fine_tune + WHERE output_model_name LIKE %s OR output_model_name LIKE %s + LIMIT 1 + """, (f'%/{model_name}', f'%{model_name}%')) + ft_result = cursor.fetchone() + logger.info(f"[PRELOAD] fine_tune查询结果: {ft_result}") + + if ft_result and ft_result.get('base_model'): + base_model_val = ft_result['base_model'] + logger.info(f"[PRELOAD] base_model_val: {base_model_val}") + # 如果是数字ID,查询模型管理表获取路径 + if str(base_model_val).isdigit(): + cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,)) + model_result = cursor.fetchone() + logger.info(f"[PRELOAD] model_manage查询结果(数字ID): {model_result}") + if model_result: + base_model_path = model_result.get('path') + else: + # 直接是路径 + base_model_path = base_model_val + + # 如果训练任务表没找到,尝试从模型管理表按名称查询 + if not base_model_path: + logger.info(f"[PRELOAD] 尝试从model_manage表查询...") + cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,)) + model_result = cursor.fetchone() + logger.info(f"[PRELOAD] model_manage查询结果: {model_result}") + if model_result: + base_model_path = model_result.get('path') + + conn.close() + + if not base_model_path: + logger.error(f"[PRELOAD] 未找到模型 {model_name} 的基座模型配置") + return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'}) + except Exception as e: + logger.error(f"[PRELOAD] 查询模型配置失败: {e}") + return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'}) + else: + logger.info(f"[PRELOAD] 使用前端传递的基座模型路径: {base_model_path}") + + # 训练后的模型路径 + trained_model_path = f"/app/base/saves/{train_method}/{model_name}" + + # 检查路径是否存在(兼容Windows和Linux) + if not os.path.exists(trained_model_path): + logger.warning(f"[PRELOAD] 训练模型路径不存在: {trained_model_path}") + # 尝试查找适配器文件来确定模型是否存在 + adapter_path = os.path.join(trained_model_path, 'adapter_model.bin') + safetensors_path = os.path.join(trained_model_path, 'model.safetensors') + if not os.path.exists(adapter_path) and not os.path.exists(safetensors_path): + return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'}) + + # 预热消息 - 使用一个简单的问候语来加载模型 + work_dir = '/app/base' + warmup_messages = [{'role': 'system', 'content': 'You are a helpful assistant.'}, {'role': 'user', 'content': 'Hello'}] + + # 根据训练方法选择 finetuning_type + finetuning_type = 'lora' if train_method == 'lora' else 'full' + + # 构建 llamafactory 预热脚本 + preload_script = f''' +import sys +import json +import logging +logging.basicConfig(level=logging.WARNING) + +from llmtuner import ChatModel + +def main(): + chat_model = ChatModel({{ + "model_name_or_path": "{base_model_path}", + "adapter_name_or_path": "{trained_model_path}", + "template": "llama3", + "finetuning_type": "{finetuning_type}", + "temperature": 0.1, + "max_new_tokens": 1 + }}) + + messages = {json.dumps(warmup_messages, ensure_ascii=False)} + + # 执行推理以加载模型 + response = chat_model.chat(messages) + print("Model loaded successfully") + +if __name__ == "__main__": + main() +''' + + try: + work_dir = '/app/base' + script_path = os.path.join(work_dir, 'temp_preload.py') + + with open(script_path, 'w', encoding='utf-8') as f: + f.write(preload_script) + + # 设置环境变量,包括 PYTHONPATH 以便找到 llamafactory + import sys as sys_module + env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'} + # 尝试查找 llamafactory 目录并添加到 PYTHONPATH + llm_factory_paths = ['/app/base', '/app', '/app/base/src/llamafactory'] + for path in llm_factory_paths: + if os.path.exists(path) and os.path.exists(os.path.join(path, 'llmtuner')): + env['PYTHONPATH'] = path + logger.info(f"[PRELOAD] 设置 PYTHONPATH={path}") + break + + logger.info(f"[PRELOAD] 执行预热脚本...") + result = subprocess.run( + [sys_module.executable, 'temp_preload.py'], + capture_output=True, + text=True, + timeout=180, + cwd=work_dir, + env=env + ) + + os.remove(script_path) + + logger.info(f"[PRELOAD] 命令返回码: {result.returncode}") + logger.info(f"[PRELOAD] stdout: {result.stdout[:500] if result.stdout else 'empty'}") + logger.info(f"[PRELOAD] stderr: {result.stderr[:500] if result.stderr else 'empty'}") + + if result.returncode == 0: + return jsonify({ + 'code': 0, + 'message': '模型预加载成功', + 'data': { + 'model_name': model_name, + 'train_method': train_method, + 'base_model': base_model_path + } + }) + else: + error_msg = result.stderr.strip() if result.stderr else result.stdout.strip() + if not error_msg: + error_msg = f'命令执行失败,返回码: {result.returncode}' + return jsonify({'code': 1, 'message': f'预加载失败: {error_msg}'}) + + except subprocess.TimeoutExpired: + logger.error("[PRELOAD] 预加载超时") + return jsonify({'code': 1, 'message': '预加载超时,请稍后重试'}) + except Exception as e: + logger.error(f"[PRELOAD] 预加载异常: {str(e)}") + return jsonify({'code': 1, 'message': f'预加载异常: {str(e)}'}) + + @model_chat_bp.route('/trained', methods=['POST']) def chat_trained_model(): """使用已训练模型进行对话推理""" @@ -220,6 +415,7 @@ def chat_trained_model(): data = request.json model_name = data.get('model_name') # 模型名称 train_method = data.get('train_method', 'lora') # 训练方法: lora, full + base_model_path = data.get('base_model_path') # 前端传递的基座模型路径 system_prompt = data.get('system_prompt', '') user_question = data.get('user_question') temperature = data.get('temperature', 0.7) @@ -236,39 +432,64 @@ def chat_trained_model(): with open(CONFIG_PATH, 'r', encoding='utf-8') as f: CONFIG = yaml.safe_load(f) - # 获取基座模型路径 - 从数据库查询 model_name 对应的路径 - try: - db_config = CONFIG['database'] - conn = pymysql.connect( - host=db_config['host'], - port=db_config['port'], - user=db_config['username'], - password=db_config['password'], - database=db_config['name'], - charset=db_config.get('charset', 'utf8mb4'), - cursorclass=pymysql.cursors.DictCursor - ) - cursor = conn.cursor() - cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,)) - model_result = cursor.fetchone() - conn.close() + # 优先使用前端传递的基座模型路径,否则从数据库查询 + if not base_model_path: + try: + db_config = CONFIG['database'] + conn = pymysql.connect( + host=db_config['host'], + port=db_config['port'], + user=db_config['username'], + password=db_config['password'], + database=db_config['name'], + charset=db_config.get('charset', 'utf8mb4'), + cursorclass=pymysql.cursors.DictCursor + ) + cursor = conn.cursor() - if not model_result or not model_result.get('path'): - return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的配置'}) + # 优先从训练任务表查询基座模型 + cursor.execute(""" + SELECT base_model, output_model_name FROM fine_tune + WHERE output_model_name LIKE %s OR output_model_name LIKE %s + LIMIT 1 + """, (f'%/{model_name}', f'%{model_name}%')) + ft_result = cursor.fetchone() - base_model_path = model_result['path'] - except Exception as e: - return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'}) + if ft_result and ft_result.get('base_model'): + base_model_val = ft_result['base_model'] + # 如果是数字ID,查询模型管理表获取路径 + if str(base_model_val).isdigit(): + cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,)) + model_result = cursor.fetchone() + if model_result: + base_model_path = model_result.get('path') + else: + # 直接是路径 + base_model_path = base_model_val + + # 如果训练任务表没找到,尝试从模型管理表按名称查询 + if not base_model_path: + cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,)) + model_result = cursor.fetchone() + if model_result: + base_model_path = model_result.get('path') + + conn.close() + + if not base_model_path: + return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'}) + except Exception as e: + return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'}) # 训练后的模型路径 trained_model_path = f"/app/base/saves/{train_method}/{model_name}" + # 检查路径是否存在(兼容Windows和Linux) if not os.path.exists(trained_model_path): - return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'}) - - # 构建 llamafactory-cli chat 命令 - work_dir = '/app/base' - llamafactory_dir = '/app/base' + adapter_path = os.path.join(trained_model_path, 'adapter_model.bin') + safetensors_path = os.path.join(trained_model_path, 'model.safetensors') + if not os.path.exists(adapter_path) and not os.path.exists(safetensors_path): + return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'}) # 准备消息 messages = [] @@ -276,55 +497,76 @@ def chat_trained_model(): messages.append({'role': 'system', 'content': system_prompt}) messages.append({'role': 'user', 'content': user_question}) - # 将消息转换为 JSON 字符串 - messages_json = json.dumps(messages, ensure_ascii=False) + # 构建 llamafactory 推理脚本 + inference_script = f''' +import sys +import json +import logging +logging.basicConfig(level=logging.WARNING) - # 构建 llamafactory-cli chat 命令 - full_cmd = f'cd {llamafactory_dir} && export CUDA_VISIBLE_DEVICES=0 && echo \'{messages_json}\' | llamafactory-cli chat --model_name_or_path {base_model_path} --adapter_name_or_path {trained_model_path} --template llama3 --finetuning_type lora --temperature {temperature} --max_tokens {max_tokens}' +from llmtuner import ChatModel + +def main(): + chat_model = ChatModel({{ + "model_name_or_path": "{base_model_path}", + "adapter_name_or_path": "{trained_model_path}", + "template": "llama3", + "finetuning_type": "lora", + "temperature": {temperature}, + "max_new_tokens": {max_tokens} + }}) + + messages = {json.dumps(messages, ensure_ascii=False)} + + response = chat_model.chat(messages) + print(response) + +if __name__ == "__main__": + main() +''' + + # 写入临时脚本 + work_dir = '/app/base' + script_path = os.path.join(work_dir, 'temp_inference.py') try: - # 执行命令 + with open(script_path, 'w', encoding='utf-8') as f: + f.write(inference_script) + + # 设置环境变量,包括 PYTHONPATH 以便找到 llamafactory + import sys as sys_module + env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'} + for path in ['/app/base', '/app', '/app/base/src/llamafactory']: + if os.path.exists(path) and os.path.exists(os.path.join(path, 'llmtuner')): + env['PYTHONPATH'] = path + break + + # 执行推理脚本 result = subprocess.run( - full_cmd, - shell=True, + [sys_module.executable, 'temp_inference.py'], capture_output=True, text=True, - timeout=120, - cwd=work_dir + timeout=180, + cwd=work_dir, + env=env ) - output = result.stdout - error = result.stderr + # 清理临时脚本 + os.remove(script_path) - # 解析输出,提取assistant回复 - # llamafactory-cli chat 输出格式通常是: - # <|im_start|>assistant - # xxx - # <|im_end|> - - assistant_content = '' - if '<|im_start|>assistant' in output: - parts = output.split('<|im_start|>assistant') - if len(parts) > 1: - content_part = parts[1].split('<|im_end|>')[0].strip() - # 移除可能存在的换行前缀 - content_part = content_part.lstrip('\n').strip() - assistant_content = content_part - elif result.returncode == 0: - # 如果没有特殊标记,尝试提取最后一部分作为回复 - lines = output.strip().split('\n') - assistant_content = '\n'.join(lines).strip() + if result.returncode == 0: + assistant_content = result.stdout.strip() + return jsonify({ + 'code': 0, + 'data': { + 'model_name': model_name, + 'train_method': train_method, + 'response': assistant_content + } + }) else: - return jsonify({'code': 1, 'message': f'推理失败: {error or output}'}) - - return jsonify({ - 'code': 0, - 'data': { - 'model_name': model_name, - 'train_method': train_method, - 'response': assistant_content - } - }) + error_msg = result.stderr.strip() if result.stderr else result.stdout.strip() + return jsonify({'code': 1, 'message': f'推理失败: {error_msg}'}) except subprocess.TimeoutExpired: return jsonify({'code': 1, 'message': '推理超时,请稍后重试'}) diff --git a/src/api/model_manage.py b/src/api/model_manage.py index b9532cd..0c21264 100644 --- a/src/api/model_manage.py +++ b/src/api/model_manage.py @@ -47,24 +47,32 @@ def generic_get_all(table_name, order_by='create_time DESC'): def get_model_path_by_name(model_name): """根据模型名称查询模型路径(用于获取基座模型路径)""" + import logging + logger = logging.getLogger(__name__) + logger.info(f"[DEBUG get_model_path_by_name] 查询模型: {model_name}") + try: conn = get_db_connection() cursor = conn.cursor() # 优先从训练任务表查询基座模型 + logger.info(f"[DEBUG get_model_path_by_name] 尝试从fine_tune表查询...") cursor.execute(""" - SELECT base_model FROM fine_tune + SELECT base_model, output_model_name FROM fine_tune WHERE output_model_name LIKE %s OR output_model_name LIKE %s LIMIT 1 """, (f'%/{model_name}', f'%{model_name}%')) ft_result = cursor.fetchone() + logger.info(f"[DEBUG get_model_path_by_name] fine_tune查询结果: {ft_result}") if ft_result and ft_result.get('base_model'): base_model_val = ft_result['base_model'] + logger.info(f"[DEBUG get_model_path_by_name] base_model_val: {base_model_val}") # 如果是数字ID,查询模型管理表获取路径 if str(base_model_val).isdigit(): cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,)) model_result = cursor.fetchone() + logger.info(f"[DEBUG get_model_path_by_name] model_manage查询结果(数字ID): {model_result}") if model_result: cursor.close() conn.close() @@ -76,12 +84,15 @@ def get_model_path_by_name(model_name): return base_model_val # 如果训练任务表没找到,尝试从模型管理表按名称查询 + logger.info(f"[DEBUG get_model_path_by_name] 尝试从model_manage表查询...") cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,)) result = cursor.fetchone() + logger.info(f"[DEBUG get_model_path_by_name] model_manage查询结果: {result}") cursor.close() conn.close() if result: return result.get('path') + logger.info(f"[DEBUG get_model_path_by_name] 未找到任何匹配,返回None") return None except Exception as e: logger.error(f"[ERROR] 查询模型路径失败: {e}") @@ -377,6 +388,16 @@ def get_trained_models(): logger.info(f"[DEBUG] 找到 {len(models)} 个已训练模型") + # 检查每个模型是否已合并或正在合并 + local_trained_path = os.path.join(PROJECT_ROOT, 'local_trained_models') + for model in models: + model_name = model['name'] + merged_path = os.path.join(local_trained_path, model_name) + lock_file = os.path.join(local_trained_path, f'.merging_{model_name}.lock') + model['merged'] = os.path.exists(merged_path) + model['merging'] = os.path.exists(lock_file) + logger.info(f"[DEBUG] 模型 {model_name} 已合并: {model['merged']}, 正在合并: {model['merging']}") + return jsonify({ 'code': 0, 'data': { @@ -387,3 +408,264 @@ def get_trained_models(): except Exception as e: logger.error(f"获取已训练模型列表失败: {e}") return jsonify({'code': 1, 'message': str(e)}) + + +# ============ 合并权重接口 ============ + +@model_manage_bp.route('/merge', methods=['POST']) +def merge_model(): + """合并模型权重(将LoRA适配器合并到基座模型)""" + import subprocess + import sys + import logging + logger = logging.getLogger(__name__) + + data = request.json + model_name = data.get('model_name') # 模型名称 + train_method = data.get('train_method', 'lora') # 训练方法 + base_model_path = data.get('base_model_path') # 基座模型路径 + + if not model_name: + return jsonify({'code': 1, 'message': '缺少模型名称'}) + + logger.info(f"[MERGE] 开始合并模型: {model_name}, 方法: {train_method}") + + # 如果没有提供基座模型路径,从数据库查询 + if not base_model_path: + try: + conn = get_db_connection() + cursor = conn.cursor() + + # 优先从训练任务表查询 + cursor.execute(""" + SELECT base_model FROM fine_tune + WHERE output_model_name LIKE %s OR output_model_name LIKE %s + LIMIT 1 + """, (f'%/{model_name}', f'%{model_name}%')) + ft_result = cursor.fetchone() + + if ft_result and ft_result.get('base_model'): + base_model_val = ft_result['base_model'] + if str(base_model_val).isdigit(): + cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,)) + model_result = cursor.fetchone() + if model_result: + base_model_path = model_result.get('path') + else: + base_model_path = base_model_val + + # 如果没找到,尝试从模型管理表按名称查询 + if not base_model_path: + cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,)) + model_result = cursor.fetchone() + if model_result: + base_model_path = model_result.get('path') + + conn.close() + + if not base_model_path: + return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'}) + except Exception as e: + logger.error(f"[MERGE] 查询模型配置失败: {e}") + return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'}) + + # 训练后的模型路径(LoRA适配器) + adapter_path = f"/app/base/saves/{train_method}/{model_name}" + + # 检查路径是否存在 + if not os.path.exists(adapter_path): + return jsonify({'code': 1, 'message': f'训练模型不存在: {adapter_path}'}) + + # 合并后的输出路径 + output_path = f"/app/base/local_trained_models/{model_name}" + + # 合并状态锁文件 + lock_file = f"/app/base/local_trained_models/.merging_{model_name}.lock" + + # 创建输出目录 + os.makedirs(output_path, exist_ok=True) + + # 创建锁文件表示正在合并中 + try: + with open(lock_file, 'w') as f: + f.write('merging') + + work_dir = '/app/base' + + # 设置环境变量 + env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'} + + # 使用 llamafactory-cli export 命令(假设已在系统 PATH 中,与训练命令一致) + cli_cmd = ['llamafactory-cli', 'export'] + + # 检查 llamafactory-cli 是否存在 + try: + # 尝试使用 which 命令(Linux/Mac) + subprocess.run(['which', 'llamafactory-cli'], capture_output=True, check=True) + except (subprocess.CalledProcessError, FileNotFoundError): + # Windows 上没有 which 命令,直接尝试执行 + logger.info("[MERGE] which 命令不可用,直接尝试执行 llamafactory-cli") + + # 构建完整命令参数 + export_args = [ + '--model_name_or_path', base_model_path, + '--adapter_name_or_path', adapter_path, + '--export_dir', output_path + ] + + logger.info(f"[MERGE] 执行合并命令: {' '.join(cli_cmd)} {' '.join(export_args)}") + + # 直接执行 llamafactory-cli export 命令 + result = subprocess.run( + cli_cmd + export_args, + capture_output=True, + text=True, + timeout=600, + cwd=work_dir or '/app/base', + env=env + ) + + logger.info(f"[MERGE] 命令返回码: {result.returncode}") + logger.info(f"[MERGE] stdout: {result.stdout[:500] if result.stdout else 'empty'}") + logger.info(f"[MERGE] stderr: {result.stderr[:500] if result.stderr else 'empty'}") + + # 等待输出目录完全创建 + import time + max_wait = 5 # 最多等待5秒 + waited = 0 + while not os.path.exists(output_path) and waited < max_wait: + time.sleep(0.5) + waited += 0.5 + + # 无论成功失败,都删除锁文件 + if os.path.exists(lock_file): + os.remove(lock_file) + + if result.returncode == 0: + # 确保目录存在才返回成功 + if os.path.exists(output_path): + return jsonify({ + 'code': 0, + 'message': f'模型权重已成功合并到 {output_path}', + 'data': { + 'model_name': model_name, + 'output_path': output_path + } + }) + else: + return jsonify({'code': 1, 'message': '合并失败:输出目录未创建'}) + else: + error_msg = result.stderr.strip() if result.stderr else result.stdout.strip() + if not error_msg: + error_msg = f'命令执行失败,返回码: {result.returncode}' + return jsonify({'code': 1, 'message': f'合并失败: {error_msg}'}) + + except subprocess.TimeoutExpired: + logger.error("[MERGE] 合并超时") + # 删除锁文件 + if os.path.exists(lock_file): + os.remove(lock_file) + return jsonify({'code': 1, 'message': '合并超时,请稍后重试'}) + except Exception as e: + logger.error(f"[MERGE] 合并异常: {str(e)}") + return jsonify({'code': 1, 'message': f'合并异常: {str(e)}'}) + + +# ============ 删除已训练模型接口 ============ + +@model_manage_bp.route('/trained-models/', methods=['DELETE']) +def delete_trained_model(model_name): + """删除已训练模型(从local_trained_models目录)""" + import shutil + import logging + logger = logging.getLogger(__name__) + + try: + # 删除 local_trained_models 目录下的模型 + model_path = os.path.join(PROJECT_ROOT, 'local_trained_models', model_name) + + if not os.path.exists(model_path): + return jsonify({'code': 1, 'message': f'模型不存在: {model_name}'}) + + # 删除目录 + shutil.rmtree(model_path) + logger.info(f"[DELETE] 已删除模型: {model_path}") + + return jsonify({'code': 0, 'message': '删除成功'}) + except Exception as e: + logger.error(f"[DELETE] 删除模型失败: {str(e)}") + return jsonify({'code': 1, 'message': f'删除失败: {str(e)}'}) + + +# ============ 导出已训练模型接口 ============ + +@model_manage_bp.route('/trained-models//export', methods=['GET']) +def export_trained_model(model_name): + """导出已训练模型(打包成zip下载)""" + import shutil + import logging + from flask import send_file + logger = logging.getLogger(__name__) + + try: + # 优先从 local_trained_models 目录查找(合并后的模型) + model_path = os.path.join(PROJECT_ROOT, 'local_trained_models', model_name) + + # 如果本地模型目录不存在,尝试从 saves 目录查找(未合并的模型) + if not os.path.exists(model_path): + # 查找 saves 目录下的模型 + saves_path = os.path.join(PROJECT_ROOT, 'saves') + train_methods = ['lora', 'full', 'qlora', 'dpo', 'cpt', 'prefix', 'adapter', 'peft'] + + for method in train_methods: + potential_path = os.path.join(saves_path, method, model_name) + if os.path.exists(potential_path): + model_path = potential_path + logger.info(f"[EXPORT] 从 saves/{method} 目录找到模型: {model_path}") + break + + # 如果还是找不到,返回错误 + if not os.path.exists(model_path): + return jsonify({'code': 1, 'message': f'模型不存在: {model_name}'}) + + # 创建临时 zip 文件 + zip_path = os.path.join(PROJECT_ROOT, 'temp_exports') + os.makedirs(zip_path, exist_ok=True) + + zip_file = os.path.join(zip_path, f'{model_name}.zip') + + # 如果已存在先删除 + if os.path.exists(zip_file): + os.remove(zip_file) + + # 打包成 zip + shutil.make_archive(zip_file[:-4], 'zip', model_path) + logger.info(f"[EXPORT] 已打包模型: {zip_file}") + + # 发送文件给前端 + response = send_file( + zip_file, + as_attachment=True, + download_name=f'{model_name}.zip', + mimetype='application/zip' + ) + + # 注册回调,删除临时文件 + def cleanup(): + try: + if os.path.exists(zip_file): + os.remove(zip_file) + logger.info(f"[EXPORT] 已清理临时文件: {zip_file}") + except: + pass + + # 使用 after_request 清理 + @response.call_on_close + def cleanup_after_request(): + cleanup() + + return response + + except Exception as e: + logger.error(f"[EXPORT] 导出模型失败: {str(e)}") + return jsonify({'code': 1, 'message': f'导出失败: {str(e)}'}) diff --git a/web/pages/dataset-preview.html b/web/pages/dataset-preview.html index 9f1c483..91ff693 100644 --- a/web/pages/dataset-preview.html +++ b/web/pages/dataset-preview.html @@ -223,6 +223,11 @@ + 合并权重 - YG_FT + + - - - -