Files
YG_FT_Platform/src/api/model_chat.py
WIN-JHFT4D3SIVT\caoxiaozhu 85710d865c 1. 增加了合并权重
2. 修改了一些列表展示的bug
2026-01-29 23:10:21 +08:00

575 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
模型对话 API 路由
"""
import os
import pymysql
import yaml
import json
import requests
import concurrent.futures
import subprocess
from flask import Blueprint, request, jsonify
# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 创建蓝图
model_chat_bp = Blueprint('model_chat', __name__, url_prefix='/api/model-chat')
def get_db_connection():
"""获取数据库连接"""
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
CONFIG = yaml.safe_load(f)
db_config = CONFIG['database']
return pymysql.connect(
host=db_config['host'],
port=db_config['port'],
user=db_config['username'],
password=db_config['password'],
database=db_config['name'],
charset=db_config.get('charset', 'utf8mb4'),
cursorclass=pymysql.cursors.DictCursor
)
def generic_get_by_id(table_name, id_val):
"""按ID查询"""
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute(f"SELECT * FROM {table_name} WHERE id = %s", (id_val,))
result = cursor.fetchone()
cursor.close()
conn.close()
return result
def call_api_model(model_config, messages, temperature, max_tokens):
"""调用API模型OpenAI兼容格式"""
api_url = model_config.get('api_url')
api_key = model_config.get('api_key')
model_name = model_config.get('model_name', '')
# 构造OpenAI兼容的完整URL
# 支持: https://api.openai.com/v1/chat/completions 或 https://api.example.com/v1
# 如果URL已经包含 /chat/completions 则直接使用,否则追加
if '/chat/completions' in api_url:
full_url = api_url
else:
# 去掉末尾的斜杠,然后追加 /chat/completions
full_url = api_url.rstrip('/') + '/chat/completions'
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
payload = {
'model': model_name,
'messages': messages,
'temperature': temperature,
'max_tokens': max_tokens
}
try:
response = requests.post(full_url, headers=headers, json=payload, timeout=120)
response.raise_for_status()
result = response.json()
if 'choices' in result and len(result['choices']) > 0:
return {
'success': True,
'content': result['choices'][0]['message'].get('content', '')
}
return {'success': False, 'error': 'API返回格式异常'}
except requests.exceptions.RequestException as e:
return {'success': False, 'error': str(e)}
def call_local_model(model_config, messages, temperature, max_tokens):
"""调用本地模型通过vLLM OpenAI兼容API"""
api_url = model_config.get('path') # 本地模型path字段存储API地址
model_name = model_config.get('model_name', '')
if not api_url:
return {'success': False, 'error': '本地模型API地址未配置'}
headers = {'Content-Type': 'application/json'}
payload = {
'model': model_name,
'messages': messages,
'temperature': temperature,
'max_tokens': max_tokens
}
try:
response = requests.post(api_url, headers=headers, json=payload, timeout=120)
response.raise_for_status()
result = response.json()
if 'choices' in result and len(result['choices']) > 0:
return {
'success': True,
'content': result['choices'][0]['message'].get('content', '')
}
return {'success': False, 'error': 'API返回格式异常'}
except requests.exceptions.RequestException as e:
return {'success': False, 'error': str(e)}
@model_chat_bp.route('', methods=['POST'])
def model_chat():
"""模型对话接口"""
data = request.json
model_id = data.get('model_id')
system_prompt = data.get('system_prompt', '')
user_question = data.get('user_question')
temperature = data.get('temperature', 0.7)
max_tokens = data.get('max_tokens', 2048)
if not model_id:
return jsonify({'code': 1, 'message': '缺少模型ID'})
if not user_question:
return jsonify({'code': 1, 'message': '缺少用户提问'})
# 获取模型配置
model = generic_get_by_id('model_manage', model_id)
if not model:
return jsonify({'code': 1, 'message': '模型不存在'})
# 构建消息
messages = []
if system_prompt:
messages.append({'role': 'system', 'content': system_prompt})
messages.append({'role': 'user', 'content': user_question})
# 根据模型类型调用
if model.get('model_source') == 'api':
result = call_api_model(model, messages, temperature, max_tokens)
else:
result = call_local_model(model, messages, temperature, max_tokens)
if result.get('success'):
return jsonify({
'code': 0,
'data': {
'model_id': model_id,
'model_name': model.get('name'),
'response': result['content']
}
})
else:
return jsonify({'code': 1, 'message': result.get('error', '调用失败')})
@model_chat_bp.route('/batch', methods=['POST'])
def model_chat_batch():
"""批量模型对话接口(并发调用多个模型)"""
data = request.json
model_ids = data.get('model_ids', [])
system_prompt = data.get('system_prompt', '')
user_question = data.get('user_question')
temperature = data.get('temperature', 0.7)
max_tokens = data.get('max_tokens', 2048)
if not model_ids:
return jsonify({'code': 1, 'message': '缺少模型ID列表'})
if not user_question:
return jsonify({'code': 1, 'message': '缺少用户提问'})
def call_single_model(model_id):
model = generic_get_by_id('model_manage', model_id)
if not model:
return {'model_id': model_id, 'success': False, 'error': '模型不存在'}
messages = []
if system_prompt:
messages.append({'role': 'system', 'content': system_prompt})
messages.append({'role': 'user', 'content': user_question})
if model.get('model_source') == 'api':
result = call_api_model(model, messages, temperature, max_tokens)
else:
result = call_local_model(model, messages, temperature, max_tokens)
return {
'model_id': model_id,
'model_name': model.get('name'),
'success': result.get('success', False),
'response': result.get('content', ''),
'error': result.get('error', '')
}
# 并发调用所有模型
results = []
with concurrent.futures.ThreadPoolExecutor(max_workers=min(len(model_ids), 4)) as executor:
future_to_model = {executor.submit(call_single_model, mid): mid for mid in model_ids}
for future in concurrent.futures.as_completed(future_to_model):
results.append(future.result())
return jsonify({'code': 0, 'data': results})
@model_chat_bp.route('/trained/preload', methods=['POST'])
def preload_trained_model():
"""预加载已训练模型(使用 llamafactory"""
import sys as sys_module
import pymysql
import yaml
import logging
logger = logging.getLogger(__name__)
data = request.json
model_name = data.get('model_name') # 模型名称
train_method = data.get('train_method', 'lora') # 训练方法: lora, full
base_model_path = data.get('base_model_path') # 前端传递的基座模型路径
if not model_name:
return jsonify({'code': 1, 'message': '缺少模型名称'})
logger.info(f"[PRELOAD] 开始预加载模型: {model_name}, 方法: {train_method}")
logger.info(f"[PRELOAD] 前端传递的基座模型路径: {base_model_path}")
# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
try:
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
CONFIG = yaml.safe_load(f)
except Exception as e:
return jsonify({'code': 1, 'message': f'读取配置失败: {str(e)}'})
# 优先使用前端传递的基座模型路径,否则从数据库查询
if not base_model_path:
try:
db_config = CONFIG['database']
conn = pymysql.connect(
host=db_config['host'],
port=db_config['port'],
user=db_config['username'],
password=db_config['password'],
database=db_config['name'],
charset=db_config.get('charset', 'utf8mb4'),
cursorclass=pymysql.cursors.DictCursor
)
cursor = conn.cursor()
# 优先从训练任务表查询基座模型
logger.info(f"[PRELOAD] 尝试从fine_tune表查询模型: {model_name}")
cursor.execute("""
SELECT base_model, output_model_name FROM fine_tune
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
LIMIT 1
""", (f'%/{model_name}', f'%{model_name}%'))
ft_result = cursor.fetchone()
logger.info(f"[PRELOAD] fine_tune查询结果: {ft_result}")
if ft_result and ft_result.get('base_model'):
base_model_val = ft_result['base_model']
logger.info(f"[PRELOAD] base_model_val: {base_model_val}")
# 如果是数字ID查询模型管理表获取路径
if str(base_model_val).isdigit():
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
model_result = cursor.fetchone()
logger.info(f"[PRELOAD] model_manage查询结果(数字ID): {model_result}")
if model_result:
base_model_path = model_result.get('path')
else:
# 直接是路径
base_model_path = base_model_val
# 如果训练任务表没找到,尝试从模型管理表按名称查询
if not base_model_path:
logger.info(f"[PRELOAD] 尝试从model_manage表查询...")
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
model_result = cursor.fetchone()
logger.info(f"[PRELOAD] model_manage查询结果: {model_result}")
if model_result:
base_model_path = model_result.get('path')
conn.close()
if not base_model_path:
logger.error(f"[PRELOAD] 未找到模型 {model_name} 的基座模型配置")
return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'})
except Exception as e:
logger.error(f"[PRELOAD] 查询模型配置失败: {e}")
return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'})
else:
logger.info(f"[PRELOAD] 使用前端传递的基座模型路径: {base_model_path}")
# 训练后的模型路径
trained_model_path = f"/app/base/saves/{train_method}/{model_name}"
# 检查路径是否存在兼容Windows和Linux
if not os.path.exists(trained_model_path):
logger.warning(f"[PRELOAD] 训练模型路径不存在: {trained_model_path}")
# 尝试查找适配器文件来确定模型是否存在
adapter_path = os.path.join(trained_model_path, 'adapter_model.bin')
safetensors_path = os.path.join(trained_model_path, 'model.safetensors')
if not os.path.exists(adapter_path) and not os.path.exists(safetensors_path):
return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'})
# 预热消息 - 使用一个简单的问候语来加载模型
work_dir = '/app/base'
warmup_messages = [{'role': 'system', 'content': 'You are a helpful assistant.'}, {'role': 'user', 'content': 'Hello'}]
# 根据训练方法选择 finetuning_type
finetuning_type = 'lora' if train_method == 'lora' else 'full'
# 构建 llamafactory 预热脚本
preload_script = f'''
import sys
import json
import logging
logging.basicConfig(level=logging.WARNING)
from llmtuner import ChatModel
def main():
chat_model = ChatModel({{
"model_name_or_path": "{base_model_path}",
"adapter_name_or_path": "{trained_model_path}",
"template": "llama3",
"finetuning_type": "{finetuning_type}",
"temperature": 0.1,
"max_new_tokens": 1
}})
messages = {json.dumps(warmup_messages, ensure_ascii=False)}
# 执行推理以加载模型
response = chat_model.chat(messages)
print("Model loaded successfully")
if __name__ == "__main__":
main()
'''
try:
work_dir = '/app/base'
script_path = os.path.join(work_dir, 'temp_preload.py')
with open(script_path, 'w', encoding='utf-8') as f:
f.write(preload_script)
# 设置环境变量,包括 PYTHONPATH 以便找到 llamafactory
import sys as sys_module
env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'}
# 尝试查找 llamafactory 目录并添加到 PYTHONPATH
llm_factory_paths = ['/app/base', '/app', '/app/base/src/llamafactory']
for path in llm_factory_paths:
if os.path.exists(path) and os.path.exists(os.path.join(path, 'llmtuner')):
env['PYTHONPATH'] = path
logger.info(f"[PRELOAD] 设置 PYTHONPATH={path}")
break
logger.info(f"[PRELOAD] 执行预热脚本...")
result = subprocess.run(
[sys_module.executable, 'temp_preload.py'],
capture_output=True,
text=True,
timeout=180,
cwd=work_dir,
env=env
)
os.remove(script_path)
logger.info(f"[PRELOAD] 命令返回码: {result.returncode}")
logger.info(f"[PRELOAD] stdout: {result.stdout[:500] if result.stdout else 'empty'}")
logger.info(f"[PRELOAD] stderr: {result.stderr[:500] if result.stderr else 'empty'}")
if result.returncode == 0:
return jsonify({
'code': 0,
'message': '模型预加载成功',
'data': {
'model_name': model_name,
'train_method': train_method,
'base_model': base_model_path
}
})
else:
error_msg = result.stderr.strip() if result.stderr else result.stdout.strip()
if not error_msg:
error_msg = f'命令执行失败,返回码: {result.returncode}'
return jsonify({'code': 1, 'message': f'预加载失败: {error_msg}'})
except subprocess.TimeoutExpired:
logger.error("[PRELOAD] 预加载超时")
return jsonify({'code': 1, 'message': '预加载超时,请稍后重试'})
except Exception as e:
logger.error(f"[PRELOAD] 预加载异常: {str(e)}")
return jsonify({'code': 1, 'message': f'预加载异常: {str(e)}'})
@model_chat_bp.route('/trained', methods=['POST'])
def chat_trained_model():
"""使用已训练模型进行对话推理"""
import pymysql
import yaml
data = request.json
model_name = data.get('model_name') # 模型名称
train_method = data.get('train_method', 'lora') # 训练方法: lora, full
base_model_path = data.get('base_model_path') # 前端传递的基座模型路径
system_prompt = data.get('system_prompt', '')
user_question = data.get('user_question')
temperature = data.get('temperature', 0.7)
max_tokens = data.get('max_tokens', 2048)
if not model_name:
return jsonify({'code': 1, 'message': '缺少模型名称'})
if not user_question:
return jsonify({'code': 1, 'message': '缺少用户提问'})
# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
CONFIG = yaml.safe_load(f)
# 优先使用前端传递的基座模型路径,否则从数据库查询
if not base_model_path:
try:
db_config = CONFIG['database']
conn = pymysql.connect(
host=db_config['host'],
port=db_config['port'],
user=db_config['username'],
password=db_config['password'],
database=db_config['name'],
charset=db_config.get('charset', 'utf8mb4'),
cursorclass=pymysql.cursors.DictCursor
)
cursor = conn.cursor()
# 优先从训练任务表查询基座模型
cursor.execute("""
SELECT base_model, output_model_name FROM fine_tune
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
LIMIT 1
""", (f'%/{model_name}', f'%{model_name}%'))
ft_result = cursor.fetchone()
if ft_result and ft_result.get('base_model'):
base_model_val = ft_result['base_model']
# 如果是数字ID查询模型管理表获取路径
if str(base_model_val).isdigit():
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
model_result = cursor.fetchone()
if model_result:
base_model_path = model_result.get('path')
else:
# 直接是路径
base_model_path = base_model_val
# 如果训练任务表没找到,尝试从模型管理表按名称查询
if not base_model_path:
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
model_result = cursor.fetchone()
if model_result:
base_model_path = model_result.get('path')
conn.close()
if not base_model_path:
return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'})
except Exception as e:
return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'})
# 训练后的模型路径
trained_model_path = f"/app/base/saves/{train_method}/{model_name}"
# 检查路径是否存在兼容Windows和Linux
if not os.path.exists(trained_model_path):
adapter_path = os.path.join(trained_model_path, 'adapter_model.bin')
safetensors_path = os.path.join(trained_model_path, 'model.safetensors')
if not os.path.exists(adapter_path) and not os.path.exists(safetensors_path):
return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'})
# 准备消息
messages = []
if system_prompt:
messages.append({'role': 'system', 'content': system_prompt})
messages.append({'role': 'user', 'content': user_question})
# 构建 llamafactory 推理脚本
inference_script = f'''
import sys
import json
import logging
logging.basicConfig(level=logging.WARNING)
from llmtuner import ChatModel
def main():
chat_model = ChatModel({{
"model_name_or_path": "{base_model_path}",
"adapter_name_or_path": "{trained_model_path}",
"template": "llama3",
"finetuning_type": "lora",
"temperature": {temperature},
"max_new_tokens": {max_tokens}
}})
messages = {json.dumps(messages, ensure_ascii=False)}
response = chat_model.chat(messages)
print(response)
if __name__ == "__main__":
main()
'''
# 写入临时脚本
work_dir = '/app/base'
script_path = os.path.join(work_dir, 'temp_inference.py')
try:
with open(script_path, 'w', encoding='utf-8') as f:
f.write(inference_script)
# 设置环境变量,包括 PYTHONPATH 以便找到 llamafactory
import sys as sys_module
env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'}
for path in ['/app/base', '/app', '/app/base/src/llamafactory']:
if os.path.exists(path) and os.path.exists(os.path.join(path, 'llmtuner')):
env['PYTHONPATH'] = path
break
# 执行推理脚本
result = subprocess.run(
[sys_module.executable, 'temp_inference.py'],
capture_output=True,
text=True,
timeout=180,
cwd=work_dir,
env=env
)
# 清理临时脚本
os.remove(script_path)
if result.returncode == 0:
assistant_content = result.stdout.strip()
return jsonify({
'code': 0,
'data': {
'model_name': model_name,
'train_method': train_method,
'response': assistant_content
}
})
else:
error_msg = result.stderr.strip() if result.stderr else result.stdout.strip()
return jsonify({'code': 1, 'message': f'推理失败: {error_msg}'})
except subprocess.TimeoutExpired:
return jsonify({'code': 1, 'message': '推理超时,请稍后重试'})
except Exception as e:
return jsonify({'code': 1, 'message': f'推理异常: {str(e)}'})