修改还了返回按钮的功能

This commit is contained in:
2026-01-29 17:39:06 +08:00
parent d0675aede3
commit 0f98d67e41
3 changed files with 254 additions and 22 deletions

View File

@@ -45,6 +45,49 @@ def generic_get_all(table_name, order_by='create_time DESC'):
return result
def get_model_path_by_name(model_name):
"""根据模型名称查询模型路径(用于获取基座模型路径)"""
try:
conn = get_db_connection()
cursor = conn.cursor()
# 优先从训练任务表查询基座模型
cursor.execute("""
SELECT base_model FROM fine_tune
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
LIMIT 1
""", (f'%/{model_name}', f'%{model_name}%'))
ft_result = cursor.fetchone()
if ft_result and ft_result.get('base_model'):
base_model_val = ft_result['base_model']
# 如果是数字ID查询模型管理表获取路径
if str(base_model_val).isdigit():
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
model_result = cursor.fetchone()
if model_result:
cursor.close()
conn.close()
return model_result.get('path')
else:
# 直接是路径
cursor.close()
conn.close()
return base_model_val
# 如果训练任务表没找到,尝试从模型管理表按名称查询
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
result = cursor.fetchone()
cursor.close()
conn.close()
if result:
return result.get('path')
return None
except Exception as e:
logger.error(f"[ERROR] 查询模型路径失败: {e}")
return None
def generic_create(table_name, data):
"""通用创建"""
conn = get_db_connection()
@@ -226,42 +269,108 @@ def get_trained_models():
try:
# 路径结构: /app/base/saves/{train_method}/{model_name}/
# train_method: lora, full, qlora, dpo, cpt 等
# 同时兼容老结构: /app/base/saves/{model_name}/
for train_method in os.listdir(base_path):
train_method_path = os.path.join(base_path, train_method)
if not os.path.isdir(train_method_path):
train_methods = ['lora', 'full', 'qlora', 'dpo', 'cpt', 'prefix', 'adapter', 'peft']
for item in os.listdir(base_path):
item_path = os.path.join(base_path, item)
if not os.path.isdir(item_path):
continue
logger.info(f"[DEBUG] 检查训练方法目录: {train_method}")
model_count = 0
# 情况1: 新结构 {train_method}/{model_name}
if item in train_methods:
logger.info(f"[DEBUG] 检查训练方法目录: {item}")
model_count = 0
# 遍历模型文件夹
for model_name in os.listdir(train_method_path):
model_path = os.path.join(train_method_path, model_name)
if not os.path.isdir(model_path):
continue
for model_name in os.listdir(item_path):
model_path = os.path.join(item_path, model_name)
if not os.path.isdir(model_path):
continue
# 检查是否有模型文件
try:
files = os.listdir(model_path)
has_model = any(f.endswith('.bin') or f.endswith('.safetensors') for f in files)
if has_model:
logger.info(f"[DEBUG] 找到模型: {item}/{model_name}")
# 获取文件创建时间
try:
import time
stat = os.stat(model_path)
create_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(stat.st_mtime))
except:
create_time = None
# 查询基座模型路径
base_model_path = get_model_path_by_name(model_name)
models.append({
'name': model_name,
'path': model_path,
'base_model_path': base_model_path,
'create_time': create_time,
'train_methods': [{
'name': item,
'path': model_path
}]
})
model_count += 1
except Exception as file_err:
logger.error(f"[DEBUG] 读取 {model_path} 失败: {file_err}")
logger.info(f"[DEBUG] {item} 找到 {model_count} 个模型")
# 情况2: 老结构 {model_name} 直接在 saves 下
else:
logger.info(f"[DEBUG] 检查老结构模型目录: {item}")
try:
files = os.listdir(model_path)
logger.info(f"[DEBUG] {train_method}/{model_name} 文件: {files[:5]}...")
files = os.listdir(item_path)
has_model = any(f.endswith('.bin') or f.endswith('.safetensors') for f in files)
if has_model:
logger.info(f"[DEBUG] 找到模型: {train_method}/{model_name}")
logger.info(f"[DEBUG] 找到模型: {item}")
# 尝试从 adapter_config.json 推断 train_method
inferred_method = 'lora' # 默认
config_file = os.path.join(item_path, 'adapter_config.json')
if os.path.exists(config_file):
try:
import json
with open(config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
if 'peft_type' in config:
peft_type = config['peft_type'].lower()
if 'lora' in peft_type:
inferred_method = 'lora'
elif 'full' in peft_type or 'pt' in peft_type:
inferred_method = 'full'
except:
pass
# 获取文件创建时间
try:
import time
stat = os.stat(item_path)
create_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(stat.st_mtime))
except:
create_time = None
# 查询基座模型路径
base_model_path = get_model_path_by_name(item)
models.append({
'name': model_name,
'path': model_path,
'name': item,
'path': item_path,
'base_model_path': base_model_path,
'create_time': create_time,
'train_methods': [{
'name': train_method,
'path': model_path
'name': inferred_method,
'path': item_path
}]
})
model_count += 1
except Exception as file_err:
logger.error(f"[DEBUG] 读取 {model_path} 失败: {file_err}")
logger.info(f"[DEBUG] {train_method} 找到 {model_count} 个模型")
logger.error(f"[DEBUG] 读取 {item_path} 失败: {file_err}")
except Exception as list_err:
logger.error(f"[DEBUG] 遍历目录失败: {list_err}")