修改还了返回按钮的功能

This commit is contained in:
2026-01-29 17:39:06 +08:00
parent d0675aede3
commit 0f98d67e41
3 changed files with 254 additions and 22 deletions

View File

@@ -6,6 +6,7 @@ import pymysql
import yaml
import requests
import concurrent.futures
import subprocess
from flask import Blueprint, request, jsonify
# 获取项目根目录
@@ -208,3 +209,124 @@ def model_chat_batch():
results.append(future.result())
return jsonify({'code': 0, 'data': results})
@model_chat_bp.route('/trained', methods=['POST'])
def chat_trained_model():
"""使用已训练模型进行对话推理"""
import pymysql
import yaml
data = request.json
model_name = data.get('model_name') # 模型名称
train_method = data.get('train_method', 'lora') # 训练方法: lora, full
system_prompt = data.get('system_prompt', '')
user_question = data.get('user_question')
temperature = data.get('temperature', 0.7)
max_tokens = data.get('max_tokens', 2048)
if not model_name:
return jsonify({'code': 1, 'message': '缺少模型名称'})
if not user_question:
return jsonify({'code': 1, 'message': '缺少用户提问'})
# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
CONFIG = yaml.safe_load(f)
# 获取基座模型路径 - 从数据库查询 model_name 对应的路径
try:
db_config = CONFIG['database']
conn = pymysql.connect(
host=db_config['host'],
port=db_config['port'],
user=db_config['username'],
password=db_config['password'],
database=db_config['name'],
charset=db_config.get('charset', 'utf8mb4'),
cursorclass=pymysql.cursors.DictCursor
)
cursor = conn.cursor()
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
model_result = cursor.fetchone()
conn.close()
if not model_result or not model_result.get('path'):
return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的配置'})
base_model_path = model_result['path']
except Exception as e:
return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'})
# 训练后的模型路径
trained_model_path = f"/app/base/saves/{train_method}/{model_name}"
if not os.path.exists(trained_model_path):
return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'})
# 构建 llamafactory-cli chat 命令
work_dir = '/app/base'
llamafactory_dir = '/app/base'
# 准备消息
messages = []
if system_prompt:
messages.append({'role': 'system', 'content': system_prompt})
messages.append({'role': 'user', 'content': user_question})
# 将消息转换为 JSON 字符串
messages_json = json.dumps(messages, ensure_ascii=False)
# 构建 llamafactory-cli chat 命令
full_cmd = f'cd {llamafactory_dir} && export CUDA_VISIBLE_DEVICES=0 && echo \'{messages_json}\' | llamafactory-cli chat --model_name_or_path {base_model_path} --adapter_name_or_path {trained_model_path} --template llama3 --finetuning_type lora --temperature {temperature} --max_tokens {max_tokens}'
try:
# 执行命令
result = subprocess.run(
full_cmd,
shell=True,
capture_output=True,
text=True,
timeout=120,
cwd=work_dir
)
output = result.stdout
error = result.stderr
# 解析输出提取assistant回复
# llamafactory-cli chat 输出格式通常是:
# <|im_start|>assistant
# xxx
# <|im_end|>
assistant_content = ''
if '<|im_start|>assistant' in output:
parts = output.split('<|im_start|>assistant')
if len(parts) > 1:
content_part = parts[1].split('<|im_end|>')[0].strip()
# 移除可能存在的换行前缀
content_part = content_part.lstrip('\n').strip()
assistant_content = content_part
elif result.returncode == 0:
# 如果没有特殊标记,尝试提取最后一部分作为回复
lines = output.strip().split('\n')
assistant_content = '\n'.join(lines).strip()
else:
return jsonify({'code': 1, 'message': f'推理失败: {error or output}'})
return jsonify({
'code': 0,
'data': {
'model_name': model_name,
'train_method': train_method,
'response': assistant_content
}
})
except subprocess.TimeoutExpired:
return jsonify({'code': 1, 'message': '推理超时,请稍后重试'})
except Exception as e:
return jsonify({'code': 1, 'message': f'推理异常: {str(e)}'})