2026-01-22 14:09:25 +08:00
|
|
|
|
"""
|
|
|
|
|
|
模型对话 API 路由
|
|
|
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
|
|
|
import pymysql
|
|
|
|
|
|
import yaml
|
|
|
|
|
|
import requests
|
|
|
|
|
|
import concurrent.futures
|
2026-01-29 17:39:06 +08:00
|
|
|
|
import subprocess
|
2026-01-22 14:09:25 +08:00
|
|
|
|
from flask import Blueprint, request, jsonify
|
|
|
|
|
|
|
|
|
|
|
|
# 获取项目根目录
|
|
|
|
|
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
|
|
|
|
|
|
|
|
# 创建蓝图
|
|
|
|
|
|
model_chat_bp = Blueprint('model_chat', __name__, url_prefix='/api/model-chat')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_db_connection():
|
|
|
|
|
|
"""获取数据库连接"""
|
|
|
|
|
|
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
|
|
|
|
|
|
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
|
|
|
|
|
CONFIG = yaml.safe_load(f)
|
|
|
|
|
|
db_config = CONFIG['database']
|
|
|
|
|
|
return pymysql.connect(
|
|
|
|
|
|
host=db_config['host'],
|
|
|
|
|
|
port=db_config['port'],
|
|
|
|
|
|
user=db_config['username'],
|
|
|
|
|
|
password=db_config['password'],
|
|
|
|
|
|
database=db_config['name'],
|
|
|
|
|
|
charset=db_config.get('charset', 'utf8mb4'),
|
|
|
|
|
|
cursorclass=pymysql.cursors.DictCursor
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generic_get_by_id(table_name, id_val):
|
|
|
|
|
|
"""按ID查询"""
|
|
|
|
|
|
conn = get_db_connection()
|
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
cursor.execute(f"SELECT * FROM {table_name} WHERE id = %s", (id_val,))
|
|
|
|
|
|
result = cursor.fetchone()
|
|
|
|
|
|
cursor.close()
|
|
|
|
|
|
conn.close()
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def call_api_model(model_config, messages, temperature, max_tokens):
|
|
|
|
|
|
"""调用API模型(OpenAI兼容格式)"""
|
|
|
|
|
|
api_url = model_config.get('api_url')
|
|
|
|
|
|
api_key = model_config.get('api_key')
|
|
|
|
|
|
model_name = model_config.get('model_name', '')
|
|
|
|
|
|
|
|
|
|
|
|
# 构造OpenAI兼容的完整URL
|
|
|
|
|
|
# 支持: https://api.openai.com/v1/chat/completions 或 https://api.example.com/v1
|
|
|
|
|
|
# 如果URL已经包含 /chat/completions 则直接使用,否则追加
|
|
|
|
|
|
if '/chat/completions' in api_url:
|
|
|
|
|
|
full_url = api_url
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 去掉末尾的斜杠,然后追加 /chat/completions
|
|
|
|
|
|
full_url = api_url.rstrip('/') + '/chat/completions'
|
|
|
|
|
|
|
|
|
|
|
|
headers = {
|
|
|
|
|
|
'Content-Type': 'application/json',
|
|
|
|
|
|
'Authorization': f'Bearer {api_key}'
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
payload = {
|
|
|
|
|
|
'model': model_name,
|
|
|
|
|
|
'messages': messages,
|
|
|
|
|
|
'temperature': temperature,
|
|
|
|
|
|
'max_tokens': max_tokens
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
response = requests.post(full_url, headers=headers, json=payload, timeout=120)
|
|
|
|
|
|
response.raise_for_status()
|
|
|
|
|
|
result = response.json()
|
|
|
|
|
|
|
|
|
|
|
|
if 'choices' in result and len(result['choices']) > 0:
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': True,
|
|
|
|
|
|
'content': result['choices'][0]['message'].get('content', '')
|
|
|
|
|
|
}
|
|
|
|
|
|
return {'success': False, 'error': 'API返回格式异常'}
|
|
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
|
|
|
|
return {'success': False, 'error': str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def call_local_model(model_config, messages, temperature, max_tokens):
|
|
|
|
|
|
"""调用本地模型(通过vLLM OpenAI兼容API)"""
|
|
|
|
|
|
api_url = model_config.get('path') # 本地模型path字段存储API地址
|
|
|
|
|
|
model_name = model_config.get('model_name', '')
|
|
|
|
|
|
|
|
|
|
|
|
if not api_url:
|
|
|
|
|
|
return {'success': False, 'error': '本地模型API地址未配置'}
|
|
|
|
|
|
|
|
|
|
|
|
headers = {'Content-Type': 'application/json'}
|
|
|
|
|
|
|
|
|
|
|
|
payload = {
|
|
|
|
|
|
'model': model_name,
|
|
|
|
|
|
'messages': messages,
|
|
|
|
|
|
'temperature': temperature,
|
|
|
|
|
|
'max_tokens': max_tokens
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
response = requests.post(api_url, headers=headers, json=payload, timeout=120)
|
|
|
|
|
|
response.raise_for_status()
|
|
|
|
|
|
result = response.json()
|
|
|
|
|
|
|
|
|
|
|
|
if 'choices' in result and len(result['choices']) > 0:
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': True,
|
|
|
|
|
|
'content': result['choices'][0]['message'].get('content', '')
|
|
|
|
|
|
}
|
|
|
|
|
|
return {'success': False, 'error': 'API返回格式异常'}
|
|
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
|
|
|
|
return {'success': False, 'error': str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@model_chat_bp.route('', methods=['POST'])
|
|
|
|
|
|
def model_chat():
|
|
|
|
|
|
"""模型对话接口"""
|
|
|
|
|
|
data = request.json
|
|
|
|
|
|
model_id = data.get('model_id')
|
|
|
|
|
|
system_prompt = data.get('system_prompt', '')
|
|
|
|
|
|
user_question = data.get('user_question')
|
|
|
|
|
|
temperature = data.get('temperature', 0.7)
|
|
|
|
|
|
max_tokens = data.get('max_tokens', 2048)
|
|
|
|
|
|
|
|
|
|
|
|
if not model_id:
|
|
|
|
|
|
return jsonify({'code': 1, 'message': '缺少模型ID'})
|
|
|
|
|
|
if not user_question:
|
|
|
|
|
|
return jsonify({'code': 1, 'message': '缺少用户提问'})
|
|
|
|
|
|
|
|
|
|
|
|
# 获取模型配置
|
|
|
|
|
|
model = generic_get_by_id('model_manage', model_id)
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
return jsonify({'code': 1, 'message': '模型不存在'})
|
|
|
|
|
|
|
|
|
|
|
|
# 构建消息
|
|
|
|
|
|
messages = []
|
|
|
|
|
|
if system_prompt:
|
|
|
|
|
|
messages.append({'role': 'system', 'content': system_prompt})
|
|
|
|
|
|
messages.append({'role': 'user', 'content': user_question})
|
|
|
|
|
|
|
|
|
|
|
|
# 根据模型类型调用
|
|
|
|
|
|
if model.get('model_source') == 'api':
|
|
|
|
|
|
result = call_api_model(model, messages, temperature, max_tokens)
|
|
|
|
|
|
else:
|
|
|
|
|
|
result = call_local_model(model, messages, temperature, max_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
if result.get('success'):
|
|
|
|
|
|
return jsonify({
|
|
|
|
|
|
'code': 0,
|
|
|
|
|
|
'data': {
|
|
|
|
|
|
'model_id': model_id,
|
|
|
|
|
|
'model_name': model.get('name'),
|
|
|
|
|
|
'response': result['content']
|
|
|
|
|
|
}
|
|
|
|
|
|
})
|
|
|
|
|
|
else:
|
|
|
|
|
|
return jsonify({'code': 1, 'message': result.get('error', '调用失败')})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@model_chat_bp.route('/batch', methods=['POST'])
|
|
|
|
|
|
def model_chat_batch():
|
|
|
|
|
|
"""批量模型对话接口(并发调用多个模型)"""
|
|
|
|
|
|
data = request.json
|
|
|
|
|
|
model_ids = data.get('model_ids', [])
|
|
|
|
|
|
system_prompt = data.get('system_prompt', '')
|
|
|
|
|
|
user_question = data.get('user_question')
|
|
|
|
|
|
temperature = data.get('temperature', 0.7)
|
|
|
|
|
|
max_tokens = data.get('max_tokens', 2048)
|
|
|
|
|
|
|
|
|
|
|
|
if not model_ids:
|
|
|
|
|
|
return jsonify({'code': 1, 'message': '缺少模型ID列表'})
|
|
|
|
|
|
if not user_question:
|
|
|
|
|
|
return jsonify({'code': 1, 'message': '缺少用户提问'})
|
|
|
|
|
|
|
|
|
|
|
|
def call_single_model(model_id):
|
|
|
|
|
|
model = generic_get_by_id('model_manage', model_id)
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
return {'model_id': model_id, 'success': False, 'error': '模型不存在'}
|
|
|
|
|
|
|
|
|
|
|
|
messages = []
|
|
|
|
|
|
if system_prompt:
|
|
|
|
|
|
messages.append({'role': 'system', 'content': system_prompt})
|
|
|
|
|
|
messages.append({'role': 'user', 'content': user_question})
|
|
|
|
|
|
|
|
|
|
|
|
if model.get('model_source') == 'api':
|
|
|
|
|
|
result = call_api_model(model, messages, temperature, max_tokens)
|
|
|
|
|
|
else:
|
|
|
|
|
|
result = call_local_model(model, messages, temperature, max_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
'model_id': model_id,
|
|
|
|
|
|
'model_name': model.get('name'),
|
|
|
|
|
|
'success': result.get('success', False),
|
|
|
|
|
|
'response': result.get('content', ''),
|
|
|
|
|
|
'error': result.get('error', '')
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 并发调用所有模型
|
|
|
|
|
|
results = []
|
|
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=min(len(model_ids), 4)) as executor:
|
|
|
|
|
|
future_to_model = {executor.submit(call_single_model, mid): mid for mid in model_ids}
|
|
|
|
|
|
for future in concurrent.futures.as_completed(future_to_model):
|
|
|
|
|
|
results.append(future.result())
|
|
|
|
|
|
|
|
|
|
|
|
return jsonify({'code': 0, 'data': results})
|
2026-01-29 17:39:06 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@model_chat_bp.route('/trained', methods=['POST'])
|
|
|
|
|
|
def chat_trained_model():
|
|
|
|
|
|
"""使用已训练模型进行对话推理"""
|
|
|
|
|
|
import pymysql
|
|
|
|
|
|
import yaml
|
|
|
|
|
|
|
|
|
|
|
|
data = request.json
|
|
|
|
|
|
model_name = data.get('model_name') # 模型名称
|
|
|
|
|
|
train_method = data.get('train_method', 'lora') # 训练方法: lora, full
|
|
|
|
|
|
system_prompt = data.get('system_prompt', '')
|
|
|
|
|
|
user_question = data.get('user_question')
|
|
|
|
|
|
temperature = data.get('temperature', 0.7)
|
|
|
|
|
|
max_tokens = data.get('max_tokens', 2048)
|
|
|
|
|
|
|
|
|
|
|
|
if not model_name:
|
|
|
|
|
|
return jsonify({'code': 1, 'message': '缺少模型名称'})
|
|
|
|
|
|
if not user_question:
|
|
|
|
|
|
return jsonify({'code': 1, 'message': '缺少用户提问'})
|
|
|
|
|
|
|
|
|
|
|
|
# 获取项目根目录
|
|
|
|
|
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
|
|
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
|
|
|
|
|
|
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
|
|
|
|
|
CONFIG = yaml.safe_load(f)
|
|
|
|
|
|
|
|
|
|
|
|
# 获取基座模型路径 - 从数据库查询 model_name 对应的路径
|
|
|
|
|
|
try:
|
|
|
|
|
|
db_config = CONFIG['database']
|
|
|
|
|
|
conn = pymysql.connect(
|
|
|
|
|
|
host=db_config['host'],
|
|
|
|
|
|
port=db_config['port'],
|
|
|
|
|
|
user=db_config['username'],
|
|
|
|
|
|
password=db_config['password'],
|
|
|
|
|
|
database=db_config['name'],
|
|
|
|
|
|
charset=db_config.get('charset', 'utf8mb4'),
|
|
|
|
|
|
cursorclass=pymysql.cursors.DictCursor
|
|
|
|
|
|
)
|
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
|
|
|
|
|
|
model_result = cursor.fetchone()
|
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
|
|
if not model_result or not model_result.get('path'):
|
|
|
|
|
|
return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的配置'})
|
|
|
|
|
|
|
|
|
|
|
|
base_model_path = model_result['path']
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'})
|
|
|
|
|
|
|
|
|
|
|
|
# 训练后的模型路径
|
|
|
|
|
|
trained_model_path = f"/app/base/saves/{train_method}/{model_name}"
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(trained_model_path):
|
|
|
|
|
|
return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'})
|
|
|
|
|
|
|
|
|
|
|
|
# 构建 llamafactory-cli chat 命令
|
|
|
|
|
|
work_dir = '/app/base'
|
|
|
|
|
|
llamafactory_dir = '/app/base'
|
|
|
|
|
|
|
|
|
|
|
|
# 准备消息
|
|
|
|
|
|
messages = []
|
|
|
|
|
|
if system_prompt:
|
|
|
|
|
|
messages.append({'role': 'system', 'content': system_prompt})
|
|
|
|
|
|
messages.append({'role': 'user', 'content': user_question})
|
|
|
|
|
|
|
|
|
|
|
|
# 将消息转换为 JSON 字符串
|
|
|
|
|
|
messages_json = json.dumps(messages, ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
|
|
# 构建 llamafactory-cli chat 命令
|
|
|
|
|
|
full_cmd = f'cd {llamafactory_dir} && export CUDA_VISIBLE_DEVICES=0 && echo \'{messages_json}\' | llamafactory-cli chat --model_name_or_path {base_model_path} --adapter_name_or_path {trained_model_path} --template llama3 --finetuning_type lora --temperature {temperature} --max_tokens {max_tokens}'
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 执行命令
|
|
|
|
|
|
result = subprocess.run(
|
|
|
|
|
|
full_cmd,
|
|
|
|
|
|
shell=True,
|
|
|
|
|
|
capture_output=True,
|
|
|
|
|
|
text=True,
|
|
|
|
|
|
timeout=120,
|
|
|
|
|
|
cwd=work_dir
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
output = result.stdout
|
|
|
|
|
|
error = result.stderr
|
|
|
|
|
|
|
|
|
|
|
|
# 解析输出,提取assistant回复
|
|
|
|
|
|
# llamafactory-cli chat 输出格式通常是:
|
|
|
|
|
|
# <|im_start|>assistant
|
|
|
|
|
|
# xxx
|
|
|
|
|
|
# <|im_end|>
|
|
|
|
|
|
|
|
|
|
|
|
assistant_content = ''
|
|
|
|
|
|
if '<|im_start|>assistant' in output:
|
|
|
|
|
|
parts = output.split('<|im_start|>assistant')
|
|
|
|
|
|
if len(parts) > 1:
|
|
|
|
|
|
content_part = parts[1].split('<|im_end|>')[0].strip()
|
|
|
|
|
|
# 移除可能存在的换行前缀
|
|
|
|
|
|
content_part = content_part.lstrip('\n').strip()
|
|
|
|
|
|
assistant_content = content_part
|
|
|
|
|
|
elif result.returncode == 0:
|
|
|
|
|
|
# 如果没有特殊标记,尝试提取最后一部分作为回复
|
|
|
|
|
|
lines = output.strip().split('\n')
|
|
|
|
|
|
assistant_content = '\n'.join(lines).strip()
|
|
|
|
|
|
else:
|
|
|
|
|
|
return jsonify({'code': 1, 'message': f'推理失败: {error or output}'})
|
|
|
|
|
|
|
|
|
|
|
|
return jsonify({
|
|
|
|
|
|
'code': 0,
|
|
|
|
|
|
'data': {
|
|
|
|
|
|
'model_name': model_name,
|
|
|
|
|
|
'train_method': train_method,
|
|
|
|
|
|
'response': assistant_content
|
|
|
|
|
|
}
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
except subprocess.TimeoutExpired:
|
|
|
|
|
|
return jsonify({'code': 1, 'message': '推理超时,请稍后重试'})
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
return jsonify({'code': 1, 'message': f'推理异常: {str(e)}'})
|