Files
YG_FT_Platform/src/api/model_manage.py

390 lines
15 KiB
Python
Raw Normal View History

"""
模型管理 API 路由
"""
import os
import pymysql
import yaml
from flask import Blueprint, request, jsonify
# 获取项目根目录 - 优先使用环境变量,否则从文件路径计算
MOUNT_BASE = os.environ.get('MOUNT_BASE', '/app/base')
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 如果 PROJECT_ROOT 是 /app 或 /app/src/llamafactory则使用挂载路径
if PROJECT_ROOT in ('/app', '/app/src/llamafactory'):
PROJECT_ROOT = MOUNT_BASE
# 创建蓝图
model_manage_bp = Blueprint('model_manage', __name__, url_prefix='/api/model-manage')
def get_db_connection():
"""获取数据库连接"""
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
CONFIG = yaml.safe_load(f)
db_config = CONFIG['database']
return pymysql.connect(
host=db_config['host'],
port=db_config['port'],
user=db_config['username'],
password=db_config['password'],
database=db_config['name'],
charset=db_config.get('charset', 'utf8mb4'),
cursorclass=pymysql.cursors.DictCursor
)
def generic_get_all(table_name, order_by='create_time DESC'):
"""通用查询所有"""
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute(f"SELECT * FROM {table_name} ORDER BY {order_by}")
result = cursor.fetchall()
cursor.close()
conn.close()
return result
2026-01-29 17:39:06 +08:00
def get_model_path_by_name(model_name):
"""根据模型名称查询模型路径(用于获取基座模型路径)"""
try:
conn = get_db_connection()
cursor = conn.cursor()
# 优先从训练任务表查询基座模型
cursor.execute("""
SELECT base_model FROM fine_tune
WHERE output_model_name LIKE %s OR output_model_name LIKE %s
LIMIT 1
""", (f'%/{model_name}', f'%{model_name}%'))
ft_result = cursor.fetchone()
if ft_result and ft_result.get('base_model'):
base_model_val = ft_result['base_model']
# 如果是数字ID查询模型管理表获取路径
if str(base_model_val).isdigit():
cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,))
model_result = cursor.fetchone()
if model_result:
cursor.close()
conn.close()
return model_result.get('path')
else:
# 直接是路径
cursor.close()
conn.close()
return base_model_val
# 如果训练任务表没找到,尝试从模型管理表按名称查询
cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,))
result = cursor.fetchone()
cursor.close()
conn.close()
if result:
return result.get('path')
return None
except Exception as e:
logger.error(f"[ERROR] 查询模型路径失败: {e}")
return None
def generic_create(table_name, data):
"""通用创建"""
conn = get_db_connection()
cursor = conn.cursor()
columns = ', '.join(data.keys())
placeholders = ', '.join(['%s'] * len(data))
sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
cursor.execute(sql, list(data.values()))
conn.commit()
new_id = cursor.lastrowid
cursor.close()
conn.close()
return new_id
def generic_update(table_name, id_val, data):
"""通用更新"""
conn = get_db_connection()
cursor = conn.cursor()
set_clause = ', '.join([f"{k} = %s" for k in data.keys()])
sql = f"UPDATE {table_name} SET {set_clause} WHERE id = %s"
values = list(data.values()) + [id_val]
cursor.execute(sql, values)
conn.commit()
cursor.close()
conn.close()
def generic_delete(table_name, id_val):
"""通用删除"""
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute(f"DELETE FROM {table_name} WHERE id = %s", (id_val,))
conn.commit()
cursor.close()
conn.close()
def generic_get_by_id(table_name, id_val):
"""通用按ID查询"""
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute(f"SELECT * FROM {table_name} WHERE id = %s", (id_val,))
result = cursor.fetchone()
cursor.close()
conn.close()
return result
# ============ 模型管理 CRUD ============
@model_manage_bp.route('', methods=['GET'])
def get_model_manage():
"""获取所有模型"""
return jsonify({'code': 0, 'data': generic_get_all('model_manage')})
@model_manage_bp.route('/<int:id>', methods=['GET'])
def get_model_manage_by_id(id):
"""获取单个模型"""
model = generic_get_by_id('model_manage', id)
if model:
return jsonify({'code': 0, 'data': model})
return jsonify({'code': 1, 'message': '模型不存在'})
@model_manage_bp.route('', methods=['POST'])
def create_model_manage():
"""创建模型"""
data = request.json
# 构建插入数据
insert_data = {
'name': data.get('name'),
'type': data.get('type'),
'model_source': data.get('model_source', 'local'),
'description': data.get('description'),
'purpose': data.get('purpose', 'inference') # 默认推理用途
}
if data.get('model_source') == 'local':
insert_data['path'] = data.get('path', '')
else:
insert_data['api_url'] = data.get('api_url', '')
insert_data['api_key'] = data.get('api_key', '')
insert_data['model_name'] = data.get('model_name', '')
new_id = generic_create('model_manage', insert_data)
return jsonify({'code': 0, 'message': '创建成功', 'id': new_id})
@model_manage_bp.route('/<int:id>', methods=['PUT'])
def update_model_manage(id):
"""更新模型"""
data = request.json
generic_update('model_manage', id, data)
return jsonify({'code': 0, 'message': '更新成功'})
@model_manage_bp.route('/<int:id>/purpose', methods=['PUT'])
def update_model_purpose(id):
"""更新模型用途"""
data = request.json
purpose = data.get('purpose')
if purpose not in ['training', 'inference', 'evaluation']:
return jsonify({'code': 1, 'message': '无效的用途类型'})
generic_update('model_manage', id, {'purpose': purpose})
return jsonify({'code': 0, 'message': '更新成功'})
@model_manage_bp.route('/<int:id>', methods=['DELETE'])
def delete_model_manage(id):
"""删除模型"""
generic_delete('model_manage', id)
return jsonify({'code': 0, 'message': '删除成功'})
# ============ 本地模型列表接口 ============
@model_manage_bp.route('/local-models', methods=['GET'])
def get_local_models():
"""获取本地模型列表从YG_FT_Base/local_models目录"""
import logging
logger = logging.getLogger(__name__)
try:
# 使用 YG_FT_Base/local_models 目录
base_path = os.path.join(PROJECT_ROOT, 'local_models')
models = []
if os.path.exists(base_path):
for item in os.listdir(base_path):
item_path = os.path.join(base_path, item)
if os.path.isdir(item_path):
models.append({
'name': item,
'path': item_path
})
return jsonify({
'code': 0,
'data': {
'models': models,
'base_path': base_path
}
})
except Exception as e:
logger.error(f"获取本地模型列表失败: {e}")
return jsonify({'code': 1, 'message': str(e)})
# ============ 已训练模型列表接口 ============
@model_manage_bp.route('/trained-models', methods=['GET'])
def get_trained_models():
"""获取已训练模型列表(从/app/base/saves目录"""
import logging
logger = logging.getLogger(__name__)
try:
# 多个可能的路径
potential_paths = [
'/app/base/saves', # 容器内路径
os.path.join(PROJECT_ROOT, 'saves'), # 本地开发路径
os.path.join(os.path.dirname(os.path.dirname(PROJECT_ROOT)), 'YG_FT_Base', 'saves'), # 上级目录
]
base_path = None
for path in potential_paths:
logger.info(f"[DEBUG] 检查路径: {path}, exists: {os.path.exists(path)}")
if os.path.exists(path):
base_path = path
break
logger.info(f"[DEBUG] 最终使用的路径: {base_path}")
models = []
if base_path and os.path.exists(base_path):
logger.info(f"[DEBUG] 遍历目录: {base_path}")
try:
# 路径结构: /app/base/saves/{train_method}/{model_name}/
# train_method: lora, full, qlora, dpo, cpt 等
2026-01-29 17:39:06 +08:00
# 同时兼容老结构: /app/base/saves/{model_name}/
2026-01-29 17:39:06 +08:00
train_methods = ['lora', 'full', 'qlora', 'dpo', 'cpt', 'prefix', 'adapter', 'peft']
2026-01-29 17:39:06 +08:00
for item in os.listdir(base_path):
item_path = os.path.join(base_path, item)
if not os.path.isdir(item_path):
continue
2026-01-29 17:39:06 +08:00
# 情况1: 新结构 {train_method}/{model_name}
if item in train_methods:
logger.info(f"[DEBUG] 检查训练方法目录: {item}")
model_count = 0
for model_name in os.listdir(item_path):
model_path = os.path.join(item_path, model_name)
if not os.path.isdir(model_path):
continue
try:
files = os.listdir(model_path)
has_model = any(f.endswith('.bin') or f.endswith('.safetensors') for f in files)
if has_model:
logger.info(f"[DEBUG] 找到模型: {item}/{model_name}")
# 获取文件创建时间
try:
import time
stat = os.stat(model_path)
create_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(stat.st_mtime))
except:
create_time = None
# 查询基座模型路径
base_model_path = get_model_path_by_name(model_name)
models.append({
'name': model_name,
'path': model_path,
'base_model_path': base_model_path,
'create_time': create_time,
'train_methods': [{
'name': item,
'path': model_path
}]
})
model_count += 1
except Exception as file_err:
logger.error(f"[DEBUG] 读取 {model_path} 失败: {file_err}")
logger.info(f"[DEBUG] {item} 找到 {model_count} 个模型")
# 情况2: 老结构 {model_name} 直接在 saves 下
else:
logger.info(f"[DEBUG] 检查老结构模型目录: {item}")
try:
2026-01-29 17:39:06 +08:00
files = os.listdir(item_path)
has_model = any(f.endswith('.bin') or f.endswith('.safetensors') for f in files)
if has_model:
2026-01-29 17:39:06 +08:00
logger.info(f"[DEBUG] 找到模型: {item}")
# 尝试从 adapter_config.json 推断 train_method
inferred_method = 'lora' # 默认
config_file = os.path.join(item_path, 'adapter_config.json')
if os.path.exists(config_file):
try:
import json
with open(config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
if 'peft_type' in config:
peft_type = config['peft_type'].lower()
if 'lora' in peft_type:
inferred_method = 'lora'
elif 'full' in peft_type or 'pt' in peft_type:
inferred_method = 'full'
except:
pass
# 获取文件创建时间
try:
import time
stat = os.stat(item_path)
create_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(stat.st_mtime))
except:
create_time = None
# 查询基座模型路径
base_model_path = get_model_path_by_name(item)
models.append({
2026-01-29 17:39:06 +08:00
'name': item,
'path': item_path,
'base_model_path': base_model_path,
'create_time': create_time,
'train_methods': [{
2026-01-29 17:39:06 +08:00
'name': inferred_method,
'path': item_path
}]
})
except Exception as file_err:
2026-01-29 17:39:06 +08:00
logger.error(f"[DEBUG] 读取 {item_path} 失败: {file_err}")
except Exception as list_err:
logger.error(f"[DEBUG] 遍历目录失败: {list_err}")
logger.info(f"[DEBUG] 找到 {len(models)} 个已训练模型")
return jsonify({
'code': 0,
'data': {
'models': models,
'base_path': base_path or ''
}
})
except Exception as e:
logger.error(f"获取已训练模型列表失败: {e}")
return jsonify({'code': 1, 'message': str(e)})