575 lines
22 KiB
Python
575 lines
22 KiB
Python
"""
|
||
模型对话 API 路由
|
||
"""
|
||
import os
|
||
import pymysql
|
||
import yaml
|
||
import json
|
||
import requests
|
||
import concurrent.futures
|
||
import subprocess
|
||
from flask import Blueprint, request, jsonify
|
||
|
||
# 获取项目根目录
|
||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
# 创建蓝图
|
||
model_chat_bp = Blueprint('model_chat', __name__, url_prefix='/api/model-chat')
|
||
|
||
|
||
def get_db_connection():
|
||
"""获取数据库连接"""
|
||
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
|
||
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||
CONFIG = yaml.safe_load(f)
|
||
db_config = CONFIG['database']
|
||
return pymysql.connect(
|
||
host=db_config['host'],
|
||
port=db_config['port'],
|
||
user=db_config['username'],
|
||
password=db_config['password'],
|
||
database=db_config['name'],
|
||
charset=db_config.get('charset', 'utf8mb4'),
|
||
cursorclass=pymysql.cursors.DictCursor
|
||
)
|
||
|
||
|
||
def generic_get_by_id(table_name, id_val):
|
||
"""按ID查询"""
|
||
conn = get_db_connection()
|
||
cursor = conn.cursor()
|
||
cursor.execute(f"SELECT * FROM {table_name} WHERE id = %s", (id_val,))
|
||
result = cursor.fetchone()
|
||
cursor.close()
|
||
conn.close()
|
||
return result
|
||
|
||
|
||
def call_api_model(model_config, messages, temperature, max_tokens):
|
||
"""调用API模型(OpenAI兼容格式)"""
|
||
api_url = model_config.get('api_url')
|
||
api_key = model_config.get('api_key')
|
||
model_name = model_config.get('model_name', '')
|
||
|
||
# 构造OpenAI兼容的完整URL
|
||
# 支持: https://api.openai.com/v1/chat/completions 或 https://api.example.com/v1
|
||
# 如果URL已经包含 /chat/completions 则直接使用,否则追加
|
||
if '/chat/completions' in api_url:
|
||
full_url = api_url
|
||
else:
|
||
# 去掉末尾的斜杠,然后追加 /chat/completions
|
||
full_url = api_url.rstrip('/') + '/chat/completions'
|
||
|
||
headers = {
|
||
'Content-Type': 'application/json',
|
||
'Authorization': f'Bearer {api_key}'
|
||
}
|
||
|
||
payload = {
|
||
'model': model_name,
|
||
'messages': messages,
|
||
'temperature': temperature,
|
||
'max_tokens': max_tokens
|
||
}
|
||
|
||
try:
|
||
response = requests.post(full_url, headers=headers, json=payload, timeout=120)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
if 'choices' in result and len(result['choices']) > 0:
|
||
return {
|
||
'success': True,
|
||
'content': result['choices'][0]['message'].get('content', '')
|
||
}
|
||
return {'success': False, 'error': 'API返回格式异常'}
|
||
except requests.exceptions.RequestException as e:
|
||
return {'success': False, 'error': str(e)}
|
||
|
||
|
||
def call_local_model(model_config, messages, temperature, max_tokens):
|
||
"""调用本地模型(通过vLLM OpenAI兼容API)"""
|
||
api_url = model_config.get('path') # 本地模型path字段存储API地址
|
||
model_name = model_config.get('model_name', '')
|
||
|
||
if not api_url:
|
||
return {'success': False, 'error': '本地模型API地址未配置'}
|
||
|
||
headers = {'Content-Type': 'application/json'}
|
||
|
||
payload = {
|
||
'model': model_name,
|
||
'messages': messages,
|
||
'temperature': temperature,
|
||
'max_tokens': max_tokens
|
||
}
|
||
|
||
try:
|
||
response = requests.post(api_url, headers=headers, json=payload, timeout=120)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
if 'choices' in result and len(result['choices']) > 0:
|
||
return {
|
||
'success': True,
|
||
'content': result['choices'][0]['message'].get('content', '')
|
||
}
|
||
return {'success': False, 'error': 'API返回格式异常'}
|
||
except requests.exceptions.RequestException as e:
|
||
return {'success': False, 'error': str(e)}
|
||
|
||
|
||
@model_chat_bp.route('', methods=['POST'])
|
||
def model_chat():
|
||
"""模型对话接口"""
|
||
data = request.json
|
||
model_id = data.get('model_id')
|
||
system_prompt = data.get('system_prompt', '')
|
||
user_question = data.get('user_question')
|
||
temperature = data.get('temperature', 0.7)
|
||
max_tokens = data.get('max_tokens', 2048)
|
||
|
||
if not model_id:
|
||
return jsonify({'code': 1, 'message': '缺少模型ID'})
|
||
if not user_question:
|
||
return jsonify({'code': 1, 'message': '缺少用户提问'})
|
||
|
||
# 获取模型配置
|
||
model = generic_get_by_id('model_manage', model_id)
|
||
if not model:
|
||
return jsonify({'code': 1, 'message': '模型不存在'})
|
||
|
||
# 构建消息
|
||
messages = []
|
||
if system_prompt:
|
||
messages.append({'role': 'system', 'content': system_prompt})
|
||
messages.append({'role': 'user', 'content': user_question})
|
||
|
||
# 根据模型类型调用
|
||
if model.get('model_source') == 'api':
|
||
result = call_api_model(model, messages, temperature, max_tokens)
|
||
else:
|
||
result = call_local_model(model, messages, temperature, max_tokens)
|
||
|
||
if result.get('success'):
|
||
return jsonify({
|
||
'code': 0,
|
||
'data': {
|
||
'model_id': model_id,
|
||
'model_name': model.get('name'),
|
||
'response': result['content']
|
||
}
|
||
})
|
||
else:
|
||
return jsonify({'code': 1, 'message': result.get('error', '调用失败')})
|
||
|
||
|
||
@model_chat_bp.route('/batch', methods=['POST'])
|
||
def model_chat_batch():
|
||
"""批量模型对话接口(并发调用多个模型)"""
|
||
data = request.json
|
||
model_ids = data.get('model_ids', [])
|
||
system_prompt = data.get('system_prompt', '')
|
||
user_question = data.get('user_question')
|
||
temperature = data.get('temperature', 0.7)
|
||
max_tokens = data.get('max_tokens', 2048)
|
||
|
||
if not model_ids:
|
||
return jsonify({'code': 1, 'message': '缺少模型ID列表'})
|
||
if not user_question:
|
||
return jsonify({'code': 1, 'message': '缺少用户提问'})
|
||
|
||
def call_single_model(model_id):
|
||
model = generic_get_by_id('model_manage', model_id)
|
||
if not model:
|
||
return {'model_id': model_id, 'success': False, 'error': '模型不存在'}
|
||
|
||
messages = []
|
||
if system_prompt:
|
||
messages.append({'role': 'system', 'content': system_prompt})
|
||
messages.append({'role': 'user', 'content': user_question})
|
||
|
||
if model.get('model_source') == 'api':
|
||
result = call_api_model(model, messages, temperature, max_tokens)
|
||
else:
|
||
result = call_local_model(model, messages, temperature, max_tokens)
|
||
|
||
return {
|
||
'model_id': model_id,
|
||
'model_name': model.get('name'),
|
||
'success': result.get('success', False),
|
||
'response': result.get('content', ''),
|
||
'error': result.get('error', '')
|
||
}
|
||
|
||
# 并发调用所有模型
|
||
results = []
|
||
with concurrent.futures.ThreadPoolExecutor(max_workers=min(len(model_ids), 4)) as executor:
|
||
future_to_model = {executor.submit(call_single_model, mid): mid for mid in model_ids}
|
||
for future in concurrent.futures.as_completed(future_to_model):
|
||
results.append(future.result())
|
||
|
||
return jsonify({'code': 0, 'data': results})
|
||
|
||
|
||
@model_chat_bp.route('/trained/preload', methods=['POST'])
|
||
def preload_trained_model():
|
||
"""预加载已训练模型(使用 llamafactory)"""
|
||
import sys as sys_module
|
||
import pymysql
|
||
import yaml
|
||
import logging
|
||
logger = logging.getLogger(__name__)
|
||
|
||
data = request.json
|
||
model_name = data.get('model_name') # 模型名称
|
||
train_method = data.get('train_method', 'lora') # 训练方法: lora, full
|
||
base_model_path = data.get('base_model_path') # 前端传递的基座模型路径
|
||
|
||
if not model_name:
|
||
return jsonify({'code': 1, 'message': '缺少模型名称'})
|
||
|
||
logger.info(f"[PRELOAD] 开始预加载模型: {model_name}, 方法: {train_method}")
|
||
logger.info(f"[PRELOAD] 前端传递的基座模型路径: {base_model_path}")
|
||
|
||
# 获取项目根目录
|
||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
|
||
try:
|
||
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||
CONFIG = yaml.safe_load(f)
|
||
except Exception as e:
|
||
return jsonify({'code': 1, 'message': f'读取配置失败: {str(e)}'})
|
||
|
||
# 优先使用前端传递的基座模型路径,否则从数据库查询
|
||
if not base_model_path:
|
||
try:
|
||
db_config = CONFIG['database']
|
||
conn = pymysql.connect(
|
||
host=db_config['host'],
|
||
port=db_config['port'],
|
||
user=db_config['username'],
|
||
password=db_config['password'],
|
||
database=db_config['name'],
|
||
charset=db_config.get('charset', 'utf8mb4'),
|
||
cursorclass=pymysql.cursors.DictCursor
|
||
)
|
||
cursor = conn.cursor()
|
||
|
||
# 优先从训练任务表查询基座模型
|
||
logger.info(f"[PRELOAD] 尝试从fine_tune表查询模型: {model_name}")
|
||
cursor.execute("""
|
||
SELECT base_model, output_model_name FROM fine_tune
|
||
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
|
||
LIMIT 1
|
||
""", (f'%/{model_name}', f'%{model_name}%'))
|
||
ft_result = cursor.fetchone()
|
||
logger.info(f"[PRELOAD] fine_tune查询结果: {ft_result}")
|
||
|
||
if ft_result and ft_result.get('base_model'):
|
||
base_model_val = ft_result['base_model']
|
||
logger.info(f"[PRELOAD] base_model_val: {base_model_val}")
|
||
# 如果是数字ID,查询模型管理表获取路径
|
||
if str(base_model_val).isdigit():
|
||
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
|
||
model_result = cursor.fetchone()
|
||
logger.info(f"[PRELOAD] model_manage查询结果(数字ID): {model_result}")
|
||
if model_result:
|
||
base_model_path = model_result.get('path')
|
||
else:
|
||
# 直接是路径
|
||
base_model_path = base_model_val
|
||
|
||
# 如果训练任务表没找到,尝试从模型管理表按名称查询
|
||
if not base_model_path:
|
||
logger.info(f"[PRELOAD] 尝试从model_manage表查询...")
|
||
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
|
||
model_result = cursor.fetchone()
|
||
logger.info(f"[PRELOAD] model_manage查询结果: {model_result}")
|
||
if model_result:
|
||
base_model_path = model_result.get('path')
|
||
|
||
conn.close()
|
||
|
||
if not base_model_path:
|
||
logger.error(f"[PRELOAD] 未找到模型 {model_name} 的基座模型配置")
|
||
return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'})
|
||
except Exception as e:
|
||
logger.error(f"[PRELOAD] 查询模型配置失败: {e}")
|
||
return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'})
|
||
else:
|
||
logger.info(f"[PRELOAD] 使用前端传递的基座模型路径: {base_model_path}")
|
||
|
||
# 训练后的模型路径
|
||
trained_model_path = f"/app/base/saves/{train_method}/{model_name}"
|
||
|
||
# 检查路径是否存在(兼容Windows和Linux)
|
||
if not os.path.exists(trained_model_path):
|
||
logger.warning(f"[PRELOAD] 训练模型路径不存在: {trained_model_path}")
|
||
# 尝试查找适配器文件来确定模型是否存在
|
||
adapter_path = os.path.join(trained_model_path, 'adapter_model.bin')
|
||
safetensors_path = os.path.join(trained_model_path, 'model.safetensors')
|
||
if not os.path.exists(adapter_path) and not os.path.exists(safetensors_path):
|
||
return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'})
|
||
|
||
# 预热消息 - 使用一个简单的问候语来加载模型
|
||
work_dir = '/app/base'
|
||
warmup_messages = [{'role': 'system', 'content': 'You are a helpful assistant.'}, {'role': 'user', 'content': 'Hello'}]
|
||
|
||
# 根据训练方法选择 finetuning_type
|
||
finetuning_type = 'lora' if train_method == 'lora' else 'full'
|
||
|
||
# 构建 llamafactory 预热脚本
|
||
preload_script = f'''
|
||
import sys
|
||
import json
|
||
import logging
|
||
logging.basicConfig(level=logging.WARNING)
|
||
|
||
from llmtuner import ChatModel
|
||
|
||
def main():
|
||
chat_model = ChatModel({{
|
||
"model_name_or_path": "{base_model_path}",
|
||
"adapter_name_or_path": "{trained_model_path}",
|
||
"template": "llama3",
|
||
"finetuning_type": "{finetuning_type}",
|
||
"temperature": 0.1,
|
||
"max_new_tokens": 1
|
||
}})
|
||
|
||
messages = {json.dumps(warmup_messages, ensure_ascii=False)}
|
||
|
||
# 执行推理以加载模型
|
||
response = chat_model.chat(messages)
|
||
print("Model loaded successfully")
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
'''
|
||
|
||
try:
|
||
work_dir = '/app/base'
|
||
script_path = os.path.join(work_dir, 'temp_preload.py')
|
||
|
||
with open(script_path, 'w', encoding='utf-8') as f:
|
||
f.write(preload_script)
|
||
|
||
# 设置环境变量,包括 PYTHONPATH 以便找到 llamafactory
|
||
import sys as sys_module
|
||
env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'}
|
||
# 尝试查找 llamafactory 目录并添加到 PYTHONPATH
|
||
llm_factory_paths = ['/app/base', '/app', '/app/base/src/llamafactory']
|
||
for path in llm_factory_paths:
|
||
if os.path.exists(path) and os.path.exists(os.path.join(path, 'llmtuner')):
|
||
env['PYTHONPATH'] = path
|
||
logger.info(f"[PRELOAD] 设置 PYTHONPATH={path}")
|
||
break
|
||
|
||
logger.info(f"[PRELOAD] 执行预热脚本...")
|
||
result = subprocess.run(
|
||
[sys_module.executable, 'temp_preload.py'],
|
||
capture_output=True,
|
||
text=True,
|
||
timeout=180,
|
||
cwd=work_dir,
|
||
env=env
|
||
)
|
||
|
||
os.remove(script_path)
|
||
|
||
logger.info(f"[PRELOAD] 命令返回码: {result.returncode}")
|
||
logger.info(f"[PRELOAD] stdout: {result.stdout[:500] if result.stdout else 'empty'}")
|
||
logger.info(f"[PRELOAD] stderr: {result.stderr[:500] if result.stderr else 'empty'}")
|
||
|
||
if result.returncode == 0:
|
||
return jsonify({
|
||
'code': 0,
|
||
'message': '模型预加载成功',
|
||
'data': {
|
||
'model_name': model_name,
|
||
'train_method': train_method,
|
||
'base_model': base_model_path
|
||
}
|
||
})
|
||
else:
|
||
error_msg = result.stderr.strip() if result.stderr else result.stdout.strip()
|
||
if not error_msg:
|
||
error_msg = f'命令执行失败,返回码: {result.returncode}'
|
||
return jsonify({'code': 1, 'message': f'预加载失败: {error_msg}'})
|
||
|
||
except subprocess.TimeoutExpired:
|
||
logger.error("[PRELOAD] 预加载超时")
|
||
return jsonify({'code': 1, 'message': '预加载超时,请稍后重试'})
|
||
except Exception as e:
|
||
logger.error(f"[PRELOAD] 预加载异常: {str(e)}")
|
||
return jsonify({'code': 1, 'message': f'预加载异常: {str(e)}'})
|
||
|
||
|
||
@model_chat_bp.route('/trained', methods=['POST'])
|
||
def chat_trained_model():
|
||
"""使用已训练模型进行对话推理"""
|
||
import pymysql
|
||
import yaml
|
||
|
||
data = request.json
|
||
model_name = data.get('model_name') # 模型名称
|
||
train_method = data.get('train_method', 'lora') # 训练方法: lora, full
|
||
base_model_path = data.get('base_model_path') # 前端传递的基座模型路径
|
||
system_prompt = data.get('system_prompt', '')
|
||
user_question = data.get('user_question')
|
||
temperature = data.get('temperature', 0.7)
|
||
max_tokens = data.get('max_tokens', 2048)
|
||
|
||
if not model_name:
|
||
return jsonify({'code': 1, 'message': '缺少模型名称'})
|
||
if not user_question:
|
||
return jsonify({'code': 1, 'message': '缺少用户提问'})
|
||
|
||
# 获取项目根目录
|
||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
|
||
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||
CONFIG = yaml.safe_load(f)
|
||
|
||
# 优先使用前端传递的基座模型路径,否则从数据库查询
|
||
if not base_model_path:
|
||
try:
|
||
db_config = CONFIG['database']
|
||
conn = pymysql.connect(
|
||
host=db_config['host'],
|
||
port=db_config['port'],
|
||
user=db_config['username'],
|
||
password=db_config['password'],
|
||
database=db_config['name'],
|
||
charset=db_config.get('charset', 'utf8mb4'),
|
||
cursorclass=pymysql.cursors.DictCursor
|
||
)
|
||
cursor = conn.cursor()
|
||
|
||
# 优先从训练任务表查询基座模型
|
||
cursor.execute("""
|
||
SELECT base_model, output_model_name FROM fine_tune
|
||
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
|
||
LIMIT 1
|
||
""", (f'%/{model_name}', f'%{model_name}%'))
|
||
ft_result = cursor.fetchone()
|
||
|
||
if ft_result and ft_result.get('base_model'):
|
||
base_model_val = ft_result['base_model']
|
||
# 如果是数字ID,查询模型管理表获取路径
|
||
if str(base_model_val).isdigit():
|
||
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
|
||
model_result = cursor.fetchone()
|
||
if model_result:
|
||
base_model_path = model_result.get('path')
|
||
else:
|
||
# 直接是路径
|
||
base_model_path = base_model_val
|
||
|
||
# 如果训练任务表没找到,尝试从模型管理表按名称查询
|
||
if not base_model_path:
|
||
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
|
||
model_result = cursor.fetchone()
|
||
if model_result:
|
||
base_model_path = model_result.get('path')
|
||
|
||
conn.close()
|
||
|
||
if not base_model_path:
|
||
return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'})
|
||
except Exception as e:
|
||
return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'})
|
||
|
||
# 训练后的模型路径
|
||
trained_model_path = f"/app/base/saves/{train_method}/{model_name}"
|
||
|
||
# 检查路径是否存在(兼容Windows和Linux)
|
||
if not os.path.exists(trained_model_path):
|
||
adapter_path = os.path.join(trained_model_path, 'adapter_model.bin')
|
||
safetensors_path = os.path.join(trained_model_path, 'model.safetensors')
|
||
if not os.path.exists(adapter_path) and not os.path.exists(safetensors_path):
|
||
return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'})
|
||
|
||
# 准备消息
|
||
messages = []
|
||
if system_prompt:
|
||
messages.append({'role': 'system', 'content': system_prompt})
|
||
messages.append({'role': 'user', 'content': user_question})
|
||
|
||
# 构建 llamafactory 推理脚本
|
||
inference_script = f'''
|
||
import sys
|
||
import json
|
||
import logging
|
||
logging.basicConfig(level=logging.WARNING)
|
||
|
||
from llmtuner import ChatModel
|
||
|
||
def main():
|
||
chat_model = ChatModel({{
|
||
"model_name_or_path": "{base_model_path}",
|
||
"adapter_name_or_path": "{trained_model_path}",
|
||
"template": "llama3",
|
||
"finetuning_type": "lora",
|
||
"temperature": {temperature},
|
||
"max_new_tokens": {max_tokens}
|
||
}})
|
||
|
||
messages = {json.dumps(messages, ensure_ascii=False)}
|
||
|
||
response = chat_model.chat(messages)
|
||
print(response)
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
'''
|
||
|
||
# 写入临时脚本
|
||
work_dir = '/app/base'
|
||
script_path = os.path.join(work_dir, 'temp_inference.py')
|
||
|
||
try:
|
||
with open(script_path, 'w', encoding='utf-8') as f:
|
||
f.write(inference_script)
|
||
|
||
# 设置环境变量,包括 PYTHONPATH 以便找到 llamafactory
|
||
import sys as sys_module
|
||
env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'}
|
||
for path in ['/app/base', '/app', '/app/base/src/llamafactory']:
|
||
if os.path.exists(path) and os.path.exists(os.path.join(path, 'llmtuner')):
|
||
env['PYTHONPATH'] = path
|
||
break
|
||
|
||
# 执行推理脚本
|
||
result = subprocess.run(
|
||
[sys_module.executable, 'temp_inference.py'],
|
||
capture_output=True,
|
||
text=True,
|
||
timeout=180,
|
||
cwd=work_dir,
|
||
env=env
|
||
)
|
||
|
||
# 清理临时脚本
|
||
os.remove(script_path)
|
||
|
||
if result.returncode == 0:
|
||
assistant_content = result.stdout.strip()
|
||
return jsonify({
|
||
'code': 0,
|
||
'data': {
|
||
'model_name': model_name,
|
||
'train_method': train_method,
|
||
'response': assistant_content
|
||
}
|
||
})
|
||
else:
|
||
error_msg = result.stderr.strip() if result.stderr else result.stdout.strip()
|
||
return jsonify({'code': 1, 'message': f'推理失败: {error_msg}'})
|
||
|
||
except subprocess.TimeoutExpired:
|
||
return jsonify({'code': 1, 'message': '推理超时,请稍后重试'})
|
||
except Exception as e:
|
||
return jsonify({'code': 1, 'message': f'推理异常: {str(e)}'})
|