2026-01-20 16:16:13 +08:00
|
|
|
|
"""
|
|
|
|
|
|
模型管理 API 路由
|
|
|
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
|
|
|
import pymysql
|
|
|
|
|
|
import yaml
|
2026-02-02 09:22:52 +08:00
|
|
|
|
import logging
|
2026-01-20 16:16:13 +08:00
|
|
|
|
from flask import Blueprint, request, jsonify
|
|
|
|
|
|
|
2026-02-02 09:22:52 +08:00
|
|
|
|
# 获取模块 logger(继承 main.py 的日志配置)
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
2026-01-28 10:31:09 +08:00
|
|
|
|
# 获取项目根目录 - 优先使用环境变量,否则从文件路径计算
|
|
|
|
|
|
MOUNT_BASE = os.environ.get('MOUNT_BASE', '/app/base')
|
2026-01-20 16:16:13 +08:00
|
|
|
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
2026-01-28 10:31:09 +08:00
|
|
|
|
# 如果 PROJECT_ROOT 是 /app 或 /app/src/llamafactory,则使用挂载路径
|
|
|
|
|
|
if PROJECT_ROOT in ('/app', '/app/src/llamafactory'):
|
|
|
|
|
|
PROJECT_ROOT = MOUNT_BASE
|
2026-01-20 16:16:13 +08:00
|
|
|
|
|
|
|
|
|
|
# 创建蓝图
|
|
|
|
|
|
model_manage_bp = Blueprint('model_manage', __name__, url_prefix='/api/model-manage')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_db_connection():
|
|
|
|
|
|
"""获取数据库连接"""
|
|
|
|
|
|
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
|
|
|
|
|
|
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
|
|
|
|
|
CONFIG = yaml.safe_load(f)
|
|
|
|
|
|
db_config = CONFIG['database']
|
|
|
|
|
|
return pymysql.connect(
|
|
|
|
|
|
host=db_config['host'],
|
|
|
|
|
|
port=db_config['port'],
|
|
|
|
|
|
user=db_config['username'],
|
|
|
|
|
|
password=db_config['password'],
|
|
|
|
|
|
database=db_config['name'],
|
|
|
|
|
|
charset=db_config.get('charset', 'utf8mb4'),
|
|
|
|
|
|
cursorclass=pymysql.cursors.DictCursor
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generic_get_all(table_name, order_by='create_time DESC'):
|
|
|
|
|
|
"""通用查询所有"""
|
|
|
|
|
|
conn = get_db_connection()
|
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
cursor.execute(f"SELECT * FROM {table_name} ORDER BY {order_by}")
|
|
|
|
|
|
result = cursor.fetchall()
|
|
|
|
|
|
cursor.close()
|
|
|
|
|
|
conn.close()
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-01-29 17:39:06 +08:00
|
|
|
|
def get_model_path_by_name(model_name):
|
|
|
|
|
|
"""根据模型名称查询模型路径(用于获取基座模型路径)"""
|
2026-01-29 23:10:21 +08:00
|
|
|
|
logger.info(f"[DEBUG get_model_path_by_name] 查询模型: {model_name}")
|
|
|
|
|
|
|
2026-01-29 17:39:06 +08:00
|
|
|
|
try:
|
|
|
|
|
|
conn = get_db_connection()
|
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
|
|
|
|
|
|
# 优先从训练任务表查询基座模型
|
2026-01-29 23:10:21 +08:00
|
|
|
|
logger.info(f"[DEBUG get_model_path_by_name] 尝试从fine_tune表查询...")
|
2026-01-29 17:39:06 +08:00
|
|
|
|
cursor.execute("""
|
2026-01-29 23:10:21 +08:00
|
|
|
|
SELECT base_model, output_model_name FROM fine_tune
|
2026-01-29 17:39:06 +08:00
|
|
|
|
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
|
|
|
|
|
|
LIMIT 1
|
|
|
|
|
|
""", (f'%/{model_name}', f'%{model_name}%'))
|
|
|
|
|
|
ft_result = cursor.fetchone()
|
2026-01-29 23:10:21 +08:00
|
|
|
|
logger.info(f"[DEBUG get_model_path_by_name] fine_tune查询结果: {ft_result}")
|
2026-01-29 17:39:06 +08:00
|
|
|
|
|
|
|
|
|
|
if ft_result and ft_result.get('base_model'):
|
|
|
|
|
|
base_model_val = ft_result['base_model']
|
2026-01-29 23:10:21 +08:00
|
|
|
|
logger.info(f"[DEBUG get_model_path_by_name] base_model_val: {base_model_val}")
|
2026-01-29 17:39:06 +08:00
|
|
|
|
# 如果是数字ID,查询模型管理表获取路径
|
|
|
|
|
|
if str(base_model_val).isdigit():
|
|
|
|
|
|
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
|
|
|
|
|
|
model_result = cursor.fetchone()
|
2026-01-29 23:10:21 +08:00
|
|
|
|
logger.info(f"[DEBUG get_model_path_by_name] model_manage查询结果(数字ID): {model_result}")
|
2026-01-29 17:39:06 +08:00
|
|
|
|
if model_result:
|
|
|
|
|
|
cursor.close()
|
|
|
|
|
|
conn.close()
|
|
|
|
|
|
return model_result.get('path')
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 直接是路径
|
|
|
|
|
|
cursor.close()
|
|
|
|
|
|
conn.close()
|
|
|
|
|
|
return base_model_val
|
|
|
|
|
|
|
|
|
|
|
|
# 如果训练任务表没找到,尝试从模型管理表按名称查询
|
2026-01-29 23:10:21 +08:00
|
|
|
|
logger.info(f"[DEBUG get_model_path_by_name] 尝试从model_manage表查询...")
|
2026-01-29 17:39:06 +08:00
|
|
|
|
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
|
|
|
|
|
|
result = cursor.fetchone()
|
2026-01-29 23:10:21 +08:00
|
|
|
|
logger.info(f"[DEBUG get_model_path_by_name] model_manage查询结果: {result}")
|
2026-01-29 17:39:06 +08:00
|
|
|
|
cursor.close()
|
|
|
|
|
|
conn.close()
|
|
|
|
|
|
if result:
|
|
|
|
|
|
return result.get('path')
|
2026-01-29 23:10:21 +08:00
|
|
|
|
logger.info(f"[DEBUG get_model_path_by_name] 未找到任何匹配,返回None")
|
2026-01-29 17:39:06 +08:00
|
|
|
|
return None
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"[ERROR] 查询模型路径失败: {e}")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-01-20 16:16:13 +08:00
|
|
|
|
def generic_create(table_name, data):
|
|
|
|
|
|
"""通用创建"""
|
|
|
|
|
|
conn = get_db_connection()
|
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
columns = ', '.join(data.keys())
|
|
|
|
|
|
placeholders = ', '.join(['%s'] * len(data))
|
|
|
|
|
|
sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
|
|
|
|
|
|
cursor.execute(sql, list(data.values()))
|
|
|
|
|
|
conn.commit()
|
|
|
|
|
|
new_id = cursor.lastrowid
|
|
|
|
|
|
cursor.close()
|
|
|
|
|
|
conn.close()
|
|
|
|
|
|
return new_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generic_update(table_name, id_val, data):
|
|
|
|
|
|
"""通用更新"""
|
|
|
|
|
|
conn = get_db_connection()
|
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
set_clause = ', '.join([f"{k} = %s" for k in data.keys()])
|
|
|
|
|
|
sql = f"UPDATE {table_name} SET {set_clause} WHERE id = %s"
|
|
|
|
|
|
values = list(data.values()) + [id_val]
|
|
|
|
|
|
cursor.execute(sql, values)
|
|
|
|
|
|
conn.commit()
|
|
|
|
|
|
cursor.close()
|
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generic_delete(table_name, id_val):
|
|
|
|
|
|
"""通用删除"""
|
|
|
|
|
|
conn = get_db_connection()
|
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
cursor.execute(f"DELETE FROM {table_name} WHERE id = %s", (id_val,))
|
|
|
|
|
|
conn.commit()
|
|
|
|
|
|
cursor.close()
|
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generic_get_by_id(table_name, id_val):
|
|
|
|
|
|
"""通用按ID查询"""
|
|
|
|
|
|
conn = get_db_connection()
|
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
cursor.execute(f"SELECT * FROM {table_name} WHERE id = %s", (id_val,))
|
|
|
|
|
|
result = cursor.fetchone()
|
|
|
|
|
|
cursor.close()
|
|
|
|
|
|
conn.close()
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ============ 模型管理 CRUD ============
|
|
|
|
|
|
|
|
|
|
|
|
@model_manage_bp.route('', methods=['GET'])
|
|
|
|
|
|
def get_model_manage():
|
|
|
|
|
|
"""获取所有模型"""
|
|
|
|
|
|
return jsonify({'code': 0, 'data': generic_get_all('model_manage')})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@model_manage_bp.route('/<int:id>', methods=['GET'])
|
|
|
|
|
|
def get_model_manage_by_id(id):
|
|
|
|
|
|
"""获取单个模型"""
|
|
|
|
|
|
model = generic_get_by_id('model_manage', id)
|
|
|
|
|
|
if model:
|
|
|
|
|
|
return jsonify({'code': 0, 'data': model})
|
|
|
|
|
|
return jsonify({'code': 1, 'message': '模型不存在'})
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-02-02 09:22:52 +08:00
|
|
|
|
@model_manage_bp.route('/name/<model_name>', methods=['GET'])
|
|
|
|
|
|
def get_model_manage_by_name(model_name):
|
|
|
|
|
|
"""根据名称获取模型"""
|
|
|
|
|
|
logger.info(f"[DEBUG] 按名称查询模型: {model_name}")
|
|
|
|
|
|
|
|
|
|
|
|
conn = get_db_connection()
|
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
cursor.execute("SELECT * FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
|
|
|
|
|
|
model = cursor.fetchone()
|
|
|
|
|
|
cursor.close()
|
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
|
|
if model:
|
|
|
|
|
|
return jsonify({'code': 0, 'data': model})
|
|
|
|
|
|
return jsonify({'code': 1, 'message': '模型不存在'})
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-01-20 16:16:13 +08:00
|
|
|
|
@model_manage_bp.route('', methods=['POST'])
|
|
|
|
|
|
def create_model_manage():
|
|
|
|
|
|
"""创建模型"""
|
|
|
|
|
|
data = request.json
|
|
|
|
|
|
# 构建插入数据
|
|
|
|
|
|
insert_data = {
|
|
|
|
|
|
'name': data.get('name'),
|
|
|
|
|
|
'type': data.get('type'),
|
|
|
|
|
|
'model_source': data.get('model_source', 'local'),
|
2026-01-22 16:46:12 +08:00
|
|
|
|
'description': data.get('description'),
|
|
|
|
|
|
'purpose': data.get('purpose', 'inference') # 默认推理用途
|
2026-01-20 16:16:13 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if data.get('model_source') == 'local':
|
|
|
|
|
|
insert_data['path'] = data.get('path', '')
|
|
|
|
|
|
else:
|
|
|
|
|
|
insert_data['api_url'] = data.get('api_url', '')
|
|
|
|
|
|
insert_data['api_key'] = data.get('api_key', '')
|
|
|
|
|
|
insert_data['model_name'] = data.get('model_name', '')
|
|
|
|
|
|
|
|
|
|
|
|
new_id = generic_create('model_manage', insert_data)
|
|
|
|
|
|
return jsonify({'code': 0, 'message': '创建成功', 'id': new_id})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@model_manage_bp.route('/<int:id>', methods=['PUT'])
|
|
|
|
|
|
def update_model_manage(id):
|
|
|
|
|
|
"""更新模型"""
|
|
|
|
|
|
data = request.json
|
|
|
|
|
|
generic_update('model_manage', id, data)
|
|
|
|
|
|
return jsonify({'code': 0, 'message': '更新成功'})
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-01-22 16:46:12 +08:00
|
|
|
|
@model_manage_bp.route('/<int:id>/purpose', methods=['PUT'])
|
|
|
|
|
|
def update_model_purpose(id):
|
|
|
|
|
|
"""更新模型用途"""
|
|
|
|
|
|
data = request.json
|
|
|
|
|
|
purpose = data.get('purpose')
|
|
|
|
|
|
if purpose not in ['training', 'inference', 'evaluation']:
|
|
|
|
|
|
return jsonify({'code': 1, 'message': '无效的用途类型'})
|
|
|
|
|
|
generic_update('model_manage', id, {'purpose': purpose})
|
|
|
|
|
|
return jsonify({'code': 0, 'message': '更新成功'})
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-01-20 16:16:13 +08:00
|
|
|
|
@model_manage_bp.route('/<int:id>', methods=['DELETE'])
|
|
|
|
|
|
def delete_model_manage(id):
|
|
|
|
|
|
"""删除模型"""
|
|
|
|
|
|
generic_delete('model_manage', id)
|
|
|
|
|
|
return jsonify({'code': 0, 'message': '删除成功'})
|
2026-01-26 17:23:34 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ============ 本地模型列表接口 ============
|
|
|
|
|
|
|
|
|
|
|
|
@model_manage_bp.route('/local-models', methods=['GET'])
|
|
|
|
|
|
def get_local_models():
|
|
|
|
|
|
"""获取本地模型列表(从YG_FT_Base/local_models目录)"""
|
|
|
|
|
|
import logging
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 使用 YG_FT_Base/local_models 目录
|
|
|
|
|
|
base_path = os.path.join(PROJECT_ROOT, 'local_models')
|
|
|
|
|
|
|
|
|
|
|
|
models = []
|
|
|
|
|
|
if os.path.exists(base_path):
|
|
|
|
|
|
for item in os.listdir(base_path):
|
|
|
|
|
|
item_path = os.path.join(base_path, item)
|
|
|
|
|
|
if os.path.isdir(item_path):
|
|
|
|
|
|
models.append({
|
|
|
|
|
|
'name': item,
|
|
|
|
|
|
'path': item_path
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
return jsonify({
|
|
|
|
|
|
'code': 0,
|
|
|
|
|
|
'data': {
|
|
|
|
|
|
'models': models,
|
|
|
|
|
|
'base_path': base_path
|
|
|
|
|
|
}
|
|
|
|
|
|
})
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"获取本地模型列表失败: {e}")
|
|
|
|
|
|
return jsonify({'code': 1, 'message': str(e)})
|
2026-01-29 10:36:59 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ============ 已训练模型列表接口 ============
|
|
|
|
|
|
|
|
|
|
|
|
@model_manage_bp.route('/trained-models', methods=['GET'])
|
|
|
|
|
|
def get_trained_models():
|
|
|
|
|
|
"""获取已训练模型列表(从/app/base/saves目录)"""
|
|
|
|
|
|
import logging
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
2026-01-29 15:51:45 +08:00
|
|
|
|
# 多个可能的路径
|
|
|
|
|
|
potential_paths = [
|
|
|
|
|
|
'/app/base/saves', # 容器内路径
|
|
|
|
|
|
os.path.join(PROJECT_ROOT, 'saves'), # 本地开发路径
|
|
|
|
|
|
os.path.join(os.path.dirname(os.path.dirname(PROJECT_ROOT)), 'YG_FT_Base', 'saves'), # 上级目录
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
base_path = None
|
|
|
|
|
|
for path in potential_paths:
|
|
|
|
|
|
logger.info(f"[DEBUG] 检查路径: {path}, exists: {os.path.exists(path)}")
|
|
|
|
|
|
if os.path.exists(path):
|
|
|
|
|
|
base_path = path
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"[DEBUG] 最终使用的路径: {base_path}")
|
2026-01-29 10:36:59 +08:00
|
|
|
|
|
|
|
|
|
|
models = []
|
2026-01-29 15:51:45 +08:00
|
|
|
|
if base_path and os.path.exists(base_path):
|
|
|
|
|
|
logger.info(f"[DEBUG] 遍历目录: {base_path}")
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 路径结构: /app/base/saves/{train_method}/{model_name}/
|
|
|
|
|
|
# train_method: lora, full, qlora, dpo, cpt 等
|
2026-01-29 17:39:06 +08:00
|
|
|
|
# 同时兼容老结构: /app/base/saves/{model_name}/
|
2026-01-29 15:51:45 +08:00
|
|
|
|
|
2026-01-29 17:39:06 +08:00
|
|
|
|
train_methods = ['lora', 'full', 'qlora', 'dpo', 'cpt', 'prefix', 'adapter', 'peft']
|
2026-01-29 15:51:45 +08:00
|
|
|
|
|
2026-01-29 17:39:06 +08:00
|
|
|
|
for item in os.listdir(base_path):
|
|
|
|
|
|
item_path = os.path.join(base_path, item)
|
|
|
|
|
|
if not os.path.isdir(item_path):
|
|
|
|
|
|
continue
|
2026-01-29 15:51:45 +08:00
|
|
|
|
|
2026-01-29 17:39:06 +08:00
|
|
|
|
# 情况1: 新结构 {train_method}/{model_name}
|
|
|
|
|
|
if item in train_methods:
|
|
|
|
|
|
logger.info(f"[DEBUG] 检查训练方法目录: {item}")
|
|
|
|
|
|
model_count = 0
|
|
|
|
|
|
|
|
|
|
|
|
for model_name in os.listdir(item_path):
|
|
|
|
|
|
model_path = os.path.join(item_path, model_name)
|
|
|
|
|
|
if not os.path.isdir(model_path):
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
files = os.listdir(model_path)
|
|
|
|
|
|
has_model = any(f.endswith('.bin') or f.endswith('.safetensors') for f in files)
|
|
|
|
|
|
|
|
|
|
|
|
if has_model:
|
|
|
|
|
|
logger.info(f"[DEBUG] 找到模型: {item}/{model_name}")
|
|
|
|
|
|
# 获取文件创建时间
|
|
|
|
|
|
try:
|
|
|
|
|
|
import time
|
|
|
|
|
|
stat = os.stat(model_path)
|
|
|
|
|
|
create_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(stat.st_mtime))
|
|
|
|
|
|
except:
|
|
|
|
|
|
create_time = None
|
|
|
|
|
|
|
|
|
|
|
|
# 查询基座模型路径
|
|
|
|
|
|
base_model_path = get_model_path_by_name(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
models.append({
|
|
|
|
|
|
'name': model_name,
|
|
|
|
|
|
'path': model_path,
|
|
|
|
|
|
'base_model_path': base_model_path,
|
|
|
|
|
|
'create_time': create_time,
|
|
|
|
|
|
'train_methods': [{
|
|
|
|
|
|
'name': item,
|
|
|
|
|
|
'path': model_path
|
|
|
|
|
|
}]
|
|
|
|
|
|
})
|
|
|
|
|
|
model_count += 1
|
|
|
|
|
|
except Exception as file_err:
|
|
|
|
|
|
logger.error(f"[DEBUG] 读取 {model_path} 失败: {file_err}")
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"[DEBUG] {item} 找到 {model_count} 个模型")
|
|
|
|
|
|
|
|
|
|
|
|
# 情况2: 老结构 {model_name} 直接在 saves 下
|
|
|
|
|
|
else:
|
|
|
|
|
|
logger.info(f"[DEBUG] 检查老结构模型目录: {item}")
|
2026-01-29 15:51:45 +08:00
|
|
|
|
try:
|
2026-01-29 17:39:06 +08:00
|
|
|
|
files = os.listdir(item_path)
|
2026-01-29 15:51:45 +08:00
|
|
|
|
has_model = any(f.endswith('.bin') or f.endswith('.safetensors') for f in files)
|
|
|
|
|
|
|
|
|
|
|
|
if has_model:
|
2026-01-29 17:39:06 +08:00
|
|
|
|
logger.info(f"[DEBUG] 找到模型: {item}")
|
|
|
|
|
|
|
|
|
|
|
|
# 尝试从 adapter_config.json 推断 train_method
|
|
|
|
|
|
inferred_method = 'lora' # 默认
|
|
|
|
|
|
config_file = os.path.join(item_path, 'adapter_config.json')
|
|
|
|
|
|
if os.path.exists(config_file):
|
|
|
|
|
|
try:
|
|
|
|
|
|
import json
|
|
|
|
|
|
with open(config_file, 'r', encoding='utf-8') as f:
|
|
|
|
|
|
config = json.load(f)
|
|
|
|
|
|
if 'peft_type' in config:
|
|
|
|
|
|
peft_type = config['peft_type'].lower()
|
|
|
|
|
|
if 'lora' in peft_type:
|
|
|
|
|
|
inferred_method = 'lora'
|
|
|
|
|
|
elif 'full' in peft_type or 'pt' in peft_type:
|
|
|
|
|
|
inferred_method = 'full'
|
|
|
|
|
|
except:
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
# 获取文件创建时间
|
|
|
|
|
|
try:
|
|
|
|
|
|
import time
|
|
|
|
|
|
stat = os.stat(item_path)
|
|
|
|
|
|
create_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(stat.st_mtime))
|
|
|
|
|
|
except:
|
|
|
|
|
|
create_time = None
|
|
|
|
|
|
|
|
|
|
|
|
# 查询基座模型路径
|
|
|
|
|
|
base_model_path = get_model_path_by_name(item)
|
|
|
|
|
|
|
2026-01-29 15:51:45 +08:00
|
|
|
|
models.append({
|
2026-01-29 17:39:06 +08:00
|
|
|
|
'name': item,
|
|
|
|
|
|
'path': item_path,
|
|
|
|
|
|
'base_model_path': base_model_path,
|
|
|
|
|
|
'create_time': create_time,
|
2026-01-29 15:51:45 +08:00
|
|
|
|
'train_methods': [{
|
2026-01-29 17:39:06 +08:00
|
|
|
|
'name': inferred_method,
|
|
|
|
|
|
'path': item_path
|
2026-01-29 15:51:45 +08:00
|
|
|
|
}]
|
|
|
|
|
|
})
|
|
|
|
|
|
except Exception as file_err:
|
2026-01-29 17:39:06 +08:00
|
|
|
|
logger.error(f"[DEBUG] 读取 {item_path} 失败: {file_err}")
|
2026-01-29 15:51:45 +08:00
|
|
|
|
|
|
|
|
|
|
except Exception as list_err:
|
|
|
|
|
|
logger.error(f"[DEBUG] 遍历目录失败: {list_err}")
|
2026-01-29 10:36:59 +08:00
|
|
|
|
|
|
|
|
|
|
logger.info(f"[DEBUG] 找到 {len(models)} 个已训练模型")
|
|
|
|
|
|
|
2026-01-29 23:10:21 +08:00
|
|
|
|
# 检查每个模型是否已合并或正在合并
|
|
|
|
|
|
local_trained_path = os.path.join(PROJECT_ROOT, 'local_trained_models')
|
|
|
|
|
|
for model in models:
|
|
|
|
|
|
model_name = model['name']
|
|
|
|
|
|
merged_path = os.path.join(local_trained_path, model_name)
|
|
|
|
|
|
lock_file = os.path.join(local_trained_path, f'.merging_{model_name}.lock')
|
|
|
|
|
|
model['merged'] = os.path.exists(merged_path)
|
|
|
|
|
|
model['merging'] = os.path.exists(lock_file)
|
|
|
|
|
|
logger.info(f"[DEBUG] 模型 {model_name} 已合并: {model['merged']}, 正在合并: {model['merging']}")
|
|
|
|
|
|
|
2026-01-29 10:36:59 +08:00
|
|
|
|
return jsonify({
|
|
|
|
|
|
'code': 0,
|
|
|
|
|
|
'data': {
|
|
|
|
|
|
'models': models,
|
2026-01-29 15:51:45 +08:00
|
|
|
|
'base_path': base_path or ''
|
2026-01-29 10:36:59 +08:00
|
|
|
|
}
|
|
|
|
|
|
})
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"获取已训练模型列表失败: {e}")
|
|
|
|
|
|
return jsonify({'code': 1, 'message': str(e)})
|
2026-01-29 23:10:21 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ============ 合并权重接口 ============
|
|
|
|
|
|
|
|
|
|
|
|
@model_manage_bp.route('/merge', methods=['POST'])
|
|
|
|
|
|
def merge_model():
|
|
|
|
|
|
"""合并模型权重(将LoRA适配器合并到基座模型)"""
|
|
|
|
|
|
import subprocess
|
|
|
|
|
|
import sys
|
|
|
|
|
|
import logging
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
data = request.json
|
|
|
|
|
|
model_name = data.get('model_name') # 模型名称
|
|
|
|
|
|
train_method = data.get('train_method', 'lora') # 训练方法
|
|
|
|
|
|
base_model_path = data.get('base_model_path') # 基座模型路径
|
|
|
|
|
|
|
|
|
|
|
|
if not model_name:
|
|
|
|
|
|
return jsonify({'code': 1, 'message': '缺少模型名称'})
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"[MERGE] 开始合并模型: {model_name}, 方法: {train_method}")
|
|
|
|
|
|
|
|
|
|
|
|
# 如果没有提供基座模型路径,从数据库查询
|
|
|
|
|
|
if not base_model_path:
|
|
|
|
|
|
try:
|
|
|
|
|
|
conn = get_db_connection()
|
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
|
|
|
|
|
|
# 优先从训练任务表查询
|
|
|
|
|
|
cursor.execute("""
|
|
|
|
|
|
SELECT base_model FROM fine_tune
|
|
|
|
|
|
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
|
|
|
|
|
|
LIMIT 1
|
|
|
|
|
|
""", (f'%/{model_name}', f'%{model_name}%'))
|
|
|
|
|
|
ft_result = cursor.fetchone()
|
|
|
|
|
|
|
|
|
|
|
|
if ft_result and ft_result.get('base_model'):
|
|
|
|
|
|
base_model_val = ft_result['base_model']
|
|
|
|
|
|
if str(base_model_val).isdigit():
|
|
|
|
|
|
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
|
|
|
|
|
|
model_result = cursor.fetchone()
|
|
|
|
|
|
if model_result:
|
|
|
|
|
|
base_model_path = model_result.get('path')
|
|
|
|
|
|
else:
|
|
|
|
|
|
base_model_path = base_model_val
|
|
|
|
|
|
|
|
|
|
|
|
# 如果没找到,尝试从模型管理表按名称查询
|
|
|
|
|
|
if not base_model_path:
|
|
|
|
|
|
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
|
|
|
|
|
|
model_result = cursor.fetchone()
|
|
|
|
|
|
if model_result:
|
|
|
|
|
|
base_model_path = model_result.get('path')
|
|
|
|
|
|
|
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
|
|
if not base_model_path:
|
|
|
|
|
|
return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'})
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"[MERGE] 查询模型配置失败: {e}")
|
|
|
|
|
|
return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'})
|
|
|
|
|
|
|
|
|
|
|
|
# 训练后的模型路径(LoRA适配器)
|
|
|
|
|
|
adapter_path = f"/app/base/saves/{train_method}/{model_name}"
|
|
|
|
|
|
|
|
|
|
|
|
# 检查路径是否存在
|
|
|
|
|
|
if not os.path.exists(adapter_path):
|
|
|
|
|
|
return jsonify({'code': 1, 'message': f'训练模型不存在: {adapter_path}'})
|
|
|
|
|
|
|
|
|
|
|
|
# 合并后的输出路径
|
|
|
|
|
|
output_path = f"/app/base/local_trained_models/{model_name}"
|
|
|
|
|
|
|
|
|
|
|
|
# 合并状态锁文件
|
|
|
|
|
|
lock_file = f"/app/base/local_trained_models/.merging_{model_name}.lock"
|
|
|
|
|
|
|
|
|
|
|
|
# 创建输出目录
|
|
|
|
|
|
os.makedirs(output_path, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
# 创建锁文件表示正在合并中
|
|
|
|
|
|
try:
|
|
|
|
|
|
with open(lock_file, 'w') as f:
|
|
|
|
|
|
f.write('merging')
|
|
|
|
|
|
|
|
|
|
|
|
work_dir = '/app/base'
|
|
|
|
|
|
|
|
|
|
|
|
# 设置环境变量
|
|
|
|
|
|
env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'}
|
|
|
|
|
|
|
|
|
|
|
|
# 使用 llamafactory-cli export 命令(假设已在系统 PATH 中,与训练命令一致)
|
|
|
|
|
|
cli_cmd = ['llamafactory-cli', 'export']
|
|
|
|
|
|
|
|
|
|
|
|
# 检查 llamafactory-cli 是否存在
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 尝试使用 which 命令(Linux/Mac)
|
|
|
|
|
|
subprocess.run(['which', 'llamafactory-cli'], capture_output=True, check=True)
|
|
|
|
|
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
|
|
|
|
# Windows 上没有 which 命令,直接尝试执行
|
|
|
|
|
|
logger.info("[MERGE] which 命令不可用,直接尝试执行 llamafactory-cli")
|
|
|
|
|
|
|
|
|
|
|
|
# 构建完整命令参数
|
|
|
|
|
|
export_args = [
|
|
|
|
|
|
'--model_name_or_path', base_model_path,
|
|
|
|
|
|
'--adapter_name_or_path', adapter_path,
|
|
|
|
|
|
'--export_dir', output_path
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"[MERGE] 执行合并命令: {' '.join(cli_cmd)} {' '.join(export_args)}")
|
|
|
|
|
|
|
|
|
|
|
|
# 直接执行 llamafactory-cli export 命令
|
|
|
|
|
|
result = subprocess.run(
|
|
|
|
|
|
cli_cmd + export_args,
|
|
|
|
|
|
capture_output=True,
|
|
|
|
|
|
text=True,
|
|
|
|
|
|
timeout=600,
|
|
|
|
|
|
cwd=work_dir or '/app/base',
|
|
|
|
|
|
env=env
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"[MERGE] 命令返回码: {result.returncode}")
|
|
|
|
|
|
logger.info(f"[MERGE] stdout: {result.stdout[:500] if result.stdout else 'empty'}")
|
|
|
|
|
|
logger.info(f"[MERGE] stderr: {result.stderr[:500] if result.stderr else 'empty'}")
|
|
|
|
|
|
|
|
|
|
|
|
# 等待输出目录完全创建
|
|
|
|
|
|
import time
|
|
|
|
|
|
max_wait = 5 # 最多等待5秒
|
|
|
|
|
|
waited = 0
|
|
|
|
|
|
while not os.path.exists(output_path) and waited < max_wait:
|
|
|
|
|
|
time.sleep(0.5)
|
|
|
|
|
|
waited += 0.5
|
|
|
|
|
|
|
|
|
|
|
|
# 无论成功失败,都删除锁文件
|
|
|
|
|
|
if os.path.exists(lock_file):
|
|
|
|
|
|
os.remove(lock_file)
|
|
|
|
|
|
|
|
|
|
|
|
if result.returncode == 0:
|
|
|
|
|
|
# 确保目录存在才返回成功
|
|
|
|
|
|
if os.path.exists(output_path):
|
|
|
|
|
|
return jsonify({
|
|
|
|
|
|
'code': 0,
|
|
|
|
|
|
'message': f'模型权重已成功合并到 {output_path}',
|
|
|
|
|
|
'data': {
|
|
|
|
|
|
'model_name': model_name,
|
|
|
|
|
|
'output_path': output_path
|
|
|
|
|
|
}
|
|
|
|
|
|
})
|
|
|
|
|
|
else:
|
|
|
|
|
|
return jsonify({'code': 1, 'message': '合并失败:输出目录未创建'})
|
|
|
|
|
|
else:
|
|
|
|
|
|
error_msg = result.stderr.strip() if result.stderr else result.stdout.strip()
|
|
|
|
|
|
if not error_msg:
|
|
|
|
|
|
error_msg = f'命令执行失败,返回码: {result.returncode}'
|
|
|
|
|
|
return jsonify({'code': 1, 'message': f'合并失败: {error_msg}'})
|
|
|
|
|
|
|
|
|
|
|
|
except subprocess.TimeoutExpired:
|
|
|
|
|
|
logger.error("[MERGE] 合并超时")
|
|
|
|
|
|
# 删除锁文件
|
|
|
|
|
|
if os.path.exists(lock_file):
|
|
|
|
|
|
os.remove(lock_file)
|
|
|
|
|
|
return jsonify({'code': 1, 'message': '合并超时,请稍后重试'})
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"[MERGE] 合并异常: {str(e)}")
|
|
|
|
|
|
return jsonify({'code': 1, 'message': f'合并异常: {str(e)}'})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ============ 删除已训练模型接口 ============
|
|
|
|
|
|
|
|
|
|
|
|
@model_manage_bp.route('/trained-models/<model_name>', methods=['DELETE'])
|
|
|
|
|
|
def delete_trained_model(model_name):
|
2026-02-02 09:22:52 +08:00
|
|
|
|
"""删除已训练模型
|
|
|
|
|
|
type=merged: 删除合并模型(local_trained_models目录)
|
|
|
|
|
|
type=lora: 删除权重(saves目录下的lora等权重文件)
|
|
|
|
|
|
"""
|
2026-01-29 23:10:21 +08:00
|
|
|
|
import shutil
|
|
|
|
|
|
import logging
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
2026-02-02 09:22:52 +08:00
|
|
|
|
# 获取删除类型参数
|
|
|
|
|
|
delete_type = request.args.get('type', 'merged') # 默认删除合并模型
|
|
|
|
|
|
|
2026-01-29 23:10:21 +08:00
|
|
|
|
try:
|
2026-02-02 09:22:52 +08:00
|
|
|
|
if delete_type == 'lora':
|
|
|
|
|
|
# 删除权重:删除 saves 目录下的权重
|
|
|
|
|
|
saves_path = os.path.join(PROJECT_ROOT, 'saves')
|
|
|
|
|
|
train_methods = ['lora', 'full', 'qlora', 'dpo', 'cpt', 'prefix', 'adapter', 'peft']
|
|
|
|
|
|
|
|
|
|
|
|
deleted = False
|
|
|
|
|
|
for method in train_methods:
|
|
|
|
|
|
weight_path = os.path.join(saves_path, method, model_name)
|
|
|
|
|
|
if os.path.exists(weight_path):
|
|
|
|
|
|
shutil.rmtree(weight_path)
|
|
|
|
|
|
logger.info(f"[DELETE] 已删除权重: {weight_path}")
|
|
|
|
|
|
deleted = True
|
|
|
|
|
|
|
|
|
|
|
|
if not deleted:
|
|
|
|
|
|
# 也可能是老结构,直接在 saves 下的 model_name 目录
|
|
|
|
|
|
old_path = os.path.join(saves_path, model_name)
|
|
|
|
|
|
if os.path.exists(old_path):
|
|
|
|
|
|
shutil.rmtree(old_path)
|
|
|
|
|
|
logger.info(f"[DELETE] 已删除老结构权重: {old_path}")
|
|
|
|
|
|
deleted = True
|
|
|
|
|
|
|
|
|
|
|
|
if deleted:
|
|
|
|
|
|
return jsonify({'code': 0, 'message': '权重已删除'})
|
|
|
|
|
|
else:
|
|
|
|
|
|
return jsonify({'code': 1, 'message': f'权重不存在: {model_name}'})
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 默认删除合并模型(local_trained_models目录)
|
|
|
|
|
|
model_path = os.path.join(PROJECT_ROOT, 'local_trained_models', model_name)
|
2026-01-29 23:10:21 +08:00
|
|
|
|
|
2026-02-02 09:22:52 +08:00
|
|
|
|
if not os.path.exists(model_path):
|
|
|
|
|
|
return jsonify({'code': 1, 'message': f'合并模型不存在: {model_name}'})
|
2026-01-29 23:10:21 +08:00
|
|
|
|
|
2026-02-02 09:22:52 +08:00
|
|
|
|
# 删除目录
|
|
|
|
|
|
shutil.rmtree(model_path)
|
|
|
|
|
|
logger.info(f"[DELETE] 已删除合并模型: {model_path}")
|
2026-01-29 23:10:21 +08:00
|
|
|
|
|
2026-02-02 09:22:52 +08:00
|
|
|
|
return jsonify({'code': 0, 'message': '合并模型已删除'})
|
2026-01-29 23:10:21 +08:00
|
|
|
|
except Exception as e:
|
2026-02-02 09:22:52 +08:00
|
|
|
|
logger.error(f"[DELETE] 删除失败: {str(e)}")
|
2026-01-29 23:10:21 +08:00
|
|
|
|
return jsonify({'code': 1, 'message': f'删除失败: {str(e)}'})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ============ 导出已训练模型接口 ============
|
|
|
|
|
|
|
|
|
|
|
|
@model_manage_bp.route('/trained-models/<model_name>/export', methods=['GET'])
|
|
|
|
|
|
def export_trained_model(model_name):
|
|
|
|
|
|
"""导出已训练模型(打包成zip下载)"""
|
|
|
|
|
|
import shutil
|
|
|
|
|
|
import logging
|
|
|
|
|
|
from flask import send_file
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 优先从 local_trained_models 目录查找(合并后的模型)
|
|
|
|
|
|
model_path = os.path.join(PROJECT_ROOT, 'local_trained_models', model_name)
|
|
|
|
|
|
|
|
|
|
|
|
# 如果本地模型目录不存在,尝试从 saves 目录查找(未合并的模型)
|
|
|
|
|
|
if not os.path.exists(model_path):
|
|
|
|
|
|
# 查找 saves 目录下的模型
|
|
|
|
|
|
saves_path = os.path.join(PROJECT_ROOT, 'saves')
|
|
|
|
|
|
train_methods = ['lora', 'full', 'qlora', 'dpo', 'cpt', 'prefix', 'adapter', 'peft']
|
|
|
|
|
|
|
|
|
|
|
|
for method in train_methods:
|
|
|
|
|
|
potential_path = os.path.join(saves_path, method, model_name)
|
|
|
|
|
|
if os.path.exists(potential_path):
|
|
|
|
|
|
model_path = potential_path
|
|
|
|
|
|
logger.info(f"[EXPORT] 从 saves/{method} 目录找到模型: {model_path}")
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
# 如果还是找不到,返回错误
|
|
|
|
|
|
if not os.path.exists(model_path):
|
|
|
|
|
|
return jsonify({'code': 1, 'message': f'模型不存在: {model_name}'})
|
|
|
|
|
|
|
|
|
|
|
|
# 创建临时 zip 文件
|
|
|
|
|
|
zip_path = os.path.join(PROJECT_ROOT, 'temp_exports')
|
|
|
|
|
|
os.makedirs(zip_path, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
zip_file = os.path.join(zip_path, f'{model_name}.zip')
|
|
|
|
|
|
|
|
|
|
|
|
# 如果已存在先删除
|
|
|
|
|
|
if os.path.exists(zip_file):
|
|
|
|
|
|
os.remove(zip_file)
|
|
|
|
|
|
|
|
|
|
|
|
# 打包成 zip
|
|
|
|
|
|
shutil.make_archive(zip_file[:-4], 'zip', model_path)
|
|
|
|
|
|
logger.info(f"[EXPORT] 已打包模型: {zip_file}")
|
|
|
|
|
|
|
|
|
|
|
|
# 发送文件给前端
|
|
|
|
|
|
response = send_file(
|
|
|
|
|
|
zip_file,
|
|
|
|
|
|
as_attachment=True,
|
|
|
|
|
|
download_name=f'{model_name}.zip',
|
|
|
|
|
|
mimetype='application/zip'
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 注册回调,删除临时文件
|
|
|
|
|
|
def cleanup():
|
|
|
|
|
|
try:
|
|
|
|
|
|
if os.path.exists(zip_file):
|
|
|
|
|
|
os.remove(zip_file)
|
|
|
|
|
|
logger.info(f"[EXPORT] 已清理临时文件: {zip_file}")
|
|
|
|
|
|
except:
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
# 使用 after_request 清理
|
|
|
|
|
|
@response.call_on_close
|
|
|
|
|
|
def cleanup_after_request():
|
|
|
|
|
|
cleanup()
|
|
|
|
|
|
|
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"[EXPORT] 导出模型失败: {str(e)}")
|
|
|
|
|
|
return jsonify({'code': 1, 'message': f'导出失败: {str(e)}'})
|