""" 模型管理 API 路由 """ import os import pymysql import yaml from flask import Blueprint, request, jsonify # 获取项目根目录 - 优先使用环境变量,否则从文件路径计算 MOUNT_BASE = os.environ.get('MOUNT_BASE', '/app/base') PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # 如果 PROJECT_ROOT 是 /app 或 /app/src/llamafactory,则使用挂载路径 if PROJECT_ROOT in ('/app', '/app/src/llamafactory'): PROJECT_ROOT = MOUNT_BASE # 创建蓝图 model_manage_bp = Blueprint('model_manage', __name__, url_prefix='/api/model-manage') def get_db_connection(): """获取数据库连接""" CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml') with open(CONFIG_PATH, 'r', encoding='utf-8') as f: CONFIG = yaml.safe_load(f) db_config = CONFIG['database'] return pymysql.connect( host=db_config['host'], port=db_config['port'], user=db_config['username'], password=db_config['password'], database=db_config['name'], charset=db_config.get('charset', 'utf8mb4'), cursorclass=pymysql.cursors.DictCursor ) def generic_get_all(table_name, order_by='create_time DESC'): """通用查询所有""" conn = get_db_connection() cursor = conn.cursor() cursor.execute(f"SELECT * FROM {table_name} ORDER BY {order_by}") result = cursor.fetchall() cursor.close() conn.close() return result def generic_create(table_name, data): """通用创建""" conn = get_db_connection() cursor = conn.cursor() columns = ', '.join(data.keys()) placeholders = ', '.join(['%s'] * len(data)) sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})" cursor.execute(sql, list(data.values())) conn.commit() new_id = cursor.lastrowid cursor.close() conn.close() return new_id def generic_update(table_name, id_val, data): """通用更新""" conn = get_db_connection() cursor = conn.cursor() set_clause = ', '.join([f"{k} = %s" for k in data.keys()]) sql = f"UPDATE {table_name} SET {set_clause} WHERE id = %s" values = list(data.values()) + [id_val] cursor.execute(sql, values) conn.commit() cursor.close() conn.close() def generic_delete(table_name, id_val): """通用删除""" conn = get_db_connection() cursor = conn.cursor() cursor.execute(f"DELETE FROM {table_name} WHERE id = %s", (id_val,)) conn.commit() cursor.close() conn.close() def generic_get_by_id(table_name, id_val): """通用按ID查询""" conn = get_db_connection() cursor = conn.cursor() cursor.execute(f"SELECT * FROM {table_name} WHERE id = %s", (id_val,)) result = cursor.fetchone() cursor.close() conn.close() return result # ============ 模型管理 CRUD ============ @model_manage_bp.route('', methods=['GET']) def get_model_manage(): """获取所有模型""" return jsonify({'code': 0, 'data': generic_get_all('model_manage')}) @model_manage_bp.route('/', methods=['GET']) def get_model_manage_by_id(id): """获取单个模型""" model = generic_get_by_id('model_manage', id) if model: return jsonify({'code': 0, 'data': model}) return jsonify({'code': 1, 'message': '模型不存在'}) @model_manage_bp.route('', methods=['POST']) def create_model_manage(): """创建模型""" data = request.json # 构建插入数据 insert_data = { 'name': data.get('name'), 'type': data.get('type'), 'model_source': data.get('model_source', 'local'), 'description': data.get('description'), 'purpose': data.get('purpose', 'inference') # 默认推理用途 } if data.get('model_source') == 'local': insert_data['path'] = data.get('path', '') else: insert_data['api_url'] = data.get('api_url', '') insert_data['api_key'] = data.get('api_key', '') insert_data['model_name'] = data.get('model_name', '') new_id = generic_create('model_manage', insert_data) return jsonify({'code': 0, 'message': '创建成功', 'id': new_id}) @model_manage_bp.route('/', methods=['PUT']) def update_model_manage(id): """更新模型""" data = request.json generic_update('model_manage', id, data) return jsonify({'code': 0, 'message': '更新成功'}) @model_manage_bp.route('//purpose', methods=['PUT']) def update_model_purpose(id): """更新模型用途""" data = request.json purpose = data.get('purpose') if purpose not in ['training', 'inference', 'evaluation']: return jsonify({'code': 1, 'message': '无效的用途类型'}) generic_update('model_manage', id, {'purpose': purpose}) return jsonify({'code': 0, 'message': '更新成功'}) @model_manage_bp.route('/', methods=['DELETE']) def delete_model_manage(id): """删除模型""" generic_delete('model_manage', id) return jsonify({'code': 0, 'message': '删除成功'}) # ============ 本地模型列表接口 ============ @model_manage_bp.route('/local-models', methods=['GET']) def get_local_models(): """获取本地模型列表(从YG_FT_Base/local_models目录)""" import logging logger = logging.getLogger(__name__) try: # 使用 YG_FT_Base/local_models 目录 base_path = os.path.join(PROJECT_ROOT, 'local_models') models = [] if os.path.exists(base_path): for item in os.listdir(base_path): item_path = os.path.join(base_path, item) if os.path.isdir(item_path): models.append({ 'name': item, 'path': item_path }) return jsonify({ 'code': 0, 'data': { 'models': models, 'base_path': base_path } }) except Exception as e: logger.error(f"获取本地模型列表失败: {e}") return jsonify({'code': 1, 'message': str(e)}) # ============ 已训练模型列表接口 ============ @model_manage_bp.route('/trained-models', methods=['GET']) def get_trained_models(): """获取已训练模型列表(从/app/base/saves目录)""" import logging logger = logging.getLogger(__name__) try: # 使用 /app/base/saves 目录(容器内路径) saves_base_path = '/app/base/saves' # 本地开发时的备用路径 local_saves_path = os.path.join(PROJECT_ROOT, 'saves') # 选择存在的路径 base_path = saves_base_path if os.path.exists(saves_base_path) else local_saves_path logger.info(f"[DEBUG] 已训练模型目录: {base_path}, exists: {os.path.exists(base_path)}") models = [] if os.path.exists(base_path): for item in os.listdir(base_path): item_path = os.path.join(base_path, item) if os.path.isdir(item_path): # 检查是否是模板目录(包含训练方法的子目录) sub_items = [] if os.path.exists(item_path): for sub_item in os.listdir(item_path): sub_path = os.path.join(item_path, sub_item) if os.path.isdir(sub_path): # 检查是否包含模型文件(adapter_model.bin 或 pytorch_model.bin 等) has_model = False for f in os.listdir(sub_path): if f.endswith('.bin') or f.endswith('.safetensors'): has_model = True break if has_model: sub_items.append({ 'name': sub_item, 'path': sub_path }) models.append({ 'name': item, 'path': item_path, 'train_methods': sub_items }) logger.info(f"[DEBUG] 找到 {len(models)} 个已训练模型") return jsonify({ 'code': 0, 'data': { 'models': models, 'base_path': base_path } }) except Exception as e: logger.error(f"获取已训练模型列表失败: {e}") return jsonify({'code': 1, 'message': str(e)})