From 0f98d67e41590117ad5ab3f8ee9604948d15f770 Mon Sep 17 00:00:00 2001 From: "WIN-JHFT4D3SIVT\\caoxiaozhu" Date: Thu, 29 Jan 2026 17:39:06 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E8=BF=98=E4=BA=86=E8=BF=94?= =?UTF-8?q?=E5=9B=9E=E6=8C=89=E9=92=AE=E7=9A=84=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api/model_chat.py | 122 +++++++++++++++++++++++ src/api/model_manage.py | 153 ++++++++++++++++++++++++----- web/pages/model-manage-create.html | 1 + 3 files changed, 254 insertions(+), 22 deletions(-) diff --git a/src/api/model_chat.py b/src/api/model_chat.py index 7b20f04..fba96e3 100644 --- a/src/api/model_chat.py +++ b/src/api/model_chat.py @@ -6,6 +6,7 @@ import pymysql import yaml import requests import concurrent.futures +import subprocess from flask import Blueprint, request, jsonify # 获取项目根目录 @@ -208,3 +209,124 @@ def model_chat_batch(): results.append(future.result()) return jsonify({'code': 0, 'data': results}) + + +@model_chat_bp.route('/trained', methods=['POST']) +def chat_trained_model(): + """使用已训练模型进行对话推理""" + import pymysql + import yaml + + data = request.json + model_name = data.get('model_name') # 模型名称 + train_method = data.get('train_method', 'lora') # 训练方法: lora, full + system_prompt = data.get('system_prompt', '') + user_question = data.get('user_question') + temperature = data.get('temperature', 0.7) + max_tokens = data.get('max_tokens', 2048) + + if not model_name: + return jsonify({'code': 1, 'message': '缺少模型名称'}) + if not user_question: + return jsonify({'code': 1, 'message': '缺少用户提问'}) + + # 获取项目根目录 + PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml') + with open(CONFIG_PATH, 'r', encoding='utf-8') as f: + CONFIG = yaml.safe_load(f) + + # 获取基座模型路径 - 从数据库查询 model_name 对应的路径 + try: + db_config = CONFIG['database'] + conn = pymysql.connect( + host=db_config['host'], + port=db_config['port'], + user=db_config['username'], + password=db_config['password'], + database=db_config['name'], + charset=db_config.get('charset', 'utf8mb4'), + cursorclass=pymysql.cursors.DictCursor + ) + cursor = conn.cursor() + cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,)) + model_result = cursor.fetchone() + conn.close() + + if not model_result or not model_result.get('path'): + return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的配置'}) + + base_model_path = model_result['path'] + except Exception as e: + return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'}) + + # 训练后的模型路径 + trained_model_path = f"/app/base/saves/{train_method}/{model_name}" + + if not os.path.exists(trained_model_path): + return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'}) + + # 构建 llamafactory-cli chat 命令 + work_dir = '/app/base' + llamafactory_dir = '/app/base' + + # 准备消息 + messages = [] + if system_prompt: + messages.append({'role': 'system', 'content': system_prompt}) + messages.append({'role': 'user', 'content': user_question}) + + # 将消息转换为 JSON 字符串 + messages_json = json.dumps(messages, ensure_ascii=False) + + # 构建 llamafactory-cli chat 命令 + full_cmd = f'cd {llamafactory_dir} && export CUDA_VISIBLE_DEVICES=0 && echo \'{messages_json}\' | llamafactory-cli chat --model_name_or_path {base_model_path} --adapter_name_or_path {trained_model_path} --template llama3 --finetuning_type lora --temperature {temperature} --max_tokens {max_tokens}' + + try: + # 执行命令 + result = subprocess.run( + full_cmd, + shell=True, + capture_output=True, + text=True, + timeout=120, + cwd=work_dir + ) + + output = result.stdout + error = result.stderr + + # 解析输出,提取assistant回复 + # llamafactory-cli chat 输出格式通常是: + # <|im_start|>assistant + # xxx + # <|im_end|> + + assistant_content = '' + if '<|im_start|>assistant' in output: + parts = output.split('<|im_start|>assistant') + if len(parts) > 1: + content_part = parts[1].split('<|im_end|>')[0].strip() + # 移除可能存在的换行前缀 + content_part = content_part.lstrip('\n').strip() + assistant_content = content_part + elif result.returncode == 0: + # 如果没有特殊标记,尝试提取最后一部分作为回复 + lines = output.strip().split('\n') + assistant_content = '\n'.join(lines).strip() + else: + return jsonify({'code': 1, 'message': f'推理失败: {error or output}'}) + + return jsonify({ + 'code': 0, + 'data': { + 'model_name': model_name, + 'train_method': train_method, + 'response': assistant_content + } + }) + + except subprocess.TimeoutExpired: + return jsonify({'code': 1, 'message': '推理超时,请稍后重试'}) + except Exception as e: + return jsonify({'code': 1, 'message': f'推理异常: {str(e)}'}) diff --git a/src/api/model_manage.py b/src/api/model_manage.py index 3d7fb1e..b9532cd 100644 --- a/src/api/model_manage.py +++ b/src/api/model_manage.py @@ -45,6 +45,49 @@ def generic_get_all(table_name, order_by='create_time DESC'): return result +def get_model_path_by_name(model_name): + """根据模型名称查询模型路径(用于获取基座模型路径)""" + try: + conn = get_db_connection() + cursor = conn.cursor() + + # 优先从训练任务表查询基座模型 + cursor.execute(""" + SELECT base_model FROM fine_tune + WHERE output_model_name LIKE %s OR output_model_name LIKE %s + LIMIT 1 + """, (f'%/{model_name}', f'%{model_name}%')) + ft_result = cursor.fetchone() + + if ft_result and ft_result.get('base_model'): + base_model_val = ft_result['base_model'] + # 如果是数字ID,查询模型管理表获取路径 + if str(base_model_val).isdigit(): + cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,)) + model_result = cursor.fetchone() + if model_result: + cursor.close() + conn.close() + return model_result.get('path') + else: + # 直接是路径 + cursor.close() + conn.close() + return base_model_val + + # 如果训练任务表没找到,尝试从模型管理表按名称查询 + cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,)) + result = cursor.fetchone() + cursor.close() + conn.close() + if result: + return result.get('path') + return None + except Exception as e: + logger.error(f"[ERROR] 查询模型路径失败: {e}") + return None + + def generic_create(table_name, data): """通用创建""" conn = get_db_connection() @@ -226,42 +269,108 @@ def get_trained_models(): try: # 路径结构: /app/base/saves/{train_method}/{model_name}/ # train_method: lora, full, qlora, dpo, cpt 等 + # 同时兼容老结构: /app/base/saves/{model_name}/ - for train_method in os.listdir(base_path): - train_method_path = os.path.join(base_path, train_method) - if not os.path.isdir(train_method_path): + train_methods = ['lora', 'full', 'qlora', 'dpo', 'cpt', 'prefix', 'adapter', 'peft'] + + for item in os.listdir(base_path): + item_path = os.path.join(base_path, item) + if not os.path.isdir(item_path): continue - logger.info(f"[DEBUG] 检查训练方法目录: {train_method}") - model_count = 0 + # 情况1: 新结构 {train_method}/{model_name} + if item in train_methods: + logger.info(f"[DEBUG] 检查训练方法目录: {item}") + model_count = 0 - # 遍历模型文件夹 - for model_name in os.listdir(train_method_path): - model_path = os.path.join(train_method_path, model_name) - if not os.path.isdir(model_path): - continue + for model_name in os.listdir(item_path): + model_path = os.path.join(item_path, model_name) + if not os.path.isdir(model_path): + continue - # 检查是否有模型文件 + try: + files = os.listdir(model_path) + has_model = any(f.endswith('.bin') or f.endswith('.safetensors') for f in files) + + if has_model: + logger.info(f"[DEBUG] 找到模型: {item}/{model_name}") + # 获取文件创建时间 + try: + import time + stat = os.stat(model_path) + create_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(stat.st_mtime)) + except: + create_time = None + + # 查询基座模型路径 + base_model_path = get_model_path_by_name(model_name) + + models.append({ + 'name': model_name, + 'path': model_path, + 'base_model_path': base_model_path, + 'create_time': create_time, + 'train_methods': [{ + 'name': item, + 'path': model_path + }] + }) + model_count += 1 + except Exception as file_err: + logger.error(f"[DEBUG] 读取 {model_path} 失败: {file_err}") + + logger.info(f"[DEBUG] {item} 找到 {model_count} 个模型") + + # 情况2: 老结构 {model_name} 直接在 saves 下 + else: + logger.info(f"[DEBUG] 检查老结构模型目录: {item}") try: - files = os.listdir(model_path) - logger.info(f"[DEBUG] {train_method}/{model_name} 文件: {files[:5]}...") + files = os.listdir(item_path) has_model = any(f.endswith('.bin') or f.endswith('.safetensors') for f in files) if has_model: - logger.info(f"[DEBUG] 找到模型: {train_method}/{model_name}") + logger.info(f"[DEBUG] 找到模型: {item}") + + # 尝试从 adapter_config.json 推断 train_method + inferred_method = 'lora' # 默认 + config_file = os.path.join(item_path, 'adapter_config.json') + if os.path.exists(config_file): + try: + import json + with open(config_file, 'r', encoding='utf-8') as f: + config = json.load(f) + if 'peft_type' in config: + peft_type = config['peft_type'].lower() + if 'lora' in peft_type: + inferred_method = 'lora' + elif 'full' in peft_type or 'pt' in peft_type: + inferred_method = 'full' + except: + pass + + # 获取文件创建时间 + try: + import time + stat = os.stat(item_path) + create_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(stat.st_mtime)) + except: + create_time = None + + # 查询基座模型路径 + base_model_path = get_model_path_by_name(item) + models.append({ - 'name': model_name, - 'path': model_path, + 'name': item, + 'path': item_path, + 'base_model_path': base_model_path, + 'create_time': create_time, 'train_methods': [{ - 'name': train_method, - 'path': model_path + 'name': inferred_method, + 'path': item_path }] }) - model_count += 1 except Exception as file_err: - logger.error(f"[DEBUG] 读取 {model_path} 失败: {file_err}") - - logger.info(f"[DEBUG] {train_method} 找到 {model_count} 个模型") + logger.error(f"[DEBUG] 读取 {item_path} 失败: {file_err}") except Exception as list_err: logger.error(f"[DEBUG] 遍历目录失败: {list_err}") diff --git a/web/pages/model-manage-create.html b/web/pages/model-manage-create.html index b5020cf..c901cf9 100644 --- a/web/pages/model-manage-create.html +++ b/web/pages/model-manage-create.html @@ -675,6 +675,7 @@ function goBack() { window.location.href = backUrl; } + window.goBack = goBack; })();