""" 远光软件微调平台 - Flask 后端 API """ import os import sys import json import pymysql import yaml import time import logging from datetime import datetime from logging.handlers import TimedRotatingFileHandler, RotatingFileHandler from flask import Flask, request, jsonify, send_from_directory from flask_cors import CORS from werkzeug.utils import secure_filename # 导入API蓝图 from api import register_blueprints # 获取项目根目录 PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, PROJECT_ROOT) # 加载配置 CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml') def load_config(): """加载配置文件""" with open(CONFIG_PATH, 'r', encoding='utf-8') as f: return yaml.safe_load(f) CONFIG = load_config() # 训练日志路径 TRAINING_LOGS_DIR = CONFIG.get('training_logs_path', '/app/base/training_logs') # ============ 日志系统配置 ============ # 日志目录逻辑: # 1. 优先使用环境变量 LOG_BASE_DIR # 2. 如果是容器环境(存在 /app/base),使用 /app/base/logs # 3. 否则使用本地项目路径 PROJECT_ROOT/logs def get_log_base_dir(): """获取日志基础目录""" # 1. 检查环境变量 if 'LOG_BASE_DIR' in os.environ: return os.environ['LOG_BASE_DIR'] # 2. 检查是否在容器环境中 mount_base = os.environ.get('MOUNT_BASE', '/app/base') if os.path.exists(mount_base): return os.path.join(mount_base, 'logs') # 3. 使用本地项目路径 return os.path.join(PROJECT_ROOT, 'logs') LOG_BASE_DIR = get_log_base_dir() def setup_logger(name='app'): """配置日志系统,按日期分目录存储""" # 创建当天的日志目录 today = datetime.now().strftime('%Y-%m-%d') log_dir = os.path.join(LOG_BASE_DIR, today) os.makedirs(log_dir, exist_ok=True) # 获取或创建 logger logger = logging.getLogger(name) logger.setLevel(logging.DEBUG) # 清除已存在的处理器 logger.handlers.clear() # 1. 全部日志处理器 (TimedRotatingFileHandler - 每天午夜分割) all_log_path = os.path.join(log_dir, 'all.log') all_handler = TimedRotatingFileHandler(all_log_path, when='midnight', interval=1, backupCount=30, encoding='utf-8') all_handler.setLevel(logging.DEBUG) all_handler.setFormatter(logging.Formatter( '[%(asctime)s] %(levelname)s in %(module)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S' )) # 2. 错误日志处理器 (RotatingFileHandler - 10MB大小分割,保留10个备份) error_log_path = os.path.join(log_dir, 'error.log') error_handler = RotatingFileHandler(error_log_path, maxBytes=10*1024*1024, backupCount=10, encoding='utf-8') error_handler.setLevel(logging.ERROR) error_handler.setFormatter(logging.Formatter( '[%(asctime)s] %(levelname)s in %(module)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S' )) # 3. 请求日志处理器 request_log_path = os.path.join(log_dir, 'request.log') request_handler = RotatingFileHandler(request_log_path, maxBytes=10*1024*1024, backupCount=10, encoding='utf-8') request_handler.setLevel(logging.INFO) request_handler.setFormatter(logging.Formatter( '[%(asctime)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S' )) # 4. 控制台处理器 console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) console_handler.setFormatter(logging.Formatter( '[%(asctime)s] %(levelname)s: %(message)s', datefmt='%H:%M:%S' )) # 5. 训练日志处理器 - 专门记录训练输出 train_log_path = os.path.join(log_dir, 'train.log') train_handler = RotatingFileHandler(train_log_path, maxBytes=100*1024*1024, backupCount=5, encoding='utf-8') train_handler.setLevel(logging.INFO) train_handler.setFormatter(logging.Formatter( '[%(asctime)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S' )) # 6. API日志处理器 - 专门记录API相关日志 api_log_path = os.path.join(log_dir, 'api.log') api_handler = RotatingFileHandler(api_log_path, maxBytes=10*1024*1024, backupCount=10, encoding='utf-8') api_handler.setLevel(logging.DEBUG) api_handler.setFormatter(logging.Formatter( '[%(asctime)s] %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S' )) # 添加处理器到 logger logger.addHandler(all_handler) logger.addHandler(error_handler) logger.addHandler(console_handler) # 为请求日志创建单独的 logger request_logger = logging.getLogger('request') request_logger.setLevel(logging.INFO) request_logger.handlers.clear() request_logger.addHandler(request_handler) request_logger.addHandler(console_handler) # 为训练日志创建单独的 logger train_logger = logging.getLogger('train') train_logger.setLevel(logging.INFO) train_logger.handlers.clear() train_logger.addHandler(train_handler) train_logger.addHandler(console_handler) # 为 logs_api 创建单独的 logger (供 src/api/logs.py 使用) logs_api_logger = logging.getLogger('logs_api') logs_api_logger.setLevel(logging.DEBUG) logs_api_logger.handlers.clear() logs_api_logger.addHandler(api_handler) logs_api_logger.addHandler(console_handler) return logger # 初始化日志系统 logger = setup_logger('app') request_logger = logging.getLogger('request') logger.info('=' * 50) logger.info('服务启动') logger.info('=' * 50) def get_db_connection(): """获取数据库连接""" db_config = CONFIG['database'] return pymysql.connect( host=db_config['host'], port=db_config['port'], user=db_config['username'], password=db_config['password'], database=db_config['name'], charset=db_config.get('charset', 'utf8mb4'), cursorclass=pymysql.cursors.DictCursor ) def init_database(): """初始化数据库表""" logger.info("正在初始化数据库...") try: conn = get_db_connection() cursor = conn.cursor() tables = [ # 精调训练表 """CREATE TABLE IF NOT EXISTS fine_tune ( id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255) NOT NULL, base_model VARCHAR(255), template VARCHAR(100) COMMENT '训练模板,如 qwen, llama, chatglm 等', train_type VARCHAR(50), train_method VARCHAR(50), gpus JSON COMMENT 'GPU硬件选择,支持多卡训练', dataset_id INT, valid_split VARCHAR(50), valid_ratio INT DEFAULT 10, output_model_name VARCHAR(255), process_id INT COMMENT '训练进程ID', tensorboard_log_dir VARCHAR(255) COMMENT 'TensorBoard日志目录', status VARCHAR(50) DEFAULT 'pending', progress INT DEFAULT 0, create_time DATETIME DEFAULT CURRENT_TIMESTAMP, update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4""", # 我的模型表 """CREATE TABLE IF NOT EXISTS my_models ( id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255) NOT NULL, type VARCHAR(100), version VARCHAR(50), description TEXT, create_time DATETIME DEFAULT CURRENT_TIMESTAMP, update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4""", # 模型评测表 """CREATE TABLE IF NOT EXISTS model_eval ( id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255) NOT NULL, model_name VARCHAR(255) NOT NULL, dataset VARCHAR(255), metric VARCHAR(100), score DECIMAL(10, 4), status VARCHAR(50) DEFAULT 'completed', create_time DATETIME DEFAULT CURRENT_TIMESTAMP ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4""", # 模型部署表 """CREATE TABLE IF NOT EXISTS model_deploy ( id INT AUTO_INCREMENT PRIMARY KEY, model_name VARCHAR(255) NOT NULL, endpoint VARCHAR(255), instance VARCHAR(100), status VARCHAR(50) DEFAULT 'running', create_time DATETIME DEFAULT CURRENT_TIMESTAMP, update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4""", # 模型对比表 """CREATE TABLE IF NOT EXISTS model_compare ( id INT AUTO_INCREMENT PRIMARY KEY, model_name VARCHAR(255) NOT NULL, description TEXT, models JSON, status VARCHAR(50) DEFAULT 'pending', create_time DATETIME DEFAULT CURRENT_TIMESTAMP, update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4""", # 数据集管理表 """CREATE TABLE IF NOT EXISTS dataset_manage ( id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255) NOT NULL, type VARCHAR(100), size VARCHAR(50), count INT, description TEXT, file_path VARCHAR(500), file_count INT DEFAULT 0, storage_type VARCHAR(50) DEFAULT 'local', minio_config TEXT, create_time DATETIME DEFAULT CURRENT_TIMESTAMP, update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4""", # 数据集文件表 """CREATE TABLE IF NOT EXISTS dataset_files ( id INT AUTO_INCREMENT PRIMARY KEY, dataset_id INT NOT NULL, file_name VARCHAR(255) NOT NULL, file_path VARCHAR(500) NOT NULL, file_size BIGINT DEFAULT 0, file_type VARCHAR(50), create_time DATETIME DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (dataset_id) REFERENCES dataset_manage(id) ON DELETE CASCADE ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4""", # 数据生成表 """CREATE TABLE IF NOT EXISTS data_generate ( id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255) NOT NULL, template VARCHAR(255), count INT DEFAULT 0, status VARCHAR(50) DEFAULT 'pending', create_time DATETIME DEFAULT CURRENT_TIMESTAMP, update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4""", # 权限管理表 """CREATE TABLE IF NOT EXISTS permission ( id INT AUTO_INCREMENT PRIMARY KEY, username VARCHAR(100) NOT NULL, role VARCHAR(50) DEFAULT 'user', permissions TEXT, create_time DATETIME DEFAULT CURRENT_TIMESTAMP, update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4""", # 模型管理表 """CREATE TABLE IF NOT EXISTS model_manage ( id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255) NOT NULL, type VARCHAR(100), model_source VARCHAR(20) DEFAULT 'local', path VARCHAR(500), api_url VARCHAR(500), api_key VARCHAR(500), model_name VARCHAR(255), description TEXT, purpose VARCHAR(50) DEFAULT 'inference', create_time DATETIME DEFAULT CURRENT_TIMESTAMP, update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4""", # 评测维度表 """CREATE TABLE IF NOT EXISTS model_dimension ( id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255) NOT NULL, type VARCHAR(100), description TEXT, create_time DATETIME DEFAULT CURRENT_TIMESTAMP, update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4""", # 系统配置表 """CREATE TABLE IF NOT EXISTS sys_config ( id INT AUTO_INCREMENT PRIMARY KEY, config_key VARCHAR(100) NOT NULL UNIQUE, config_value TEXT, description VARCHAR(255), update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4""", # 用户表 """CREATE TABLE IF NOT EXISTS users ( id INT AUTO_INCREMENT PRIMARY KEY, username VARCHAR(100) NOT NULL UNIQUE, password VARCHAR(255) NOT NULL, role VARCHAR(50) DEFAULT 'user', create_time DATETIME DEFAULT CURRENT_TIMESTAMP ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4""" ] for i, table_sql in enumerate(tables): try: cursor.execute(table_sql) logger.debug(f"表 {i+1}/{len(tables)} 创建/检查成功") except Exception as e: logger.error(f"表 {i+1} 创建失败: {e}") # 为已存在的表添加缺失的列(静默处理,不显示重复列的提示) for table_col in [("model_manage", "purpose"), ("model_eval", "name"), ("fine_tune", "gpus")]: try: table_name, col_name = table_col cursor.execute(f"ALTER TABLE {table_name} ADD COLUMN {col_name} JSON") logger.debug(f"{table_name} 表添加 {col_name} 列成功") except Exception: pass # 列已存在时不输出任何信息 # 为 fine_tune 表添加 template 列 try: cursor.execute("ALTER TABLE fine_tune ADD COLUMN template VARCHAR(100) COMMENT '训练模板,如 qwen, llama, chatglm 等'") logger.debug("fine_tune 表添加 template 列成功") except Exception: pass # 列已存在时不输出任何信息 # 为 fine_tune 表添加 process_id 列 try: cursor.execute("ALTER TABLE fine_tune ADD COLUMN process_id INT COMMENT '训练进程ID'") logger.debug("fine_tune 表添加 process_id 列成功") except Exception: pass # 列已存在时不输出任何信息 # 为 fine_tune 表添加训练相关列 columns_to_add = [ ("description", "TEXT COMMENT '任务描述'"), ("train_dataset_id", "INT COMMENT '训练数据集ID'"), ("valid_dataset_id", "INT COMMENT '验证数据集ID'"), ("save_steps", "INT DEFAULT 100 COMMENT '保存步数'"), ("lr_scheduler_type", "VARCHAR(50) DEFAULT 'cosine' COMMENT '学习率调度器'"), ("warmup_ratio", "FLOAT DEFAULT 0.05 COMMENT '预热比例'"), ("weight_decay", "FLOAT DEFAULT 0.01 COMMENT '权重衰减'"), ("batch_size", "INT DEFAULT 1 COMMENT '批次大小'"), ("learning_rate", "FLOAT DEFAULT 0.0001 COMMENT '学习率'"), ("n_epochs", "FLOAT DEFAULT 1.0 COMMENT '训练轮数'"), ("max_length", "INT DEFAULT 512 COMMENT '最大长度'"), ("lora_alpha", "VARCHAR(10) DEFAULT '32' COMMENT 'LoRA alpha'"), ("lora_rank", "VARCHAR(10) DEFAULT '8' COMMENT 'LoRA rank'"), ("lora_dropout", "FLOAT DEFAULT 0.1 COMMENT 'LoRA dropout'"), ("valid_ratio", "INT DEFAULT 10 COMMENT '验证集比例'"), ] for col_name, col_def in columns_to_add: try: cursor.execute(f"ALTER TABLE fine_tune ADD COLUMN {col_name} {col_def}") logger.debug(f"fine_tune 表添加 {col_name} 列成功") except Exception: pass # 列已存在时不输出任何信息 # 插入默认管理员用户 cursor.execute("SELECT * FROM users WHERE username = 'admin'") if not cursor.fetchone(): cursor.execute("INSERT INTO users (username, password, role) VALUES ('admin', 'admin', 'admin')") logger.info("默认管理员用户创建成功") conn.commit() cursor.close() conn.close() logger.info("数据库初始化完成") except Exception as e: logger.error(f"数据库初始化失败: {e}") raise app = Flask(__name__) app.config['SECRET_KEY'] = CONFIG['secret_key'] app.config['CORS_HEADERS'] = 'Content-Type' # 允许所有来源 - 支持跨域请求 CORS(app, resources={ r"/api/*": { "origins": "*", "methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"], "allow_headers": ["Content-Type", "Authorization", "X-Requested-With"], "expose_headers": ["Content-Length", "Content-Range"], "supports_credentials": False, "max_age": 86400 # 缓存预检请求结果 24 小时 } }, vary_header=True) # 注册蓝图 register_blueprints(app) # ============ 请求日志中间件 ============ @app.after_request def log_request(response): """记录所有API请求""" # 排除健康检查接口 if request.path != '/api/health': request_logger.info(f'{request.method} {request.path} - {response.status_code}') return response # ============ 健康检查 ============ @app.route('/api/health', methods=['GET']) def health_check(): """健康检查接口,返回系统监控数据""" import psutil try: cpu_percent = int(psutil.cpu_percent(interval=None)) memory = psutil.virtual_memory() memory_percent = int(memory.percent) disk = psutil.disk_usage('/') disk_percent = int(disk.percent) return jsonify({ 'status': 'ok', 'code': 0, 'data': { 'cpu_percent': cpu_percent, 'memory_percent': memory_percent, 'disk_percent': disk_percent } }) except Exception as e: logger.error(f"健康检查失败: {e}") return jsonify({'status': 'error', 'code': 1, 'message': str(e)}) # ============ 详细系统监控 ============ @app.route('/api/system-info', methods=['GET']) def system_info(): """获取详细系统监控信息""" import psutil import os try: # CPU 信息 cpu_percent = psutil.cpu_percent(interval=None) cpu_counts = psutil.cpu_count() cpu_freq = psutil.cpu_freq() # 内存信息 memory = psutil.virtual_memory() # 磁盘信息 disk = psutil.disk_usage('/') disk_io = psutil.disk_io_counters() # 网络信息 net_io = psutil.net_io_counters() # 系统启动时间 boot_time = psutil.boot_time() uptime_seconds = time.time() - boot_time # GPU 信息 gpu_list = [] try: import pynvml pynvml.nvmlInit() gpu_count = pynvml.nvmlDeviceGetCount() for i in range(gpu_count): try: handle = pynvml.nvmlDeviceGetHandleByIndex(i) name = pynvml.nvmlDeviceGetName(handle) # 获取显存信息 try: mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) memory_used = mem_info.used memory_total = mem_info.total except: memory_used = 0 memory_total = 0 # 获取利用率 try: util = pynvml.nvmlDeviceGetUtilizationRates(handle) gpu_util = util.gpu mem_util = util.memory except: gpu_util = 0 mem_util = 0 # 获取温度 - pynvml 11.x API: 只接受handle参数 try: temp = pynvml.nvmlDeviceGetTemperature(handle) except: temp = 0 # 获取功耗 try: power = pynvml.nvmlDeviceGetPowerUsage(handle) except: power = 0 # 获取风扇转速 (百分比) try: fan_speed = pynvml.nvmlDeviceGetFanSpeed(handle) except: fan_speed = 0 # 获取显卡时钟频率 (MHz) try: clock = pynvml.nvmlDeviceGetClockInfo(handle, pynvml.NVML_CLOCK_SM) except: clock = 0 # 获取显存时钟频率 (MHz) try: mem_clock = pynvml.nvmlDeviceGetClockInfo(handle, pynvml.NVML_CLOCK_MEM) except: mem_clock = 0 # 获取驱动版本信息 try: version = pynvml.nvmlSystemGetDriverVersion() except: version = '' gpu_list.append({ 'name': name.decode() if isinstance(name, bytes) else name, 'memory_used_gb': round(memory_used / (1024**3), 1), 'memory_total_gb': round(memory_total / (1024**3), 1), 'gpu_percent': gpu_util, 'memory_percent': mem_util, 'temperature': temp, 'power_w': round(power / 1000, 1) if power > 0 else 0, 'fan_speed': fan_speed, 'clock_mhz': clock, 'memory_clock_mhz': mem_clock, 'driver_version': version.decode() if isinstance(version, bytes) else version }) except Exception as e: logger.debug(f"获取GPU {i} 信息失败: {e}") continue pynvml.nvmlShutdown() except Exception as e: logger.warning(f"获取GPU信息失败: {e}") gpu_list = [] return jsonify({ 'code': 0, 'data': { 'cpu': { 'percent': cpu_percent, 'cores': cpu_counts, 'frequency_mhz': cpu_freq.current if cpu_freq else 0 }, 'memory': { 'percent': memory.percent, 'used_gb': round(memory.used / (1024**3), 1), 'total_gb': round(memory.total / (1024**3), 1), 'available_gb': round(memory.available / (1024**3), 1), 'cached_gb': round(memory.cached / (1024**3), 1) if hasattr(memory, 'cached') else 0 }, 'disk': { 'percent': disk.percent, 'used_gb': round(disk.used / (1024**3), 0), 'total_gb': round(disk.total / (1024**3), 0), 'read_mb': round(disk_io.read_bytes / (1024**2), 0), 'write_mb': round(disk_io.write_bytes / (1024**2), 0) }, 'network': { 'upload_mb': round(net_io.bytes_sent / (1024**2), 1), 'download_mb': round(net_io.bytes_recv / (1024**2), 1) }, 'system': { 'uptime_seconds': uptime_seconds, 'process_count': len(psutil.pids()) }, 'gpu': gpu_list } }) except Exception as e: logger.error(f"获取系统信息失败: {e}") return jsonify({'code': 1, 'message': str(e)}) # ============ 通用 CRUD 操作 ============ import json def generic_get_all(table_name, order_by='create_time DESC'): """通用查询所有""" conn = get_db_connection() cursor = conn.cursor() cursor.execute(f"SELECT * FROM {table_name} ORDER BY {order_by}") result = cursor.fetchall() cursor.close() conn.close() # 自动解析 JSON 字段 for row in result: for key, value in row.items(): if isinstance(value, str) and value.startswith('[') and value.endswith(']'): try: row[key] = json.loads(value) except: pass elif isinstance(value, str) and value.startswith('{') and value.endswith('}'): try: row[key] = json.loads(value) except: pass return result def generic_get_by_id(table_name, id_val): """通用按ID查询""" conn = get_db_connection() cursor = conn.cursor() cursor.execute(f"SELECT * FROM {table_name} WHERE id = %s", (id_val,)) result = cursor.fetchone() cursor.close() conn.close() # 自动解析 JSON 字段 if result: for key, value in result.items(): if isinstance(value, str) and value.startswith('[') and value.endswith(']'): try: result[key] = json.loads(value) except: pass elif isinstance(value, str) and value.startswith('{') and value.endswith('}'): try: result[key] = json.loads(value) except: pass return result def generic_create(table_name, data): """通用创建""" import json conn = get_db_connection() cursor = conn.cursor() # 处理JSON字段,将列表和字典转为JSON字符串 processed_values = [] for value in data.values(): if isinstance(value, (list, dict)): processed_values.append(json.dumps(value, ensure_ascii=False)) else: processed_values.append(value) columns = ', '.join(data.keys()) placeholders = ', '.join(['%s'] * len(data)) sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})" cursor.execute(sql, processed_values) conn.commit() new_id = cursor.lastrowid cursor.close() conn.close() return new_id def generic_update(table_name, id_val, data): """通用更新""" import json conn = get_db_connection() cursor = conn.cursor() # 处理JSON字段,将列表和字典转为JSON字符串 processed_values = [] for value in data.values(): if isinstance(value, (list, dict)): processed_values.append(json.dumps(value, ensure_ascii=False)) else: processed_values.append(value) set_clause = ', '.join([f"{k} = %s" for k in data.keys()]) sql = f"UPDATE {table_name} SET {set_clause} WHERE id = %s" values = processed_values + [id_val] cursor.execute(sql, values) conn.commit() cursor.close() conn.close() def generic_delete(table_name, id_val): """通用删除""" conn = get_db_connection() cursor = conn.cursor() cursor.execute(f"DELETE FROM {table_name} WHERE id = %s", (id_val,)) conn.commit() cursor.close() conn.close() # ============ 登录接口 ============ @app.route('/api/login', methods=['POST']) def login(): data = request.json username = data.get('username') password = data.get('password') conn = get_db_connection() cursor = conn.cursor() cursor.execute("SELECT * FROM users WHERE username = %s AND password = %s", (username, password)) user = cursor.fetchone() cursor.close() conn.close() if user: return jsonify({'code': 0, 'message': '登录成功', 'data': {'username': user['username'], 'role': user['role']}}) return jsonify({'code': 1, 'message': '用户名或密码错误'}) # ============ 精调训练接口 ============ @app.route('/api/fine-tune', methods=['GET']) def get_fine_tune(): return jsonify({'code': 0, 'data': generic_get_all('fine_tune')}) @app.route('/api/fine-tune/', methods=['GET']) def get_fine_tune_by_id(id): """获取单个训练任务详情""" try: conn = get_db_connection() cursor = conn.cursor() cursor.execute("SELECT * FROM fine_tune WHERE id = %s", (id,)) task = cursor.fetchone() if not task: cursor.close() conn.close() return jsonify({'code': 1, 'message': '任务不存在'}) # 获取列名并转换为字典(get_db_connection已使用DictCursor,task已是字典) if isinstance(task, dict): task_dict = task else: columns = [desc[0] for desc in cursor.description] task_dict = dict(zip(columns, task)) cursor.close() conn.close() # 处理 datetime 序列化 for key, value in task_dict.items(): if isinstance(value, datetime): task_dict[key] = value.strftime('%Y-%m-%d %H:%M:%S') return jsonify({'code': 0, 'data': task_dict}) except Exception as e: return jsonify({'code': 1, 'message': str(e)}) @app.route('/api/fine-tune/progress/', methods=['GET']) def get_fine_tune_progress(id): """获取训练任务的进度(通过解析日志文件)""" try: # 获取任务信息 conn = get_db_connection() cursor = conn.cursor(dictionary=True) cursor.execute("SELECT id, process_id, name, status FROM fine_tune WHERE id = %s", (id,)) task = cursor.fetchone() conn.close() if not task: return jsonify({'code': 1, 'message': '任务不存在'}) process_id = task.get('process_id') task_name = task.get('name', '') if not process_id: return jsonify({'code': 0, 'data': {'progress': 0, 'status': task.get('status', 'unknown'), 'step': '', 'speed': '', 'eta': ''}}) # 查找日志文件 - 优先使用容器路径,如果不存在则使用本地路径 training_logs_dir = TRAINING_LOGS_DIR if not os.path.exists(training_logs_dir): training_logs_dir = os.path.join(PROJECT_ROOT, 'training_logs') if not os.path.exists(training_logs_dir): return jsonify({'code': 0, 'data': {'progress': 0, 'status': task.get('status', 'unknown'), 'step': '', 'speed': '', 'eta': ''}}) log_file = None # 优先按 process_id 查找 for file_name in os.listdir(training_logs_dir): if file_name.endswith('.log') and file_name.startswith(f'{process_id}_'): log_file = os.path.join(training_logs_dir, file_name) break # 如果没找到,尝试按任务名称查找 if not log_file and task_name: for file_name in os.listdir(training_logs_dir): if file_name.endswith('.log') and task_name in file_name: log_file = os.path.join(training_logs_dir, file_name) break if not log_file or not os.path.exists(log_file): return jsonify({'code': 0, 'data': {'progress': 0, 'status': task.get('status', 'unknown'), 'step': '', 'speed': '', 'eta': ''}}) # 读取日志文件内容 try: with open(log_file, 'r', encoding='utf-8') as f: content = f.read() except Exception as e: return jsonify({'code': 0, 'data': {'progress': 0, 'status': task.get('status', 'unknown'), 'step': '', 'speed': '', 'eta': ''}}) # 解析进度 progress = 0 step_info = '' speed_info = '' eta_info = '' import re # 处理 Windows 格式的日志(\r 覆盖行),将 \r 替换为换行 content = content.replace('\r', '\n') # 日志格式: " 3%|▎ | 1/33 [00:09<05:10, 9.69s/it]" # 或: " 30%|███ | 10/33 [01:22<03:00, 7.86s/it]" # 匹配 "数字%|进度条| step/total [elapsed', methods=['PUT']) def update_fine_tune(id): data = request.json generic_update('fine_tune', id, data) return jsonify({'code': 0, 'message': '更新成功'}) @app.route('/api/fine-tune/', methods=['DELETE']) def delete_fine_tune(id): # 删除前获取任务信息(用于删除日志文件) conn = get_db_connection() cursor = conn.cursor() cursor.execute("SELECT process_id, name FROM fine_tune WHERE id = %s", (id,)) task_info = cursor.fetchone() conn.close() # 删除相关的日志文件 if task_info and task_info.get('process_id'): from datetime import datetime process_id = task_info['process_id'] task_name = task_info.get('name', 'unknown') # 优先使用容器路径,如果不存在则使用本地路径 training_logs_dir = TRAINING_LOGS_DIR if not os.path.exists(training_logs_dir): training_logs_dir = os.path.join(PROJECT_ROOT, 'training_logs') try: if os.path.exists(training_logs_dir): for file_name in os.listdir(training_logs_dir): # 查找以 PID 开头的日志文件 if file_name.endswith('.log') and file_name.startswith(f'{process_id}_'): log_file = os.path.join(training_logs_dir, file_name) try: os.remove(log_file) print(f"[INFO] 已删除日志文件: {log_file}") except Exception as e: print(f"[WARN] 删除日志文件失败: {log_file}, 错误: {e}") except Exception as e: print(f"[WARN] 查找或删除日志文件时出错: {e}") # 删除数据库记录 generic_delete('fine_tune', id) return jsonify({'code': 0, 'message': '删除成功'}) # ============ 我的模型接口 ============ @app.route('/api/my-models', methods=['GET']) def get_my_models(): return jsonify({'code': 0, 'data': generic_get_all('my_models')}) @app.route('/api/my-models', methods=['POST']) def create_my_model(): data = request.json new_id = generic_create('my_models', data) return jsonify({'code': 0, 'message': '创建成功', 'id': new_id}) @app.route('/api/my-models/', methods=['PUT']) def update_my_model(id): data = request.json generic_update('my_models', id, data) return jsonify({'code': 0, 'message': '更新成功'}) @app.route('/api/my-models/', methods=['DELETE']) def delete_my_model(id): generic_delete('my_models', id) return jsonify({'code': 0, 'message': '删除成功'}) # ============ 模型评测接口 ============ @app.route('/api/model-eval', methods=['GET']) def get_model_eval(): return jsonify({'code': 0, 'data': generic_get_all('model_eval')}) @app.route('/api/model-eval', methods=['POST']) def create_model_eval(): data = request.json # 获取模型名称和数据集名称 model_id = data.get('model_id') dataset_id = data.get('dataset_id') eval_dimension = data.get('eval_dimension', '') # 获取模型名称 model_name = '' if model_id: conn = get_db_connection() cursor = conn.cursor() cursor.execute("SELECT name FROM model_manage WHERE id = %s", (model_id,)) model_result = cursor.fetchone() cursor.close() conn.close() if model_result: model_name = model_result['name'] # 获取数据集名称 dataset_name = '' if dataset_id: conn = get_db_connection() cursor = conn.cursor() cursor.execute("SELECT name FROM dataset_manage WHERE id = %s", (dataset_id,)) dataset_result = cursor.fetchone() cursor.close() conn.close() if dataset_result: dataset_name = dataset_result['name'] # 构建插入数据,映射到数据库字段 insert_data = { 'name': data.get('name', ''), 'model_name': model_name, 'dataset': dataset_name, 'metric': eval_dimension, 'status': 'pending' } new_id = generic_create('model_eval', insert_data) return jsonify({'code': 0, 'message': '创建成功', 'id': new_id}) @app.route('/api/model-eval/', methods=['PUT']) def update_model_eval(id): data = request.json # 构建更新数据,映射到数据库字段 update_data = {} if 'name' in data: update_data['name'] = data['name'] if 'model_id' in data: model_id = data['model_id'] conn = get_db_connection() cursor = conn.cursor() cursor.execute("SELECT name FROM model_manage WHERE id = %s", (model_id,)) model_result = cursor.fetchone() cursor.close() conn.close() update_data['model_name'] = model_result['name'] if model_result else '' if 'dataset_id' in data: dataset_id = data['dataset_id'] conn = get_db_connection() cursor = conn.cursor() cursor.execute("SELECT name FROM dataset_manage WHERE id = %s", (dataset_id,)) dataset_result = cursor.fetchone() cursor.close() conn.close() update_data['dataset'] = dataset_result['name'] if dataset_result else '' if 'eval_dimension' in data: update_data['metric'] = data['eval_dimension'] if update_data: generic_update('model_eval', id, update_data) return jsonify({'code': 0, 'message': '更新成功'}) @app.route('/api/model-eval/', methods=['DELETE']) def delete_model_eval(id): generic_delete('model_eval', id) return jsonify({'code': 0, 'message': '删除成功'}) # ============ 模型部署接口 ============ @app.route('/api/model-deploy', methods=['GET']) def get_model_deploy(): return jsonify({'code': 0, 'data': generic_get_all('model_deploy')}) @app.route('/api/model-deploy', methods=['POST']) def create_model_deploy(): data = request.json new_id = generic_create('model_deploy', data) return jsonify({'code': 0, 'message': '创建成功', 'id': new_id}) @app.route('/api/model-deploy/', methods=['PUT']) def update_model_deploy(id): data = request.json generic_update('model_deploy', id, data) return jsonify({'code': 0, 'message': '更新成功'}) @app.route('/api/model-deploy/', methods=['DELETE']) def delete_model_deploy(id): generic_delete('model_deploy', id) return jsonify({'code': 0, 'message': '删除成功'}) # ============ 模型对比接口 ============ @app.route('/api/model-compare', methods=['GET']) def get_model_compare(): result = generic_get_all('model_compare') # 确保 models 字段被正确解析为 JSON for row in result: if 'models' in row and isinstance(row['models'], str): try: row['models'] = json.loads(row['models']) logger.debug(f"[model-compare] 解析 models 字段成功: {row['models']}") except Exception as e: logger.error(f"[model-compare] 解析 models 字段失败: {e}, 原始值: {row['models']}") return jsonify({'code': 0, 'data': result}) @app.route('/api/model-compare/', methods=['GET']) def get_model_compare_by_id(id): """获取单个模型对比任务""" result = generic_get_by_id('model_compare', id) if result: # 确保 models 字段被正确解析为 JSON if 'models' in result and isinstance(result['models'], str): try: result['models'] = json.loads(result['models']) except Exception as e: logger.error(f"[model-compare] 解析 models 字段失败: {e}") return jsonify({'code': 0, 'data': result}) return jsonify({'code': 1, 'message': '任务不存在'}) @app.route('/api/model-compare', methods=['POST']) def create_model_compare(): data = request.json.copy() # 字段映射: name -> model_name if 'name' in data: data['model_name'] = data.pop('name') # 设置默认加载状态 if 'status' not in data: data['status'] = 'pending' new_id = generic_create('model_compare', data) return jsonify({'code': 0, 'message': '创建成功', 'id': new_id}) @app.route('/api/model-compare/', methods=['PUT']) def update_model_compare(id): data = request.json generic_update('model_compare', id, data) return jsonify({'code': 0, 'message': '更新成功'}) @app.route('/api/model-compare/', methods=['DELETE']) def delete_model_compare(id): # 先停止加载的模型服务 try: conn = get_db_connection() cursor = conn.cursor() cursor.execute("SELECT models, load_status FROM model_compare WHERE id = %s", (id,)) row = cursor.fetchone() cursor.close() conn.close() if row and row[1] == 'loaded': # 停止模型服务 from subprocess import Popen import signal load_status = json.loads(row[0]) if isinstance(row[0], str) else row[0] if isinstance(load_status, dict) and 'processes' in load_status: for proc in load_status['processes']: if 'pid' in proc: try: os.kill(proc['pid'], signal.SIGTERM) except: pass except Exception as e: logger.error(f"[model-compare] 停止模型服务失败: {e}") generic_delete('model_compare', id) return jsonify({'code': 0, 'message': '删除成功'}) # ============ 模型加载接口 ============ import subprocess import signal # 存储加载状态 (生产环境应使用数据库或Redis) model_compare_processes = {} @app.route('/api/model-compare//load', methods=['POST']) def load_model_compare(id): """加载模型 - 启动模型服务""" try: # 获取对比任务 conn = get_db_connection() cursor = conn.cursor() cursor.execute("SELECT models FROM model_compare WHERE id = %s", (id,)) row = cursor.fetchone() cursor.close() conn.close() if not row: return jsonify({'code': 1, 'message': '任务不存在'}) models = json.loads(row[0]) if isinstance(row[0], str) else row[0] if not models or len(models) < 2: return jsonify({'code': 1, 'message': '模型配置无效'}) # 更新状态为 loading generic_update('model_compare', id, {'status': 'loading'}) # 启动模型服务 processes = [] loaded_models = [] for i, model in enumerate(models): model_path = model.get('model_path', '') gpu_id = model.get('gpu_id', 0) port = model.get('port', 7862 + i * 10) if not model_path: continue # 构建启动命令 # 获取模型名称和模板 model_name = model.get('model_name', '') template = detect_template(model_path) cmd = [ sys.executable, '-m', 'llamafactory.api', '--model_name_or_path', model_path, '--template', template or 'default', '--port', str(port), '--gpu_ids', str(gpu_id) ] logger.info(f"[model-compare] 启动模型服务: {' '.join(cmd)}") # 启动进程 proc = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=PROJECT_ROOT ) processes.append({ 'pid': proc.pid, 'port': port, 'model_name': model_name, 'model_path': model_path }) loaded_models.append({ 'port': port, 'model_name': model_name, 'status': 'starting' }) # 保存进程信息 model_compare_processes[id] = { 'processes': processes, 'loaded_models': loaded_models, 'start_time': datetime.now().timestamp() } # 更新数据库中的加载状态 load_status = { 'processes': processes, 'loaded_models': loaded_models, 'started_at': datetime.now().isoformat() } generic_update('model_compare', id, { 'status': 'loading', 'load_status': json.dumps(load_status, ensure_ascii=False) }) return jsonify({ 'code': 0, 'message': '正在加载模型...', 'data': { 'models': loaded_models, 'check_url': f'/api/model-compare/{id}/load-status' } }) except Exception as e: logger.error(f"[model-compare] 加载模型失败: {e}") generic_update('model_compare', id, {'status': 'pending'}) return jsonify({'code': 1, 'message': f'加载失败: {str(e)}'}) def detect_template(model_path): """根据模型路径检测模板类型""" model_lower = model_path.lower() if 'qwen' in model_lower: return 'qwen' elif 'llama' in model_lower or 'llama' in model_lower: return 'llama' elif 'chatglm' in model_lower: return 'chatglm' elif 'baichuan' in model_lower: return 'baichuan' elif 'mistral' in model_lower: return 'mistral' elif 'yi' in model_lower: return 'yi' elif 'deepseek' in model_lower: return 'deepseek' return 'default' @app.route('/api/model-compare//load-status', methods=['GET']) def get_load_status(id): """获取模型加载状态""" try: conn = get_db_connection() cursor = conn.cursor() cursor.execute("SELECT models, load_status FROM model_compare WHERE id = %s", (id,)) row = cursor.fetchone() cursor.close() conn.close() if not row: return jsonify({'code': 1, 'message': '任务不存在'}) models = json.loads(row[0]) if isinstance(row[0], str) else row[0] load_status = json.loads(row[1]) if row[1] and isinstance(row[1], str) else {} loaded_models = load_status.get('loaded_models', []) processes = load_status.get('processes', []) # 检查每个模型服务是否就绪 all_ready = True for i, model in enumerate(loaded_models): port = model.get('port') try: import requests resp = requests.get(f'http://localhost:{port}/health', timeout=2) if resp.status_code == 200: model['status'] = 'ready' else: model['status'] = 'starting' all_ready = False except: # 检查进程是否还在运行 if 'pid' in processes[i] if i < len(processes) else False: model['status'] = 'starting' all_ready = False else: model['status'] = 'failed' all_ready = False # 更新状态 if all_ready: generic_update('model_compare', id, {'status': 'loaded'}) else: generic_update('model_compare', id, {'status': 'loading'}) # 更新加载状态 load_status['loaded_models'] = loaded_models load_status['all_ready'] = all_ready generic_update('model_compare', id, { 'load_status': json.dumps(load_status, ensure_ascii=False) }) return jsonify({ 'code': 0, 'data': { 'status': 'loaded' if all_ready else 'loading', 'models': loaded_models, 'all_ready': all_ready } }) except Exception as e: logger.error(f"[model-compare] 获取加载状态失败: {e}") return jsonify({'code': 1, 'message': str(e)}) @app.route('/api/model-compare//unload', methods=['POST']) def unload_model_compare(id): """停止加载的模型服务""" try: conn = get_db_connection() cursor = conn.cursor() cursor.execute("SELECT load_status FROM model_compare WHERE id = %s", (id,)) row = cursor.fetchone() cursor.close() conn.close() if row and row[0]: load_status = json.loads(row[0]) if isinstance(row[0], str) else row[0] processes = load_status.get('processes', []) # 停止所有进程 for proc in processes: pid = proc.get('pid') if pid: try: os.kill(pid, signal.SIGTERM) logger.info(f"[model-compare] 已停止进程 {pid}") except ProcessLookupError: pass except Exception as e: logger.warning(f"[model-compare] 停止进程 {pid} 失败: {e}") # 清理内存 if id in model_compare_processes: del model_compare_processes[id] # 更新状态 generic_update('model_compare', id, { 'status': 'pending', 'load_status': None }) return jsonify({'code': 0, 'message': '已停止模型服务'}) except Exception as e: logger.error(f"[model-compare] 停止模型服务失败: {e}") return jsonify({'code': 1, 'message': str(e)}) # ============ 数据生成接口 ============ @app.route('/api/data-generate', methods=['GET']) def get_data_generate(): return jsonify({'code': 0, 'data': generic_get_all('data_generate')}) @app.route('/api/data-generate', methods=['POST']) def create_data_generate(): data = request.json new_id = generic_create('data_generate', data) return jsonify({'code': 0, 'message': '创建成功', 'id': new_id}) @app.route('/api/data-generate/', methods=['PUT']) def update_data_generate(id): data = request.json generic_update('data_generate', id, data) return jsonify({'code': 0, 'message': '更新成功'}) @app.route('/api/data-generate/', methods=['DELETE']) def delete_data_generate(id): generic_delete('data_generate', id) return jsonify({'code': 0, 'message': '删除成功'}) # ============ 权限管理接口 ============ @app.route('/api/permission', methods=['GET']) def get_permission(): return jsonify({'code': 0, 'data': generic_get_all('permission')}) @app.route('/api/permission', methods=['POST']) def create_permission(): data = request.json new_id = generic_create('permission', data) return jsonify({'code': 0, 'message': '创建成功', 'id': new_id}) @app.route('/api/permission/', methods=['PUT']) def update_permission(id): data = request.json generic_update('permission', id, data) return jsonify({'code': 0, 'message': '更新成功'}) @app.route('/api/permission/', methods=['DELETE']) def delete_permission(id): generic_delete('permission', id) return jsonify({'code': 0, 'message': '删除成功'}) # ============ 系统配置接口 ============ @app.route('/api/sys-config', methods=['GET']) def get_sys_config(): return jsonify({'code': 0, 'data': generic_get_all('sys_config')}) @app.route('/api/sys-config', methods=['POST']) def create_sys_config(): data = request.json new_id = generic_create('sys_config', data) return jsonify({'code': 0, 'message': '创建成功', 'id': new_id}) @app.route('/api/sys-config/', methods=['PUT']) def update_sys_config(id): data = request.json generic_update('sys_config', id, data) return jsonify({'code': 0, 'message': '更新成功'}) @app.route('/api/sys-config/', methods=['DELETE']) def delete_sys_config(id): generic_delete('sys_config', id) return jsonify({'code': 0, 'message': '删除成功'}) if __name__ == '__main__': # 启动前先初始化数据库 init_database() app_config = CONFIG['app'] host = app_config['host'] port = app_config['port'] debug = app_config.get('debug', True) logger.info(f'服务启动于 http://{host}:{port}') logger.info(f'Debug模式: {debug}') app.run(host=host, port=port, debug=debug)