diff --git a/src/api/__init__.py b/src/api/__init__.py index 84d6dc2..abb8722 100644 --- a/src/api/__init__.py +++ b/src/api/__init__.py @@ -3,9 +3,11 @@ API 路由包 """ from .datasets import datasets_bp from .model_manage import model_manage_bp +from .model_chat import model_chat_bp # 注册所有蓝图 def register_blueprints(app): """注册所有蓝图""" app.register_blueprint(datasets_bp) app.register_blueprint(model_manage_bp) + app.register_blueprint(model_chat_bp) diff --git a/src/api/model_chat.py b/src/api/model_chat.py new file mode 100644 index 0000000..7b20f04 --- /dev/null +++ b/src/api/model_chat.py @@ -0,0 +1,210 @@ +""" +模型对话 API 路由 +""" +import os +import pymysql +import yaml +import requests +import concurrent.futures +from flask import Blueprint, request, jsonify + +# 获取项目根目录 +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# 创建蓝图 +model_chat_bp = Blueprint('model_chat', __name__, url_prefix='/api/model-chat') + + +def get_db_connection(): + """获取数据库连接""" + CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml') + with open(CONFIG_PATH, 'r', encoding='utf-8') as f: + CONFIG = yaml.safe_load(f) + db_config = CONFIG['database'] + return pymysql.connect( + host=db_config['host'], + port=db_config['port'], + user=db_config['username'], + password=db_config['password'], + database=db_config['name'], + charset=db_config.get('charset', 'utf8mb4'), + cursorclass=pymysql.cursors.DictCursor + ) + + +def generic_get_by_id(table_name, id_val): + """按ID查询""" + conn = get_db_connection() + cursor = conn.cursor() + cursor.execute(f"SELECT * FROM {table_name} WHERE id = %s", (id_val,)) + result = cursor.fetchone() + cursor.close() + conn.close() + return result + + +def call_api_model(model_config, messages, temperature, max_tokens): + """调用API模型(OpenAI兼容格式)""" + api_url = model_config.get('api_url') + api_key = model_config.get('api_key') + model_name = model_config.get('model_name', '') + + # 构造OpenAI兼容的完整URL + # 支持: https://api.openai.com/v1/chat/completions 或 https://api.example.com/v1 + # 如果URL已经包含 /chat/completions 则直接使用,否则追加 + if '/chat/completions' in api_url: + full_url = api_url + else: + # 去掉末尾的斜杠,然后追加 /chat/completions + full_url = api_url.rstrip('/') + '/chat/completions' + + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {api_key}' + } + + payload = { + 'model': model_name, + 'messages': messages, + 'temperature': temperature, + 'max_tokens': max_tokens + } + + try: + response = requests.post(full_url, headers=headers, json=payload, timeout=120) + response.raise_for_status() + result = response.json() + + if 'choices' in result and len(result['choices']) > 0: + return { + 'success': True, + 'content': result['choices'][0]['message'].get('content', '') + } + return {'success': False, 'error': 'API返回格式异常'} + except requests.exceptions.RequestException as e: + return {'success': False, 'error': str(e)} + + +def call_local_model(model_config, messages, temperature, max_tokens): + """调用本地模型(通过vLLM OpenAI兼容API)""" + api_url = model_config.get('path') # 本地模型path字段存储API地址 + model_name = model_config.get('model_name', '') + + if not api_url: + return {'success': False, 'error': '本地模型API地址未配置'} + + headers = {'Content-Type': 'application/json'} + + payload = { + 'model': model_name, + 'messages': messages, + 'temperature': temperature, + 'max_tokens': max_tokens + } + + try: + response = requests.post(api_url, headers=headers, json=payload, timeout=120) + response.raise_for_status() + result = response.json() + + if 'choices' in result and len(result['choices']) > 0: + return { + 'success': True, + 'content': result['choices'][0]['message'].get('content', '') + } + return {'success': False, 'error': 'API返回格式异常'} + except requests.exceptions.RequestException as e: + return {'success': False, 'error': str(e)} + + +@model_chat_bp.route('', methods=['POST']) +def model_chat(): + """模型对话接口""" + data = request.json + model_id = data.get('model_id') + system_prompt = data.get('system_prompt', '') + user_question = data.get('user_question') + temperature = data.get('temperature', 0.7) + max_tokens = data.get('max_tokens', 2048) + + if not model_id: + return jsonify({'code': 1, 'message': '缺少模型ID'}) + if not user_question: + return jsonify({'code': 1, 'message': '缺少用户提问'}) + + # 获取模型配置 + model = generic_get_by_id('model_manage', model_id) + if not model: + return jsonify({'code': 1, 'message': '模型不存在'}) + + # 构建消息 + messages = [] + if system_prompt: + messages.append({'role': 'system', 'content': system_prompt}) + messages.append({'role': 'user', 'content': user_question}) + + # 根据模型类型调用 + if model.get('model_source') == 'api': + result = call_api_model(model, messages, temperature, max_tokens) + else: + result = call_local_model(model, messages, temperature, max_tokens) + + if result.get('success'): + return jsonify({ + 'code': 0, + 'data': { + 'model_id': model_id, + 'model_name': model.get('name'), + 'response': result['content'] + } + }) + else: + return jsonify({'code': 1, 'message': result.get('error', '调用失败')}) + + +@model_chat_bp.route('/batch', methods=['POST']) +def model_chat_batch(): + """批量模型对话接口(并发调用多个模型)""" + data = request.json + model_ids = data.get('model_ids', []) + system_prompt = data.get('system_prompt', '') + user_question = data.get('user_question') + temperature = data.get('temperature', 0.7) + max_tokens = data.get('max_tokens', 2048) + + if not model_ids: + return jsonify({'code': 1, 'message': '缺少模型ID列表'}) + if not user_question: + return jsonify({'code': 1, 'message': '缺少用户提问'}) + + def call_single_model(model_id): + model = generic_get_by_id('model_manage', model_id) + if not model: + return {'model_id': model_id, 'success': False, 'error': '模型不存在'} + + messages = [] + if system_prompt: + messages.append({'role': 'system', 'content': system_prompt}) + messages.append({'role': 'user', 'content': user_question}) + + if model.get('model_source') == 'api': + result = call_api_model(model, messages, temperature, max_tokens) + else: + result = call_local_model(model, messages, temperature, max_tokens) + + return { + 'model_id': model_id, + 'model_name': model.get('name'), + 'success': result.get('success', False), + 'response': result.get('content', ''), + 'error': result.get('error', '') + } + + # 并发调用所有模型 + results = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=min(len(model_ids), 4)) as executor: + future_to_model = {executor.submit(call_single_model, mid): mid for mid in model_ids} + for future in concurrent.futures.as_completed(future_to_model): + results.append(future.result()) + + return jsonify({'code': 0, 'data': results}) diff --git a/web/pages/main.html b/web/pages/main.html index b1309ac..f3a5e37 100644 --- a/web/pages/main.html +++ b/web/pages/main.html @@ -5,6 +5,7 @@