修改还了返回按钮的功能
This commit is contained in:
@@ -6,6 +6,7 @@ import pymysql
|
||||
import yaml
|
||||
import requests
|
||||
import concurrent.futures
|
||||
import subprocess
|
||||
from flask import Blueprint, request, jsonify
|
||||
|
||||
# 获取项目根目录
|
||||
@@ -208,3 +209,124 @@ def model_chat_batch():
|
||||
results.append(future.result())
|
||||
|
||||
return jsonify({'code': 0, 'data': results})
|
||||
|
||||
|
||||
@model_chat_bp.route('/trained', methods=['POST'])
|
||||
def chat_trained_model():
|
||||
"""使用已训练模型进行对话推理"""
|
||||
import pymysql
|
||||
import yaml
|
||||
|
||||
data = request.json
|
||||
model_name = data.get('model_name') # 模型名称
|
||||
train_method = data.get('train_method', 'lora') # 训练方法: lora, full
|
||||
system_prompt = data.get('system_prompt', '')
|
||||
user_question = data.get('user_question')
|
||||
temperature = data.get('temperature', 0.7)
|
||||
max_tokens = data.get('max_tokens', 2048)
|
||||
|
||||
if not model_name:
|
||||
return jsonify({'code': 1, 'message': '缺少模型名称'})
|
||||
if not user_question:
|
||||
return jsonify({'code': 1, 'message': '缺少用户提问'})
|
||||
|
||||
# 获取项目根目录
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
|
||||
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||
CONFIG = yaml.safe_load(f)
|
||||
|
||||
# 获取基座模型路径 - 从数据库查询 model_name 对应的路径
|
||||
try:
|
||||
db_config = CONFIG['database']
|
||||
conn = pymysql.connect(
|
||||
host=db_config['host'],
|
||||
port=db_config['port'],
|
||||
user=db_config['username'],
|
||||
password=db_config['password'],
|
||||
database=db_config['name'],
|
||||
charset=db_config.get('charset', 'utf8mb4'),
|
||||
cursorclass=pymysql.cursors.DictCursor
|
||||
)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
|
||||
model_result = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
if not model_result or not model_result.get('path'):
|
||||
return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的配置'})
|
||||
|
||||
base_model_path = model_result['path']
|
||||
except Exception as e:
|
||||
return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'})
|
||||
|
||||
# 训练后的模型路径
|
||||
trained_model_path = f"/app/base/saves/{train_method}/{model_name}"
|
||||
|
||||
if not os.path.exists(trained_model_path):
|
||||
return jsonify({'code': 1, 'message': f'训练模型不存在: {trained_model_path}'})
|
||||
|
||||
# 构建 llamafactory-cli chat 命令
|
||||
work_dir = '/app/base'
|
||||
llamafactory_dir = '/app/base'
|
||||
|
||||
# 准备消息
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({'role': 'system', 'content': system_prompt})
|
||||
messages.append({'role': 'user', 'content': user_question})
|
||||
|
||||
# 将消息转换为 JSON 字符串
|
||||
messages_json = json.dumps(messages, ensure_ascii=False)
|
||||
|
||||
# 构建 llamafactory-cli chat 命令
|
||||
full_cmd = f'cd {llamafactory_dir} && export CUDA_VISIBLE_DEVICES=0 && echo \'{messages_json}\' | llamafactory-cli chat --model_name_or_path {base_model_path} --adapter_name_or_path {trained_model_path} --template llama3 --finetuning_type lora --temperature {temperature} --max_tokens {max_tokens}'
|
||||
|
||||
try:
|
||||
# 执行命令
|
||||
result = subprocess.run(
|
||||
full_cmd,
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
cwd=work_dir
|
||||
)
|
||||
|
||||
output = result.stdout
|
||||
error = result.stderr
|
||||
|
||||
# 解析输出,提取assistant回复
|
||||
# llamafactory-cli chat 输出格式通常是:
|
||||
# <|im_start|>assistant
|
||||
# xxx
|
||||
# <|im_end|>
|
||||
|
||||
assistant_content = ''
|
||||
if '<|im_start|>assistant' in output:
|
||||
parts = output.split('<|im_start|>assistant')
|
||||
if len(parts) > 1:
|
||||
content_part = parts[1].split('<|im_end|>')[0].strip()
|
||||
# 移除可能存在的换行前缀
|
||||
content_part = content_part.lstrip('\n').strip()
|
||||
assistant_content = content_part
|
||||
elif result.returncode == 0:
|
||||
# 如果没有特殊标记,尝试提取最后一部分作为回复
|
||||
lines = output.strip().split('\n')
|
||||
assistant_content = '\n'.join(lines).strip()
|
||||
else:
|
||||
return jsonify({'code': 1, 'message': f'推理失败: {error or output}'})
|
||||
|
||||
return jsonify({
|
||||
'code': 0,
|
||||
'data': {
|
||||
'model_name': model_name,
|
||||
'train_method': train_method,
|
||||
'response': assistant_content
|
||||
}
|
||||
})
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return jsonify({'code': 1, 'message': '推理超时,请稍后重试'})
|
||||
except Exception as e:
|
||||
return jsonify({'code': 1, 'message': f'推理异常: {str(e)}'})
|
||||
|
||||
@@ -45,6 +45,49 @@ def generic_get_all(table_name, order_by='create_time DESC'):
|
||||
return result
|
||||
|
||||
|
||||
def get_model_path_by_name(model_name):
|
||||
"""根据模型名称查询模型路径(用于获取基座模型路径)"""
|
||||
try:
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 优先从训练任务表查询基座模型
|
||||
cursor.execute("""
|
||||
SELECT base_model FROM fine_tune
|
||||
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
|
||||
LIMIT 1
|
||||
""", (f'%/{model_name}', f'%{model_name}%'))
|
||||
ft_result = cursor.fetchone()
|
||||
|
||||
if ft_result and ft_result.get('base_model'):
|
||||
base_model_val = ft_result['base_model']
|
||||
# 如果是数字ID,查询模型管理表获取路径
|
||||
if str(base_model_val).isdigit():
|
||||
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
|
||||
model_result = cursor.fetchone()
|
||||
if model_result:
|
||||
cursor.close()
|
||||
conn.close()
|
||||
return model_result.get('path')
|
||||
else:
|
||||
# 直接是路径
|
||||
cursor.close()
|
||||
conn.close()
|
||||
return base_model_val
|
||||
|
||||
# 如果训练任务表没找到,尝试从模型管理表按名称查询
|
||||
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
|
||||
result = cursor.fetchone()
|
||||
cursor.close()
|
||||
conn.close()
|
||||
if result:
|
||||
return result.get('path')
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[ERROR] 查询模型路径失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def generic_create(table_name, data):
|
||||
"""通用创建"""
|
||||
conn = get_db_connection()
|
||||
@@ -226,42 +269,108 @@ def get_trained_models():
|
||||
try:
|
||||
# 路径结构: /app/base/saves/{train_method}/{model_name}/
|
||||
# train_method: lora, full, qlora, dpo, cpt 等
|
||||
# 同时兼容老结构: /app/base/saves/{model_name}/
|
||||
|
||||
for train_method in os.listdir(base_path):
|
||||
train_method_path = os.path.join(base_path, train_method)
|
||||
if not os.path.isdir(train_method_path):
|
||||
train_methods = ['lora', 'full', 'qlora', 'dpo', 'cpt', 'prefix', 'adapter', 'peft']
|
||||
|
||||
for item in os.listdir(base_path):
|
||||
item_path = os.path.join(base_path, item)
|
||||
if not os.path.isdir(item_path):
|
||||
continue
|
||||
|
||||
logger.info(f"[DEBUG] 检查训练方法目录: {train_method}")
|
||||
model_count = 0
|
||||
# 情况1: 新结构 {train_method}/{model_name}
|
||||
if item in train_methods:
|
||||
logger.info(f"[DEBUG] 检查训练方法目录: {item}")
|
||||
model_count = 0
|
||||
|
||||
# 遍历模型文件夹
|
||||
for model_name in os.listdir(train_method_path):
|
||||
model_path = os.path.join(train_method_path, model_name)
|
||||
if not os.path.isdir(model_path):
|
||||
continue
|
||||
for model_name in os.listdir(item_path):
|
||||
model_path = os.path.join(item_path, model_name)
|
||||
if not os.path.isdir(model_path):
|
||||
continue
|
||||
|
||||
# 检查是否有模型文件
|
||||
try:
|
||||
files = os.listdir(model_path)
|
||||
has_model = any(f.endswith('.bin') or f.endswith('.safetensors') for f in files)
|
||||
|
||||
if has_model:
|
||||
logger.info(f"[DEBUG] 找到模型: {item}/{model_name}")
|
||||
# 获取文件创建时间
|
||||
try:
|
||||
import time
|
||||
stat = os.stat(model_path)
|
||||
create_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(stat.st_mtime))
|
||||
except:
|
||||
create_time = None
|
||||
|
||||
# 查询基座模型路径
|
||||
base_model_path = get_model_path_by_name(model_name)
|
||||
|
||||
models.append({
|
||||
'name': model_name,
|
||||
'path': model_path,
|
||||
'base_model_path': base_model_path,
|
||||
'create_time': create_time,
|
||||
'train_methods': [{
|
||||
'name': item,
|
||||
'path': model_path
|
||||
}]
|
||||
})
|
||||
model_count += 1
|
||||
except Exception as file_err:
|
||||
logger.error(f"[DEBUG] 读取 {model_path} 失败: {file_err}")
|
||||
|
||||
logger.info(f"[DEBUG] {item} 找到 {model_count} 个模型")
|
||||
|
||||
# 情况2: 老结构 {model_name} 直接在 saves 下
|
||||
else:
|
||||
logger.info(f"[DEBUG] 检查老结构模型目录: {item}")
|
||||
try:
|
||||
files = os.listdir(model_path)
|
||||
logger.info(f"[DEBUG] {train_method}/{model_name} 文件: {files[:5]}...")
|
||||
files = os.listdir(item_path)
|
||||
has_model = any(f.endswith('.bin') or f.endswith('.safetensors') for f in files)
|
||||
|
||||
if has_model:
|
||||
logger.info(f"[DEBUG] 找到模型: {train_method}/{model_name}")
|
||||
logger.info(f"[DEBUG] 找到模型: {item}")
|
||||
|
||||
# 尝试从 adapter_config.json 推断 train_method
|
||||
inferred_method = 'lora' # 默认
|
||||
config_file = os.path.join(item_path, 'adapter_config.json')
|
||||
if os.path.exists(config_file):
|
||||
try:
|
||||
import json
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
if 'peft_type' in config:
|
||||
peft_type = config['peft_type'].lower()
|
||||
if 'lora' in peft_type:
|
||||
inferred_method = 'lora'
|
||||
elif 'full' in peft_type or 'pt' in peft_type:
|
||||
inferred_method = 'full'
|
||||
except:
|
||||
pass
|
||||
|
||||
# 获取文件创建时间
|
||||
try:
|
||||
import time
|
||||
stat = os.stat(item_path)
|
||||
create_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(stat.st_mtime))
|
||||
except:
|
||||
create_time = None
|
||||
|
||||
# 查询基座模型路径
|
||||
base_model_path = get_model_path_by_name(item)
|
||||
|
||||
models.append({
|
||||
'name': model_name,
|
||||
'path': model_path,
|
||||
'name': item,
|
||||
'path': item_path,
|
||||
'base_model_path': base_model_path,
|
||||
'create_time': create_time,
|
||||
'train_methods': [{
|
||||
'name': train_method,
|
||||
'path': model_path
|
||||
'name': inferred_method,
|
||||
'path': item_path
|
||||
}]
|
||||
})
|
||||
model_count += 1
|
||||
except Exception as file_err:
|
||||
logger.error(f"[DEBUG] 读取 {model_path} 失败: {file_err}")
|
||||
|
||||
logger.info(f"[DEBUG] {train_method} 找到 {model_count} 个模型")
|
||||
logger.error(f"[DEBUG] 读取 {item_path} 失败: {file_err}")
|
||||
|
||||
except Exception as list_err:
|
||||
logger.error(f"[DEBUG] 遍历目录失败: {list_err}")
|
||||
|
||||
@@ -675,6 +675,7 @@
|
||||
function goBack() {
|
||||
window.location.href = backUrl;
|
||||
}
|
||||
window.goBack = goBack;
|
||||
})();
|
||||
</script>
|
||||
|
||||
|
||||
Reference in New Issue
Block a user