""" 模型管理 API 路由 """ import os import pymysql import yaml import logging from flask import Blueprint, request, jsonify # 获取模块 logger(继承 main.py 的日志配置) logger = logging.getLogger(__name__) # 获取项目根目录 - 优先使用环境变量,否则从文件路径计算 MOUNT_BASE = os.environ.get('MOUNT_BASE', '/app/base') PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # 如果 PROJECT_ROOT 是 /app 或 /app/src/llamafactory,则使用挂载路径 if PROJECT_ROOT in ('/app', '/app/src/llamafactory'): PROJECT_ROOT = MOUNT_BASE # 创建蓝图 model_manage_bp = Blueprint('model_manage', __name__, url_prefix='/api/model-manage') def get_db_connection(): """获取数据库连接""" CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml') with open(CONFIG_PATH, 'r', encoding='utf-8') as f: CONFIG = yaml.safe_load(f) db_config = CONFIG['database'] return pymysql.connect( host=db_config['host'], port=db_config['port'], user=db_config['username'], password=db_config['password'], database=db_config['name'], charset=db_config.get('charset', 'utf8mb4'), cursorclass=pymysql.cursors.DictCursor ) def generic_get_all(table_name, order_by='create_time DESC'): """通用查询所有""" conn = get_db_connection() cursor = conn.cursor() cursor.execute(f"SELECT * FROM {table_name} ORDER BY {order_by}") result = cursor.fetchall() cursor.close() conn.close() return result def get_model_path_by_name(model_name): """根据模型名称查询模型路径(用于获取基座模型路径)""" logger.info(f"[DEBUG get_model_path_by_name] 查询模型: {model_name}") try: conn = get_db_connection() cursor = conn.cursor() # 优先从训练任务表查询基座模型 logger.info(f"[DEBUG get_model_path_by_name] 尝试从fine_tune表查询...") cursor.execute(""" SELECT base_model, output_model_name FROM fine_tune WHERE output_model_name LIKE %s OR output_model_name LIKE %s LIMIT 1 """, (f'%/{model_name}', f'%{model_name}%')) ft_result = cursor.fetchone() logger.info(f"[DEBUG get_model_path_by_name] fine_tune查询结果: {ft_result}") if ft_result and ft_result.get('base_model'): base_model_val = ft_result['base_model'] logger.info(f"[DEBUG get_model_path_by_name] base_model_val: {base_model_val}") # 如果是数字ID,查询模型管理表获取路径 if str(base_model_val).isdigit(): cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,)) model_result = cursor.fetchone() logger.info(f"[DEBUG get_model_path_by_name] model_manage查询结果(数字ID): {model_result}") if model_result: cursor.close() conn.close() return model_result.get('path') else: # 直接是路径 cursor.close() conn.close() return base_model_val # 如果训练任务表没找到,尝试从模型管理表按名称查询 logger.info(f"[DEBUG get_model_path_by_name] 尝试从model_manage表查询...") cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,)) result = cursor.fetchone() logger.info(f"[DEBUG get_model_path_by_name] model_manage查询结果: {result}") cursor.close() conn.close() if result: return result.get('path') logger.info(f"[DEBUG get_model_path_by_name] 未找到任何匹配,返回None") return None except Exception as e: logger.error(f"[ERROR] 查询模型路径失败: {e}") return None def generic_create(table_name, data): """通用创建""" conn = get_db_connection() cursor = conn.cursor() columns = ', '.join(data.keys()) placeholders = ', '.join(['%s'] * len(data)) sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})" cursor.execute(sql, list(data.values())) conn.commit() new_id = cursor.lastrowid cursor.close() conn.close() return new_id def generic_update(table_name, id_val, data): """通用更新""" conn = get_db_connection() cursor = conn.cursor() set_clause = ', '.join([f"{k} = %s" for k in data.keys()]) sql = f"UPDATE {table_name} SET {set_clause} WHERE id = %s" values = list(data.values()) + [id_val] cursor.execute(sql, values) conn.commit() cursor.close() conn.close() def generic_delete(table_name, id_val): """通用删除""" conn = get_db_connection() cursor = conn.cursor() cursor.execute(f"DELETE FROM {table_name} WHERE id = %s", (id_val,)) conn.commit() cursor.close() conn.close() def generic_get_by_id(table_name, id_val): """通用按ID查询""" conn = get_db_connection() cursor = conn.cursor() cursor.execute(f"SELECT * FROM {table_name} WHERE id = %s", (id_val,)) result = cursor.fetchone() cursor.close() conn.close() return result # ============ 模型管理 CRUD ============ @model_manage_bp.route('', methods=['GET']) def get_model_manage(): """获取所有模型""" return jsonify({'code': 0, 'data': generic_get_all('model_manage')}) @model_manage_bp.route('/', methods=['GET']) def get_model_manage_by_id(id): """获取单个模型""" model = generic_get_by_id('model_manage', id) if model: return jsonify({'code': 0, 'data': model}) return jsonify({'code': 1, 'message': '模型不存在'}) @model_manage_bp.route('/name/', methods=['GET']) def get_model_manage_by_name(model_name): """根据名称获取模型""" logger.info(f"[DEBUG] 按名称查询模型: {model_name}") conn = get_db_connection() cursor = conn.cursor() cursor.execute("SELECT * FROM model_manage WHERE name = %s LIMIT 1", (model_name,)) model = cursor.fetchone() cursor.close() conn.close() if model: return jsonify({'code': 0, 'data': model}) return jsonify({'code': 1, 'message': '模型不存在'}) @model_manage_bp.route('', methods=['POST']) def create_model_manage(): """创建模型""" data = request.json # 构建插入数据 insert_data = { 'name': data.get('name'), 'type': data.get('type'), 'model_source': data.get('model_source', 'local'), 'description': data.get('description'), 'purpose': data.get('purpose', 'inference') # 默认推理用途 } if data.get('model_source') == 'local': insert_data['path'] = data.get('path', '') else: insert_data['api_url'] = data.get('api_url', '') insert_data['api_key'] = data.get('api_key', '') insert_data['model_name'] = data.get('model_name', '') new_id = generic_create('model_manage', insert_data) return jsonify({'code': 0, 'message': '创建成功', 'id': new_id}) @model_manage_bp.route('/', methods=['PUT']) def update_model_manage(id): """更新模型""" data = request.json generic_update('model_manage', id, data) return jsonify({'code': 0, 'message': '更新成功'}) @model_manage_bp.route('//purpose', methods=['PUT']) def update_model_purpose(id): """更新模型用途""" data = request.json purpose = data.get('purpose') if purpose not in ['training', 'inference', 'evaluation']: return jsonify({'code': 1, 'message': '无效的用途类型'}) generic_update('model_manage', id, {'purpose': purpose}) return jsonify({'code': 0, 'message': '更新成功'}) @model_manage_bp.route('/', methods=['DELETE']) def delete_model_manage(id): """删除模型""" generic_delete('model_manage', id) return jsonify({'code': 0, 'message': '删除成功'}) # ============ 本地模型列表接口 ============ @model_manage_bp.route('/local-models', methods=['GET']) def get_local_models(): """获取本地模型列表(从YG_FT_Base/local_models目录)""" import logging logger = logging.getLogger(__name__) try: # 使用 YG_FT_Base/local_models 目录 base_path = os.path.join(PROJECT_ROOT, 'local_models') models = [] if os.path.exists(base_path): for item in os.listdir(base_path): item_path = os.path.join(base_path, item) if os.path.isdir(item_path): models.append({ 'name': item, 'path': item_path }) return jsonify({ 'code': 0, 'data': { 'models': models, 'base_path': base_path } }) except Exception as e: logger.error(f"获取本地模型列表失败: {e}") return jsonify({'code': 1, 'message': str(e)}) # ============ 已训练模型列表接口 ============ @model_manage_bp.route('/trained-models', methods=['GET']) def get_trained_models(): """获取已训练模型列表(从/app/base/saves目录)""" import logging logger = logging.getLogger(__name__) try: # 多个可能的路径 potential_paths = [ '/app/base/saves', # 容器内路径 os.path.join(PROJECT_ROOT, 'saves'), # 本地开发路径 os.path.join(os.path.dirname(os.path.dirname(PROJECT_ROOT)), 'YG_FT_Base', 'saves'), # 上级目录 ] base_path = None for path in potential_paths: logger.info(f"[DEBUG] 检查路径: {path}, exists: {os.path.exists(path)}") if os.path.exists(path): base_path = path break logger.info(f"[DEBUG] 最终使用的路径: {base_path}") models = [] if base_path and os.path.exists(base_path): logger.info(f"[DEBUG] 遍历目录: {base_path}") try: # 路径结构: /app/base/saves/{train_method}/{model_name}/ # train_method: lora, full, qlora, dpo, cpt 等 # 同时兼容老结构: /app/base/saves/{model_name}/ train_methods = ['lora', 'full', 'qlora', 'dpo', 'cpt', 'prefix', 'adapter', 'peft'] for item in os.listdir(base_path): item_path = os.path.join(base_path, item) if not os.path.isdir(item_path): continue # 情况1: 新结构 {train_method}/{model_name} if item in train_methods: logger.info(f"[DEBUG] 检查训练方法目录: {item}") model_count = 0 for model_name in os.listdir(item_path): model_path = os.path.join(item_path, model_name) if not os.path.isdir(model_path): continue try: files = os.listdir(model_path) has_model = any(f.endswith('.bin') or f.endswith('.safetensors') for f in files) if has_model: logger.info(f"[DEBUG] 找到模型: {item}/{model_name}") # 获取文件创建时间 try: import time stat = os.stat(model_path) create_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(stat.st_mtime)) except: create_time = None # 查询基座模型路径 base_model_path = get_model_path_by_name(model_name) models.append({ 'name': model_name, 'path': model_path, 'base_model_path': base_model_path, 'create_time': create_time, 'train_methods': [{ 'name': item, 'path': model_path }] }) model_count += 1 except Exception as file_err: logger.error(f"[DEBUG] 读取 {model_path} 失败: {file_err}") logger.info(f"[DEBUG] {item} 找到 {model_count} 个模型") # 情况2: 老结构 {model_name} 直接在 saves 下 else: logger.info(f"[DEBUG] 检查老结构模型目录: {item}") try: files = os.listdir(item_path) has_model = any(f.endswith('.bin') or f.endswith('.safetensors') for f in files) if has_model: logger.info(f"[DEBUG] 找到模型: {item}") # 尝试从 adapter_config.json 推断 train_method inferred_method = 'lora' # 默认 config_file = os.path.join(item_path, 'adapter_config.json') if os.path.exists(config_file): try: import json with open(config_file, 'r', encoding='utf-8') as f: config = json.load(f) if 'peft_type' in config: peft_type = config['peft_type'].lower() if 'lora' in peft_type: inferred_method = 'lora' elif 'full' in peft_type or 'pt' in peft_type: inferred_method = 'full' except: pass # 获取文件创建时间 try: import time stat = os.stat(item_path) create_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(stat.st_mtime)) except: create_time = None # 查询基座模型路径 base_model_path = get_model_path_by_name(item) models.append({ 'name': item, 'path': item_path, 'base_model_path': base_model_path, 'create_time': create_time, 'train_methods': [{ 'name': inferred_method, 'path': item_path }] }) except Exception as file_err: logger.error(f"[DEBUG] 读取 {item_path} 失败: {file_err}") except Exception as list_err: logger.error(f"[DEBUG] 遍历目录失败: {list_err}") logger.info(f"[DEBUG] 找到 {len(models)} 个已训练模型") # 检查每个模型是否已合并或正在合并 local_trained_path = os.path.join(PROJECT_ROOT, 'local_trained_models') for model in models: model_name = model['name'] merged_path = os.path.join(local_trained_path, model_name) lock_file = os.path.join(local_trained_path, f'.merging_{model_name}.lock') model['merged'] = os.path.exists(merged_path) model['merging'] = os.path.exists(lock_file) logger.info(f"[DEBUG] 模型 {model_name} 已合并: {model['merged']}, 正在合并: {model['merging']}") return jsonify({ 'code': 0, 'data': { 'models': models, 'base_path': base_path or '' } }) except Exception as e: logger.error(f"获取已训练模型列表失败: {e}") return jsonify({'code': 1, 'message': str(e)}) # ============ 合并权重接口 ============ @model_manage_bp.route('/merge', methods=['POST']) def merge_model(): """合并模型权重(将LoRA适配器合并到基座模型)""" import subprocess import sys import logging logger = logging.getLogger(__name__) data = request.json model_name = data.get('model_name') # 模型名称 train_method = data.get('train_method', 'lora') # 训练方法 base_model_path = data.get('base_model_path') # 基座模型路径 if not model_name: return jsonify({'code': 1, 'message': '缺少模型名称'}) logger.info(f"[MERGE] 开始合并模型: {model_name}, 方法: {train_method}") # 如果没有提供基座模型路径,从数据库查询 if not base_model_path: try: conn = get_db_connection() cursor = conn.cursor() # 优先从训练任务表查询 cursor.execute(""" SELECT base_model FROM fine_tune WHERE output_model_name LIKE %s OR output_model_name LIKE %s LIMIT 1 """, (f'%/{model_name}', f'%{model_name}%')) ft_result = cursor.fetchone() if ft_result and ft_result.get('base_model'): base_model_val = ft_result['base_model'] if str(base_model_val).isdigit(): cursor.execute("SELECT path FROM model_manage WHERE id = %s LIMIT 1", (base_model_val,)) model_result = cursor.fetchone() if model_result: base_model_path = model_result.get('path') else: base_model_path = base_model_val # 如果没找到,尝试从模型管理表按名称查询 if not base_model_path: cursor.execute("SELECT path FROM model_manage WHERE name = %s LIMIT 1", (model_name,)) model_result = cursor.fetchone() if model_result: base_model_path = model_result.get('path') conn.close() if not base_model_path: return jsonify({'code': 1, 'message': f'未找到模型 {model_name} 的基座模型配置'}) except Exception as e: logger.error(f"[MERGE] 查询模型配置失败: {e}") return jsonify({'code': 1, 'message': f'查询模型配置失败: {str(e)}'}) # 训练后的模型路径(LoRA适配器) adapter_path = f"/app/base/saves/{train_method}/{model_name}" # 检查路径是否存在 if not os.path.exists(adapter_path): return jsonify({'code': 1, 'message': f'训练模型不存在: {adapter_path}'}) # 合并后的输出路径 output_path = f"/app/base/local_trained_models/{model_name}" # 合并状态锁文件 lock_file = f"/app/base/local_trained_models/.merging_{model_name}.lock" # 创建输出目录 os.makedirs(output_path, exist_ok=True) # 创建锁文件表示正在合并中 try: with open(lock_file, 'w') as f: f.write('merging') work_dir = '/app/base' # 设置环境变量 env = {**os.environ, 'CUDA_VISIBLE_DEVICES': '0'} # 使用 llamafactory-cli export 命令(假设已在系统 PATH 中,与训练命令一致) cli_cmd = ['llamafactory-cli', 'export'] # 检查 llamafactory-cli 是否存在 try: # 尝试使用 which 命令(Linux/Mac) subprocess.run(['which', 'llamafactory-cli'], capture_output=True, check=True) except (subprocess.CalledProcessError, FileNotFoundError): # Windows 上没有 which 命令,直接尝试执行 logger.info("[MERGE] which 命令不可用,直接尝试执行 llamafactory-cli") # 构建完整命令参数 export_args = [ '--model_name_or_path', base_model_path, '--adapter_name_or_path', adapter_path, '--export_dir', output_path ] logger.info(f"[MERGE] 执行合并命令: {' '.join(cli_cmd)} {' '.join(export_args)}") # 直接执行 llamafactory-cli export 命令 result = subprocess.run( cli_cmd + export_args, capture_output=True, text=True, timeout=600, cwd=work_dir or '/app/base', env=env ) logger.info(f"[MERGE] 命令返回码: {result.returncode}") logger.info(f"[MERGE] stdout: {result.stdout[:500] if result.stdout else 'empty'}") logger.info(f"[MERGE] stderr: {result.stderr[:500] if result.stderr else 'empty'}") # 等待输出目录完全创建 import time max_wait = 5 # 最多等待5秒 waited = 0 while not os.path.exists(output_path) and waited < max_wait: time.sleep(0.5) waited += 0.5 # 无论成功失败,都删除锁文件 if os.path.exists(lock_file): os.remove(lock_file) if result.returncode == 0: # 确保目录存在才返回成功 if os.path.exists(output_path): return jsonify({ 'code': 0, 'message': f'模型权重已成功合并到 {output_path}', 'data': { 'model_name': model_name, 'output_path': output_path } }) else: return jsonify({'code': 1, 'message': '合并失败:输出目录未创建'}) else: error_msg = result.stderr.strip() if result.stderr else result.stdout.strip() if not error_msg: error_msg = f'命令执行失败,返回码: {result.returncode}' return jsonify({'code': 1, 'message': f'合并失败: {error_msg}'}) except subprocess.TimeoutExpired: logger.error("[MERGE] 合并超时") # 删除锁文件 if os.path.exists(lock_file): os.remove(lock_file) return jsonify({'code': 1, 'message': '合并超时,请稍后重试'}) except Exception as e: logger.error(f"[MERGE] 合并异常: {str(e)}") return jsonify({'code': 1, 'message': f'合并异常: {str(e)}'}) # ============ 删除已训练模型接口 ============ @model_manage_bp.route('/trained-models/', methods=['DELETE']) def delete_trained_model(model_name): """删除已训练模型 type=merged: 删除合并模型(local_trained_models目录) type=lora: 删除权重(saves目录下的lora等权重文件) """ import shutil import logging logger = logging.getLogger(__name__) # 获取删除类型参数 delete_type = request.args.get('type', 'merged') # 默认删除合并模型 try: if delete_type == 'lora': # 删除权重:删除 saves 目录下的权重 saves_path = os.path.join(PROJECT_ROOT, 'saves') train_methods = ['lora', 'full', 'qlora', 'dpo', 'cpt', 'prefix', 'adapter', 'peft'] deleted = False for method in train_methods: weight_path = os.path.join(saves_path, method, model_name) if os.path.exists(weight_path): shutil.rmtree(weight_path) logger.info(f"[DELETE] 已删除权重: {weight_path}") deleted = True if not deleted: # 也可能是老结构,直接在 saves 下的 model_name 目录 old_path = os.path.join(saves_path, model_name) if os.path.exists(old_path): shutil.rmtree(old_path) logger.info(f"[DELETE] 已删除老结构权重: {old_path}") deleted = True if deleted: return jsonify({'code': 0, 'message': '权重已删除'}) else: return jsonify({'code': 1, 'message': f'权重不存在: {model_name}'}) else: # 默认删除合并模型(local_trained_models目录) model_path = os.path.join(PROJECT_ROOT, 'local_trained_models', model_name) if not os.path.exists(model_path): return jsonify({'code': 1, 'message': f'合并模型不存在: {model_name}'}) # 删除目录 shutil.rmtree(model_path) logger.info(f"[DELETE] 已删除合并模型: {model_path}") return jsonify({'code': 0, 'message': '合并模型已删除'}) except Exception as e: logger.error(f"[DELETE] 删除失败: {str(e)}") return jsonify({'code': 1, 'message': f'删除失败: {str(e)}'}) # ============ 导出已训练模型接口 ============ @model_manage_bp.route('/trained-models//export', methods=['GET']) def export_trained_model(model_name): """导出已训练模型(打包成zip下载)""" import shutil import logging from flask import send_file logger = logging.getLogger(__name__) try: # 优先从 local_trained_models 目录查找(合并后的模型) model_path = os.path.join(PROJECT_ROOT, 'local_trained_models', model_name) # 如果本地模型目录不存在,尝试从 saves 目录查找(未合并的模型) if not os.path.exists(model_path): # 查找 saves 目录下的模型 saves_path = os.path.join(PROJECT_ROOT, 'saves') train_methods = ['lora', 'full', 'qlora', 'dpo', 'cpt', 'prefix', 'adapter', 'peft'] for method in train_methods: potential_path = os.path.join(saves_path, method, model_name) if os.path.exists(potential_path): model_path = potential_path logger.info(f"[EXPORT] 从 saves/{method} 目录找到模型: {model_path}") break # 如果还是找不到,返回错误 if not os.path.exists(model_path): return jsonify({'code': 1, 'message': f'模型不存在: {model_name}'}) # 创建临时 zip 文件 zip_path = os.path.join(PROJECT_ROOT, 'temp_exports') os.makedirs(zip_path, exist_ok=True) zip_file = os.path.join(zip_path, f'{model_name}.zip') # 如果已存在先删除 if os.path.exists(zip_file): os.remove(zip_file) # 打包成 zip shutil.make_archive(zip_file[:-4], 'zip', model_path) logger.info(f"[EXPORT] 已打包模型: {zip_file}") # 发送文件给前端 response = send_file( zip_file, as_attachment=True, download_name=f'{model_name}.zip', mimetype='application/zip' ) # 注册回调,删除临时文件 def cleanup(): try: if os.path.exists(zip_file): os.remove(zip_file) logger.info(f"[EXPORT] 已清理临时文件: {zip_file}") except: pass # 使用 after_request 清理 @response.call_on_close def cleanup_after_request(): cleanup() return response except Exception as e: logger.error(f"[EXPORT] 导出模型失败: {str(e)}") return jsonify({'code': 1, 'message': f'导出失败: {str(e)}'})