Files
YG_FT_Platform/src/api/model_chat.py

211 lines
7.2 KiB
Python
Raw Normal View History

"""
模型对话 API 路由
"""
import os
import pymysql
import yaml
import requests
import concurrent.futures
from flask import Blueprint, request, jsonify
# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 创建蓝图
model_chat_bp = Blueprint('model_chat', __name__, url_prefix='/api/model-chat')
def get_db_connection():
"""获取数据库连接"""
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
CONFIG = yaml.safe_load(f)
db_config = CONFIG['database']
return pymysql.connect(
host=db_config['host'],
port=db_config['port'],
user=db_config['username'],
password=db_config['password'],
database=db_config['name'],
charset=db_config.get('charset', 'utf8mb4'),
cursorclass=pymysql.cursors.DictCursor
)
def generic_get_by_id(table_name, id_val):
"""按ID查询"""
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute(f"SELECT * FROM {table_name} WHERE id = %s", (id_val,))
result = cursor.fetchone()
cursor.close()
conn.close()
return result
def call_api_model(model_config, messages, temperature, max_tokens):
"""调用API模型OpenAI兼容格式"""
api_url = model_config.get('api_url')
api_key = model_config.get('api_key')
model_name = model_config.get('model_name', '')
# 构造OpenAI兼容的完整URL
# 支持: https://api.openai.com/v1/chat/completions 或 https://api.example.com/v1
# 如果URL已经包含 /chat/completions 则直接使用,否则追加
if '/chat/completions' in api_url:
full_url = api_url
else:
# 去掉末尾的斜杠,然后追加 /chat/completions
full_url = api_url.rstrip('/') + '/chat/completions'
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
payload = {
'model': model_name,
'messages': messages,
'temperature': temperature,
'max_tokens': max_tokens
}
try:
response = requests.post(full_url, headers=headers, json=payload, timeout=120)
response.raise_for_status()
result = response.json()
if 'choices' in result and len(result['choices']) > 0:
return {
'success': True,
'content': result['choices'][0]['message'].get('content', '')
}
return {'success': False, 'error': 'API返回格式异常'}
except requests.exceptions.RequestException as e:
return {'success': False, 'error': str(e)}
def call_local_model(model_config, messages, temperature, max_tokens):
"""调用本地模型通过vLLM OpenAI兼容API"""
api_url = model_config.get('path') # 本地模型path字段存储API地址
model_name = model_config.get('model_name', '')
if not api_url:
return {'success': False, 'error': '本地模型API地址未配置'}
headers = {'Content-Type': 'application/json'}
payload = {
'model': model_name,
'messages': messages,
'temperature': temperature,
'max_tokens': max_tokens
}
try:
response = requests.post(api_url, headers=headers, json=payload, timeout=120)
response.raise_for_status()
result = response.json()
if 'choices' in result and len(result['choices']) > 0:
return {
'success': True,
'content': result['choices'][0]['message'].get('content', '')
}
return {'success': False, 'error': 'API返回格式异常'}
except requests.exceptions.RequestException as e:
return {'success': False, 'error': str(e)}
@model_chat_bp.route('', methods=['POST'])
def model_chat():
"""模型对话接口"""
data = request.json
model_id = data.get('model_id')
system_prompt = data.get('system_prompt', '')
user_question = data.get('user_question')
temperature = data.get('temperature', 0.7)
max_tokens = data.get('max_tokens', 2048)
if not model_id:
return jsonify({'code': 1, 'message': '缺少模型ID'})
if not user_question:
return jsonify({'code': 1, 'message': '缺少用户提问'})
# 获取模型配置
model = generic_get_by_id('model_manage', model_id)
if not model:
return jsonify({'code': 1, 'message': '模型不存在'})
# 构建消息
messages = []
if system_prompt:
messages.append({'role': 'system', 'content': system_prompt})
messages.append({'role': 'user', 'content': user_question})
# 根据模型类型调用
if model.get('model_source') == 'api':
result = call_api_model(model, messages, temperature, max_tokens)
else:
result = call_local_model(model, messages, temperature, max_tokens)
if result.get('success'):
return jsonify({
'code': 0,
'data': {
'model_id': model_id,
'model_name': model.get('name'),
'response': result['content']
}
})
else:
return jsonify({'code': 1, 'message': result.get('error', '调用失败')})
@model_chat_bp.route('/batch', methods=['POST'])
def model_chat_batch():
"""批量模型对话接口(并发调用多个模型)"""
data = request.json
model_ids = data.get('model_ids', [])
system_prompt = data.get('system_prompt', '')
user_question = data.get('user_question')
temperature = data.get('temperature', 0.7)
max_tokens = data.get('max_tokens', 2048)
if not model_ids:
return jsonify({'code': 1, 'message': '缺少模型ID列表'})
if not user_question:
return jsonify({'code': 1, 'message': '缺少用户提问'})
def call_single_model(model_id):
model = generic_get_by_id('model_manage', model_id)
if not model:
return {'model_id': model_id, 'success': False, 'error': '模型不存在'}
messages = []
if system_prompt:
messages.append({'role': 'system', 'content': system_prompt})
messages.append({'role': 'user', 'content': user_question})
if model.get('model_source') == 'api':
result = call_api_model(model, messages, temperature, max_tokens)
else:
result = call_local_model(model, messages, temperature, max_tokens)
return {
'model_id': model_id,
'model_name': model.get('name'),
'success': result.get('success', False),
'response': result.get('content', ''),
'error': result.get('error', '')
}
# 并发调用所有模型
results = []
with concurrent.futures.ThreadPoolExecutor(max_workers=min(len(model_ids), 4)) as executor:
future_to_model = {executor.submit(call_single_model, mid): mid for mid in model_ids}
for future in concurrent.futures.as_completed(future_to_model):
results.append(future.result())
return jsonify({'code': 0, 'data': results})