""" 模型对话 API 路由 """ import os import pymysql import yaml import requests import concurrent.futures import subprocess from flask import Blueprint, request, jsonify # 获取项目根目录 PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # 创建蓝图 model_chat_bp = Blueprint('model_chat', __name__, url_prefix='/api/model-chat') def get_db_connection(): """获取数据库连接""" CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml') with open(CONFIG_PATH, 'r', encoding='utf-8') as f: CONFIG = yaml.safe_load(f) db_config = CONFIG['database'] return pymysql.connect( host=db_config['host'], port=db_config['port'], user=db_config['username'], password=db_config['password'], database=db_config['name'], charset=db_config.get('charset', 'utf8mb4'), cursorclass=pymysql.cursors.DictCursor ) def generic_get_by_id(table_name, id_val): """按ID查询""" conn = get_db_connection() cursor = conn.cursor() cursor.execute(f"SELECT * FROM {table_name} WHERE id = %s", (id_val,)) result = cursor.fetchone() cursor.close() conn.close() return result def call_api_model(model_config, messages, temperature, max_tokens): """调用API模型(OpenAI兼容格式)""" api_url = model_config.get('api_url') api_key = model_config.get('api_key') model_name = model_config.get('model_name', '') # 构造OpenAI兼容的完整URL # 支持: https://api.openai.com/v1/chat/completions 或 https://api.example.com/v1 # 如果URL已经包含 /chat/completions 则直接使用,否则追加 if '/chat/completions' in api_url: full_url = api_url else: # 去掉末尾的斜杠,然后追加 /chat/completions full_url = api_url.rstrip('/') + '/chat/completions' headers = { 'Content-Type': 'application/json', 'Authorization': f'Bearer {api_key}' } payload = { 'model': model_name, 'messages': messages, 'temperature': temperature, 'max_tokens': max_tokens } try: response = requests.post(full_url, headers=headers, json=payload, timeout=120) response.raise_for_status() result = response.json() if 'choices' in result and len(result['choices']) > 0: return { 'success': True, 'content': result['choices'][0]['message'].get('content', '') } return {'success': False, 'error': 'API返回格式异常'} except requests.exceptions.RequestException as e: return {'success': False, 'error': str(e)} def call_local_model(model_config, messages, temperature, max_tokens): """调用本地模型(通过vLLM OpenAI兼容API)""" api_url = model_config.get('path') # 本地模型path字段存储API地址 model_name = model_config.get('model_name', '') if not api_url: return {'success': False, 'error': '本地模型API地址未配置'} headers = {'Content-Type': 'application/json'} payload = { 'model': model_name, 'messages': messages, 'temperature': temperature, 'max_tokens': max_tokens } try: response = requests.post(api_url, headers=headers, json=payload, timeout=120) response.raise_for_status() result = response.json() if 'choices' in result and len(result['choices']) > 0: return { 'success': True, 'content': result['choices'][0]['message'].get('content', '') } return {'success': False, 'error': 'API返回格式异常'} except requests.exceptions.RequestException as e: return {'success': False, 'error': str(e)} @model_chat_bp.route('', methods=['POST']) def model_chat(): """模型对话接口""" data = request.json model_id = data.get('model_id') system_prompt = data.get('system_prompt', '') user_question = data.get('user_question') temperature = data.get('temperature', 0.7) max_tokens = data.get('max_tokens', 2048) if not model_id: return jsonify({'code': 1, 'message': '缺少模型ID'}) if not user_question: return jsonify({'code': 1, 'message': '缺少用户提问'}) # 获取模型配置 model = generic_get_by_id('model_manage', model_id) if not model: return jsonify({'code': 1, 'message': '模型不存在'}) # 构建消息 messages = [] if system_prompt: messages.append({'role': 'system', 'content': system_prompt}) messages.append({'role': 'user', 'content': user_question}) # 根据模型类型调用 if model.get('model_source') == 'api': result = call_api_model(model, messages, temperature, max_tokens) else: result = call_local_model(model, messages, temperature, max_tokens) if result.get('success'): return jsonify({ 'code': 0, 'data': { 'model_id': model_id, 'model_name': model.get('name'), 'response': result['content'] } }) else: return jsonify({'code': 1, 'message': result.get('error', '调用失败')}) @model_chat_bp.route('/batch', methods=['POST']) def model_chat_batch(): """批量模型对话接口(并发调用多个模型)""" data = request.json model_ids = data.get('model_ids', []) system_prompt = data.get('system_prompt', '') user_question = data.get('user_question') temperature = data.get('temperature', 0.7) max_tokens = data.get('max_tokens', 2048) if not model_ids: return jsonify({'code': 1, 'message': '缺少模型ID列表'}) if not user_question: return jsonify({'code': 1, 'message': '缺少用户提问'}) def call_single_model(model_id): model = generic_get_by_id('model_manage', model_id) if not model: return {'model_id': model_id, 'success': False, 'error': '模型不存在'} messages = [] if system_prompt: messages.append({'role': 'system', 'content': system_prompt}) messages.append({'role': 'user', 'content': user_question}) if model.get('model_source') == 'api': result = call_api_model(model, messages, temperature, max_tokens) else: result = call_local_model(model, messages, temperature, max_tokens) return { 'model_id': model_id, 'model_name': model.get('name'), 'success': result.get('success', False), 'response': result.get('content', ''), 'error': result.get('error', '') } # 并发调用所有模型 results = [] with concurrent.futures.ThreadPoolExecutor(max_workers=min(len(model_ids), 4)) as executor: future_to_model = {executor.submit(call_single_model, mid): mid for mid in model_ids} for future in concurrent.futures.as_completed(future_to_model): results.append(future.result()) return jsonify({'code': 0, 'data': results}) @model_chat_bp.route('/trained', methods=['POST']) def chat_trained_model(): """使用已训练模型进行对话推理""" import pymysql import yaml data = request.json model_name = data.get('model_name') # 模型名称 train_method = data.get('train_method', 'lora') # 训练方法: lora, full system_prompt = data.get('system_prompt', '') user_question = data.get('user_question') temperature = data.get('temperature', 0.7) max_tokens = data.get('max_tokens', 2048) if not model_name: return jsonify({'code': 1, 'message': '缺少模型名称'}) if not user_question: return jsonify({'code': 1, 'message': '缺少用户提问'}) # 获取项目根目录 PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml') with open(CONFIG_PATH, 'r', encoding='utf-8') as f: CONFIG = yaml.safe_load(f) # 获取基座模型路径 - 从数据库查询 model_name 对应的路径 try: db_config = CONFIG['database'] conn = pymysql.connect( host=db_config['host'], port=db_config['port'], user=db_config['username'], password=db_config['password'], database=db_config['name'], charset=db_config.get('charset', 'utf8mb4'), cursorclass=pymysql.cursors.DictCursor ) cursor = conn.cursor() cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,)) model_result = cursor.fetchone() conn.close() if not model_result or not model_result.get('path'): return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的配置'}) base_model_path = model_result['path'] except Exception as e: return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'}) # 训练后的模型路径 trained_model_path = f"/app/base/saves/{train_method}/{model_name}" if not os.path.exists(trained_model_path): return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'}) # 构建 llamafactory-cli chat 命令 work_dir = '/app/base' llamafactory_dir = '/app/base' # 准备消息 messages = [] if system_prompt: messages.append({'role': 'system', 'content': system_prompt}) messages.append({'role': 'user', 'content': user_question}) # 将消息转换为 JSON 字符串 messages_json = json.dumps(messages, ensure_ascii=False) # 构建 llamafactory-cli chat 命令 full_cmd = f'cd {llamafactory_dir} && export CUDA_VISIBLE_DEVICES=0 && echo \'{messages_json}\' | llamafactory-cli chat --model_name_or_path {base_model_path} --adapter_name_or_path {trained_model_path} --template llama3 --finetuning_type lora --temperature {temperature} --max_tokens {max_tokens}' try: # 执行命令 result = subprocess.run( full_cmd, shell=True, capture_output=True, text=True, timeout=120, cwd=work_dir ) output = result.stdout error = result.stderr # 解析输出,提取assistant回复 # llamafactory-cli chat 输出格式通常是: # <|im_start|>assistant # xxx # <|im_end|> assistant_content = '' if '<|im_start|>assistant' in output: parts = output.split('<|im_start|>assistant') if len(parts) > 1: content_part = parts[1].split('<|im_end|>')[0].strip() # 移除可能存在的换行前缀 content_part = content_part.lstrip('\n').strip() assistant_content = content_part elif result.returncode == 0: # 如果没有特殊标记,尝试提取最后一部分作为回复 lines = output.strip().split('\n') assistant_content = '\n'.join(lines).strip() else: return jsonify({'code': 1, 'message': f'推理失败: {error or output}'}) return jsonify({ 'code': 0, 'data': { 'model_name': model_name, 'train_method': train_method, 'response': assistant_content } }) except subprocess.TimeoutExpired: return jsonify({'code': 1, 'message': '推理超时,请稍后重试'}) except Exception as e: return jsonify({'code': 1, 'message': f'推理异常: {str(e)}'})