""" 模型对话 API 路由 """ import os import pymysql import yaml import requests import concurrent.futures from flask import Blueprint, request, jsonify # 获取项目根目录 PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # 创建蓝图 model_chat_bp = Blueprint('model_chat', __name__, url_prefix='/api/model-chat') def get_db_connection(): """获取数据库连接""" CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml') with open(CONFIG_PATH, 'r', encoding='utf-8') as f: CONFIG = yaml.safe_load(f) db_config = CONFIG['database'] return pymysql.connect( host=db_config['host'], port=db_config['port'], user=db_config['username'], password=db_config['password'], database=db_config['name'], charset=db_config.get('charset', 'utf8mb4'), cursorclass=pymysql.cursors.DictCursor ) def generic_get_by_id(table_name, id_val): """按ID查询""" conn = get_db_connection() cursor = conn.cursor() cursor.execute(f"SELECT * FROM {table_name} WHERE id = %s", (id_val,)) result = cursor.fetchone() cursor.close() conn.close() return result def call_api_model(model_config, messages, temperature, max_tokens): """调用API模型(OpenAI兼容格式)""" api_url = model_config.get('api_url') api_key = model_config.get('api_key') model_name = model_config.get('model_name', '') # 构造OpenAI兼容的完整URL # 支持: https://api.openai.com/v1/chat/completions 或 https://api.example.com/v1 # 如果URL已经包含 /chat/completions 则直接使用,否则追加 if '/chat/completions' in api_url: full_url = api_url else: # 去掉末尾的斜杠,然后追加 /chat/completions full_url = api_url.rstrip('/') + '/chat/completions' headers = { 'Content-Type': 'application/json', 'Authorization': f'Bearer {api_key}' } payload = { 'model': model_name, 'messages': messages, 'temperature': temperature, 'max_tokens': max_tokens } try: response = requests.post(full_url, headers=headers, json=payload, timeout=120) response.raise_for_status() result = response.json() if 'choices' in result and len(result['choices']) > 0: return { 'success': True, 'content': result['choices'][0]['message'].get('content', '') } return {'success': False, 'error': 'API返回格式异常'} except requests.exceptions.RequestException as e: return {'success': False, 'error': str(e)} def call_local_model(model_config, messages, temperature, max_tokens): """调用本地模型(通过vLLM OpenAI兼容API)""" api_url = model_config.get('path') # 本地模型path字段存储API地址 model_name = model_config.get('model_name', '') if not api_url: return {'success': False, 'error': '本地模型API地址未配置'} headers = {'Content-Type': 'application/json'} payload = { 'model': model_name, 'messages': messages, 'temperature': temperature, 'max_tokens': max_tokens } try: response = requests.post(api_url, headers=headers, json=payload, timeout=120) response.raise_for_status() result = response.json() if 'choices' in result and len(result['choices']) > 0: return { 'success': True, 'content': result['choices'][0]['message'].get('content', '') } return {'success': False, 'error': 'API返回格式异常'} except requests.exceptions.RequestException as e: return {'success': False, 'error': str(e)} @model_chat_bp.route('', methods=['POST']) def model_chat(): """模型对话接口""" data = request.json model_id = data.get('model_id') system_prompt = data.get('system_prompt', '') user_question = data.get('user_question') temperature = data.get('temperature', 0.7) max_tokens = data.get('max_tokens', 2048) if not model_id: return jsonify({'code': 1, 'message': '缺少模型ID'}) if not user_question: return jsonify({'code': 1, 'message': '缺少用户提问'}) # 获取模型配置 model = generic_get_by_id('model_manage', model_id) if not model: return jsonify({'code': 1, 'message': '模型不存在'}) # 构建消息 messages = [] if system_prompt: messages.append({'role': 'system', 'content': system_prompt}) messages.append({'role': 'user', 'content': user_question}) # 根据模型类型调用 if model.get('model_source') == 'api': result = call_api_model(model, messages, temperature, max_tokens) else: result = call_local_model(model, messages, temperature, max_tokens) if result.get('success'): return jsonify({ 'code': 0, 'data': { 'model_id': model_id, 'model_name': model.get('name'), 'response': result['content'] } }) else: return jsonify({'code': 1, 'message': result.get('error', '调用失败')}) @model_chat_bp.route('/batch', methods=['POST']) def model_chat_batch(): """批量模型对话接口(并发调用多个模型)""" data = request.json model_ids = data.get('model_ids', []) system_prompt = data.get('system_prompt', '') user_question = data.get('user_question') temperature = data.get('temperature', 0.7) max_tokens = data.get('max_tokens', 2048) if not model_ids: return jsonify({'code': 1, 'message': '缺少模型ID列表'}) if not user_question: return jsonify({'code': 1, 'message': '缺少用户提问'}) def call_single_model(model_id): model = generic_get_by_id('model_manage', model_id) if not model: return {'model_id': model_id, 'success': False, 'error': '模型不存在'} messages = [] if system_prompt: messages.append({'role': 'system', 'content': system_prompt}) messages.append({'role': 'user', 'content': user_question}) if model.get('model_source') == 'api': result = call_api_model(model, messages, temperature, max_tokens) else: result = call_local_model(model, messages, temperature, max_tokens) return { 'model_id': model_id, 'model_name': model.get('name'), 'success': result.get('success', False), 'response': result.get('content', ''), 'error': result.get('error', '') } # 并发调用所有模型 results = [] with concurrent.futures.ThreadPoolExecutor(max_workers=min(len(model_ids), 4)) as executor: future_to_model = {executor.submit(call_single_model, mid): mid for mid in model_ids} for future in concurrent.futures.as_completed(future_to_model): results.append(future.result()) return jsonify({'code': 0, 'data': results})