""" 模型对话 API 路由 """ import os import pymysql import yaml import json import requests import concurrent.futures import subprocess from flask import Blueprint, request, jsonify # 获取项目根目录 PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # 创建蓝图 model_chat_bp = Blueprint('model_chat', __name__, url_prefix='/api/model-chat') def get_db_connection(): """获取数据库连接""" CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml') with open(CONFIG_PATH, 'r', encoding='utf-8') as f: CONFIG = yaml.safe_load(f) db_config = CONFIG['database'] return pymysql.connect( host=db_config['host'], port=db_config['port'], user=db_config['username'], password=db_config['password'], database=db_config['name'], charset=db_config.get('charset', 'utf8mb4'), cursorclass=pymysql.cursors.DictCursor ) def generic_get_by_id(table_name, id_val): """按ID查询""" conn = get_db_connection() cursor = conn.cursor() cursor.execute(f"SELECT * FROM {table_name} WHERE id = %s", (id_val,)) result = cursor.fetchone() cursor.close() conn.close() return result def call_api_model(model_config, messages, temperature, max_tokens): """调用API模型(OpenAI兼容格式)""" api_url = model_config.get('api_url') api_key = model_config.get('api_key') model_name = model_config.get('model_name', '') # 构造OpenAI兼容的完整URL # 支持: https://api.openai.com/v1/chat/completions 或 https://api.example.com/v1 # 如果URL已经包含 /chat/completions 则直接使用,否则追加 if '/chat/completions' in api_url: full_url = api_url else: # 去掉末尾的斜杠,然后追加 /chat/completions full_url = api_url.rstrip('/') + '/chat/completions' headers = { 'Content-Type': 'application/json', 'Authorization': f'Bearer {api_key}' } payload = { 'model': model_name, 'messages': messages, 'temperature': temperature, 'max_tokens': max_tokens } try: response = requests.post(full_url, headers=headers, json=payload, timeout=120) response.raise_for_status() result = response.json() if 'choices' in result and len(result['choices']) > 0: return { 'success': True, 'content': result['choices'][0]['message'].get('content', '') } return {'success': False, 'error': 'API返回格式异常'} except requests.exceptions.RequestException as e: return {'success': False, 'error': str(e)} def call_local_model(model_config, messages, temperature, max_tokens): """调用本地模型(通过vLLM OpenAI兼容API)""" api_url = model_config.get('path') # 本地模型path字段存储API地址 model_name = model_config.get('model_name', '') if not api_url: return {'success': False, 'error': '本地模型API地址未配置'} headers = {'Content-Type': 'application/json'} payload = { 'model': model_name, 'messages': messages, 'temperature': temperature, 'max_tokens': max_tokens } try: response = requests.post(api_url, headers=headers, json=payload, timeout=120) response.raise_for_status() result = response.json() if 'choices' in result and len(result['choices']) > 0: return { 'success': True, 'content': result['choices'][0]['message'].get('content', '') } return {'success': False, 'error': 'API返回格式异常'} except requests.exceptions.RequestException as e: return {'success': False, 'error': str(e)} @model_chat_bp.route('', methods=['POST']) def model_chat(): """模型对话接口""" data = request.json model_id = data.get('model_id') system_prompt = data.get('system_prompt', '') user_question = data.get('user_question') temperature = data.get('temperature', 0.7) max_tokens = data.get('max_tokens', 2048) if not model_id: return jsonify({'code': 1, 'message': '缺少模型ID'}) if not user_question: return jsonify({'code': 1, 'message': '缺少用户提问'}) # 获取模型配置 model = generic_get_by_id('model_manage', model_id) if not model: return jsonify({'code': 1, 'message': '模型不存在'}) # 构建消息 messages = [] if system_prompt: messages.append({'role': 'system', 'content': system_prompt}) messages.append({'role': 'user', 'content': user_question}) # 根据模型类型调用 if model.get('model_source') == 'api': result = call_api_model(model, messages, temperature, max_tokens) else: result = call_local_model(model, messages, temperature, max_tokens) if result.get('success'): return jsonify({ 'code': 0, 'data': { 'model_id': model_id, 'model_name': model.get('name'), 'response': result['content'] } }) else: return jsonify({'code': 1, 'message': result.get('error', '调用失败')}) @model_chat_bp.route('/batch', methods=['POST']) def model_chat_batch(): """批量模型对话接口(并发调用多个模型)""" data = request.json model_ids = data.get('model_ids', []) system_prompt = data.get('system_prompt', '') user_question = data.get('user_question') temperature = data.get('temperature', 0.7) max_tokens = data.get('max_tokens', 2048) if not model_ids: return jsonify({'code': 1, 'message': '缺少模型ID列表'}) if not user_question: return jsonify({'code': 1, 'message': '缺少用户提问'}) def call_single_model(model_id): model = generic_get_by_id('model_manage', model_id) if not model: return {'model_id': model_id, 'success': False, 'error': '模型不存在'} messages = [] if system_prompt: messages.append({'role': 'system', 'content': system_prompt}) messages.append({'role': 'user', 'content': user_question}) if model.get('model_source') == 'api': result = call_api_model(model, messages, temperature, max_tokens) else: result = call_local_model(model, messages, temperature, max_tokens) return { 'model_id': model_id, 'model_name': model.get('name'), 'success': result.get('success', False), 'response': result.get('content', ''), 'error': result.get('error', '') } # 并发调用所有模型 results = [] with concurrent.futures.ThreadPoolExecutor(max_workers=min(len(model_ids), 4)) as executor: future_to_model = {executor.submit(call_single_model, mid): mid for mid in model_ids} for future in concurrent.futures.as_completed(future_to_model): results.append(future.result()) return jsonify({'code': 0, 'data': results}) @model_chat_bp.route('/trained/preload', methods=['POST']) def preload_trained_model(): """预加载已训练模型(使用 llamafactory)""" import sys as sys_module import pymysql import yaml import logging logger = logging.getLogger(__name__) data = request.json model_name = data.get('model_name') # 模型名称 train_method = data.get('train_method', 'lora') # 训练方法: lora, full base_model_path = data.get('base_model_path') # 前端传递的基座模型路径 if not model_name: return jsonify({'code': 1, 'message': '缺少模型名称'}) logger.info(f"[PRELOAD] 开始预加载模型: {model_name}, 方法: {train_method}") logger.info(f"[PRELOAD] 前端传递的基座模型路径: {base_model_path}") # 获取项目根目录 PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml') try: with open(CONFIG_PATH, 'r', encoding='utf-8') as f: CONFIG = yaml.safe_load(f) except Exception as e: return jsonify({'code': 1, 'message': f'读取配置失败: {str(e)}'}) # 优先使用前端传递的基座模型路径,否则从数据库查询 if not base_model_path: try: db_config = CONFIG['database'] conn = pymysql.connect( host=db_config['host'], port=db_config['port'], user=db_config['username'], password=db_config['password'], database=db_config['name'], charset=db_config.get('charset', 'utf8mb4'), cursorclass=pymysql.cursors.DictCursor ) cursor = conn.cursor() # 优先从训练任务表查询基座模型 logger.info(f"[PRELOAD] 尝试从fine_tune表查询模型: {model_name}") cursor.execute(""" SELECT base_model, output_model_name FROM fine_tune WHERE output_model_name LIKE %s OR output_model_name LIKE %s LIMIT 1 """, (f'%/{model_name}', f'%{model_name}%')) ft_result = cursor.fetchone() logger.info(f"[PRELOAD] fine_tune查询结果: {ft_result}") if ft_result and ft_result.get('base_model'): base_model_val = ft_result['base_model'] logger.info(f"[PRELOAD] base_model_val: {base_model_val}") # 如果是数字ID,查询模型管理表获取路径 if str(base_model_val).isdigit(): cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,)) model_result = cursor.fetchone() logger.info(f"[PRELOAD] model_manage查询结果(数字ID): {model_result}") if model_result: base_model_path = model_result.get('path') else: # 直接是路径 base_model_path = base_model_val # 如果训练任务表没找到,尝试从模型管理表按名称查询 if not base_model_path: logger.info(f"[PRELOAD] 尝试从model_manage表查询...") cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,)) model_result = cursor.fetchone() logger.info(f"[PRELOAD] model_manage查询结果: {model_result}") if model_result: base_model_path = model_result.get('path') conn.close() if not base_model_path: logger.error(f"[PRELOAD] 未找到模型 {model_name} 的基座模型配置") return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'}) except Exception as e: logger.error(f"[PRELOAD] 查询模型配置失败: {e}") return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'}) else: logger.info(f"[PRELOAD] 使用前端传递的基座模型路径: {base_model_path}") # 训练后的模型路径 trained_model_path = f"/app/base/saves/{train_method}/{model_name}" # 检查路径是否存在(兼容Windows和Linux) if not os.path.exists(trained_model_path): logger.warning(f"[PRELOAD] 训练模型路径不存在: {trained_model_path}") # 尝试查找适配器文件来确定模型是否存在 adapter_path = os.path.join(trained_model_path, 'adapter_model.bin') safetensors_path = os.path.join(trained_model_path, 'model.safetensors') if not os.path.exists(adapter_path) and not os.path.exists(safetensors_path): return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'}) # 预热消息 - 使用一个简单的问候语来加载模型 work_dir = '/app/base' warmup_messages = [{'role': 'system', 'content': 'You are a helpful assistant.'}, {'role': 'user', 'content': 'Hello'}] # 根据训练方法选择 finetuning_type finetuning_type = 'lora' if train_method == 'lora' else 'full' # 构建 llamafactory 预热脚本 preload_script = f''' import sys import json import logging logging.basicConfig(level=logging.WARNING) from llmtuner import ChatModel def main(): chat_model = ChatModel({{ "model_name_or_path": "{base_model_path}", "adapter_name_or_path": "{trained_model_path}", "template": "llama3", "finetuning_type": "{finetuning_type}", "temperature": 0.1, "max_new_tokens": 1 }}) messages = {json.dumps(warmup_messages, ensure_ascii=False)} # 执行推理以加载模型 response = chat_model.chat(messages) print("Model loaded successfully") if __name__ == "__main__": main() ''' try: work_dir = '/app/base' script_path = os.path.join(work_dir, 'temp_preload.py') with open(script_path, 'w', encoding='utf-8') as f: f.write(preload_script) # 设置环境变量,包括 PYTHONPATH 以便找到 llamafactory import sys as sys_module env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'} # 尝试查找 llamafactory 目录并添加到 PYTHONPATH llm_factory_paths = ['/app/base', '/app', '/app/base/src/llamafactory'] for path in llm_factory_paths: if os.path.exists(path) and os.path.exists(os.path.join(path, 'llmtuner')): env['PYTHONPATH'] = path logger.info(f"[PRELOAD] 设置 PYTHONPATH={path}") break logger.info(f"[PRELOAD] 执行预热脚本...") result = subprocess.run( [sys_module.executable, 'temp_preload.py'], capture_output=True, text=True, timeout=180, cwd=work_dir, env=env ) os.remove(script_path) logger.info(f"[PRELOAD] 命令返回码: {result.returncode}") logger.info(f"[PRELOAD] stdout: {result.stdout[:500] if result.stdout else 'empty'}") logger.info(f"[PRELOAD] stderr: {result.stderr[:500] if result.stderr else 'empty'}") if result.returncode == 0: return jsonify({ 'code': 0, 'message': '模型预加载成功', 'data': { 'model_name': model_name, 'train_method': train_method, 'base_model': base_model_path } }) else: error_msg = result.stderr.strip() if result.stderr else result.stdout.strip() if not error_msg: error_msg = f'命令执行失败,返回码: {result.returncode}' return jsonify({'code': 1, 'message': f'预加载失败: {error_msg}'}) except subprocess.TimeoutExpired: logger.error("[PRELOAD] 预加载超时") return jsonify({'code': 1, 'message': '预加载超时,请稍后重试'}) except Exception as e: logger.error(f"[PRELOAD] 预加载异常: {str(e)}") return jsonify({'code': 1, 'message': f'预加载异常: {str(e)}'}) @model_chat_bp.route('/trained', methods=['POST']) def chat_trained_model(): """使用已训练模型进行对话推理""" import pymysql import yaml data = request.json model_name = data.get('model_name') # 模型名称 train_method = data.get('train_method', 'lora') # 训练方法: lora, full base_model_path = data.get('base_model_path') # 前端传递的基座模型路径 system_prompt = data.get('system_prompt', '') user_question = data.get('user_question') temperature = data.get('temperature', 0.7) max_tokens = data.get('max_tokens', 2048) if not model_name: return jsonify({'code': 1, 'message': '缺少模型名称'}) if not user_question: return jsonify({'code': 1, 'message': '缺少用户提问'}) # 获取项目根目录 PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml') with open(CONFIG_PATH, 'r', encoding='utf-8') as f: CONFIG = yaml.safe_load(f) # 优先使用前端传递的基座模型路径,否则从数据库查询 if not base_model_path: try: db_config = CONFIG['database'] conn = pymysql.connect( host=db_config['host'], port=db_config['port'], user=db_config['username'], password=db_config['password'], database=db_config['name'], charset=db_config.get('charset', 'utf8mb4'), cursorclass=pymysql.cursors.DictCursor ) cursor = conn.cursor() # 优先从训练任务表查询基座模型 cursor.execute(""" SELECT base_model, output_model_name FROM fine_tune WHERE output_model_name LIKE %s OR output_model_name LIKE %s LIMIT 1 """, (f'%/{model_name}', f'%{model_name}%')) ft_result = cursor.fetchone() if ft_result and ft_result.get('base_model'): base_model_val = ft_result['base_model'] # 如果是数字ID,查询模型管理表获取路径 if str(base_model_val).isdigit(): cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,)) model_result = cursor.fetchone() if model_result: base_model_path = model_result.get('path') else: # 直接是路径 base_model_path = base_model_val # 如果训练任务表没找到,尝试从模型管理表按名称查询 if not base_model_path: cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,)) model_result = cursor.fetchone() if model_result: base_model_path = model_result.get('path') conn.close() if not base_model_path: return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'}) except Exception as e: return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'}) # 训练后的模型路径 trained_model_path = f"/app/base/saves/{train_method}/{model_name}" # 检查路径是否存在(兼容Windows和Linux) if not os.path.exists(trained_model_path): adapter_path = os.path.join(trained_model_path, 'adapter_model.bin') safetensors_path = os.path.join(trained_model_path, 'model.safetensors') if not os.path.exists(adapter_path) and not os.path.exists(safetensors_path): return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'}) # 准备消息 messages = [] if system_prompt: messages.append({'role': 'system', 'content': system_prompt}) messages.append({'role': 'user', 'content': user_question}) # 构建 llamafactory 推理脚本 inference_script = f''' import sys import json import logging logging.basicConfig(level=logging.WARNING) from llmtuner import ChatModel def main(): chat_model = ChatModel({{ "model_name_or_path": "{base_model_path}", "adapter_name_or_path": "{trained_model_path}", "template": "llama3", "finetuning_type": "lora", "temperature": {temperature}, "max_new_tokens": {max_tokens} }}) messages = {json.dumps(messages, ensure_ascii=False)} response = chat_model.chat(messages) print(response) if __name__ == "__main__": main() ''' # 写入临时脚本 work_dir = '/app/base' script_path = os.path.join(work_dir, 'temp_inference.py') try: with open(script_path, 'w', encoding='utf-8') as f: f.write(inference_script) # 设置环境变量,包括 PYTHONPATH 以便找到 llamafactory import sys as sys_module env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'} for path in ['/app/base', '/app', '/app/base/src/llamafactory']: if os.path.exists(path) and os.path.exists(os.path.join(path, 'llmtuner')): env['PYTHONPATH'] = path break # 执行推理脚本 result = subprocess.run( [sys_module.executable, 'temp_inference.py'], capture_output=True, text=True, timeout=180, cwd=work_dir, env=env ) # 清理临时脚本 os.remove(script_path) if result.returncode == 0: assistant_content = result.stdout.strip() return jsonify({ 'code': 0, 'data': { 'model_name': model_name, 'train_method': train_method, 'response': assistant_content } }) else: error_msg = result.stderr.strip() if result.stderr else result.stdout.strip() return jsonify({'code': 1, 'message': f'推理失败: {error_msg}'}) except subprocess.TimeoutExpired: return jsonify({'code': 1, 'message': '推理超时,请稍后重试'}) except Exception as e: return jsonify({'code': 1, 'message': f'推理异常: {str(e)}'})