1. 新增了日志系统

2. 新增了添加新训练选择对应的GPU
This commit is contained in:
2026-01-23 11:07:09 +08:00
parent 7f64362826
commit 730ac6f460
6 changed files with 860 additions and 14 deletions

View File

@@ -7,6 +7,9 @@ import json
import pymysql
import yaml
import time
import logging
from datetime import datetime
from logging.handlers import TimedRotatingFileHandler, RotatingFileHandler
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
from werkzeug.utils import secure_filename
@@ -30,6 +33,82 @@ def load_config():
CONFIG = load_config()
# ============ 日志系统配置 ============
LOG_BASE_DIR = os.path.join(PROJECT_ROOT, 'logs')
def setup_logger(name='app'):
"""配置日志系统,按日期分目录存储"""
# 创建当天的日志目录
today = datetime.now().strftime('%Y-%m-%d')
log_dir = os.path.join(LOG_BASE_DIR, today)
os.makedirs(log_dir, exist_ok=True)
# 获取或创建 logger
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
# 清除已存在的处理器
logger.handlers.clear()
# 1. 全部日志处理器 (TimedRotatingFileHandler - 每天午夜分割)
all_log_path = os.path.join(log_dir, 'all.log')
all_handler = TimedRotatingFileHandler(all_log_path, when='midnight', interval=1, backupCount=30, encoding='utf-8')
all_handler.setLevel(logging.DEBUG)
all_handler.setFormatter(logging.Formatter(
'[%(asctime)s] %(levelname)s in %(module)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
))
# 2. 错误日志处理器 (RotatingFileHandler - 10MB大小分割保留10个备份)
error_log_path = os.path.join(log_dir, 'error.log')
error_handler = RotatingFileHandler(error_log_path, maxBytes=10*1024*1024, backupCount=10, encoding='utf-8')
error_handler.setLevel(logging.ERROR)
error_handler.setFormatter(logging.Formatter(
'[%(asctime)s] %(levelname)s in %(module)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
))
# 3. 请求日志处理器
request_log_path = os.path.join(log_dir, 'request.log')
request_handler = RotatingFileHandler(request_log_path, maxBytes=10*1024*1024, backupCount=10, encoding='utf-8')
request_handler.setLevel(logging.INFO)
request_handler.setFormatter(logging.Formatter(
'[%(asctime)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
))
# 4. 控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(logging.Formatter(
'[%(asctime)s] %(levelname)s: %(message)s',
datefmt='%H:%M:%S'
))
# 添加处理器到 logger
logger.addHandler(all_handler)
logger.addHandler(error_handler)
logger.addHandler(console_handler)
# 为请求日志创建单独的 logger
request_logger = logging.getLogger('request')
request_logger.setLevel(logging.INFO)
request_logger.handlers.clear()
request_logger.addHandler(request_handler)
request_logger.addHandler(console_handler)
return logger
# 初始化日志系统
logger = setup_logger('app')
request_logger = logging.getLogger('request')
logger.info('=' * 50)
logger.info('服务启动')
logger.info('=' * 50)
def get_db_connection():
"""获取数据库连接"""
@@ -47,7 +126,7 @@ def get_db_connection():
def init_database():
"""初始化数据库表"""
print("正在初始化数据库...")
logger.info("正在初始化数据库...")
try:
conn = get_db_connection()
cursor = conn.cursor()
@@ -60,6 +139,7 @@ def init_database():
base_model VARCHAR(255),
train_type VARCHAR(50),
train_method VARCHAR(50),
gpus JSON COMMENT 'GPU硬件选择支持多卡训练',
dataset_id INT,
valid_split VARCHAR(50),
valid_ratio INT DEFAULT 10,
@@ -212,31 +292,31 @@ def init_database():
for i, table_sql in enumerate(tables):
try:
cursor.execute(table_sql)
print(f" {i+1}/{len(tables)} 创建/检查成功")
logger.debug(f"{i+1}/{len(tables)} 创建/检查成功")
except Exception as e:
print(f" {i+1} 创建失败: {e}")
logger.error(f"{i+1} 创建失败: {e}")
# 为已存在的表添加缺失的列(静默处理,不显示重复列的提示)
for table_col in [("model_manage", "purpose"), ("model_eval", "name")]:
for table_col in [("model_manage", "purpose"), ("model_eval", "name"), ("fine_tune", "gpus")]:
try:
table_name, col_name = table_col
cursor.execute(f"ALTER TABLE {table_name} ADD COLUMN {col_name} VARCHAR(255) DEFAULT ''")
print(f" {table_name} 表添加 {col_name} 列成功")
except Exception as e:
cursor.execute(f"ALTER TABLE {table_name} ADD COLUMN {col_name} JSON")
logger.debug(f"{table_name} 表添加 {col_name} 列成功")
except Exception:
pass # 列已存在时不输出任何信息
# 插入默认管理员用户
cursor.execute("SELECT * FROM users WHERE username = 'admin'")
if not cursor.fetchone():
cursor.execute("INSERT INTO users (username, password, role) VALUES ('admin', 'admin', 'admin')")
print(" 默认管理员用户创建成功")
logger.info("默认管理员用户创建成功")
conn.commit()
cursor.close()
conn.close()
print("数据库初始化完成")
logger.info("数据库初始化完成")
except Exception as e:
print(f"数据库初始化失败: {e}")
logger.error(f"数据库初始化失败: {e}")
raise
@@ -250,6 +330,16 @@ CORS(app, origins="*", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allo
register_blueprints(app)
# ============ 请求日志中间件 ============
@app.after_request
def log_request(response):
"""记录所有API请求"""
# 排除健康检查接口
if request.path != '/api/health':
request_logger.info(f'{request.method} {request.path} - {response.status_code}')
return response
# ============ 健康检查 ============
@app.route('/api/health', methods=['GET'])
def health_check():
@@ -272,6 +362,7 @@ def health_check():
}
})
except Exception as e:
logger.error(f"健康检查失败: {e}")
return jsonify({'status': 'error', 'code': 1, 'message': str(e)})
@@ -689,4 +780,9 @@ if __name__ == '__main__':
# 启动前先初始化数据库
init_database()
app_config = CONFIG['app']
app.run(host=app_config['host'], port=app_config['port'], debug=app_config.get('debug', True))
host = app_config['host']
port = app_config['port']
debug = app_config.get('debug', True)
logger.info(f'服务启动于 http://{host}:{port}')
logger.info(f'Debug模式: {debug}')
app.run(host=host, port=port, debug=debug)