模型微调已经调通
增加了参数预览
This commit is contained in:
60
src/main.py
60
src/main.py
@@ -86,6 +86,15 @@ def setup_logger(name='app'):
|
||||
datefmt='%H:%M:%S'
|
||||
))
|
||||
|
||||
# 5. 训练日志处理器 - 专门记录训练输出
|
||||
train_log_path = os.path.join(log_dir, 'train.log')
|
||||
train_handler = RotatingFileHandler(train_log_path, maxBytes=100*1024*1024, backupCount=5, encoding='utf-8')
|
||||
train_handler.setLevel(logging.INFO)
|
||||
train_handler.setFormatter(logging.Formatter(
|
||||
'[%(asctime)s] %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
))
|
||||
|
||||
# 添加处理器到 logger
|
||||
logger.addHandler(all_handler)
|
||||
logger.addHandler(error_handler)
|
||||
@@ -98,6 +107,13 @@ def setup_logger(name='app'):
|
||||
request_logger.addHandler(request_handler)
|
||||
request_logger.addHandler(console_handler)
|
||||
|
||||
# 为训练日志创建单独的 logger
|
||||
train_logger = logging.getLogger('train')
|
||||
train_logger.setLevel(logging.INFO)
|
||||
train_logger.handlers.clear()
|
||||
train_logger.addHandler(train_handler)
|
||||
train_logger.addHandler(console_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
@@ -137,6 +153,7 @@ def init_database():
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
base_model VARCHAR(255),
|
||||
template VARCHAR(100) COMMENT '训练模板,如 qwen, llama, chatglm 等',
|
||||
train_type VARCHAR(50),
|
||||
train_method VARCHAR(50),
|
||||
gpus JSON COMMENT 'GPU硬件选择,支持多卡训练',
|
||||
@@ -144,6 +161,7 @@ def init_database():
|
||||
valid_split VARCHAR(50),
|
||||
valid_ratio INT DEFAULT 10,
|
||||
output_model_name VARCHAR(255),
|
||||
process_id INT COMMENT '训练进程ID',
|
||||
status VARCHAR(50) DEFAULT 'pending',
|
||||
progress INT DEFAULT 0,
|
||||
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
@@ -305,6 +323,44 @@ def init_database():
|
||||
except Exception:
|
||||
pass # 列已存在时不输出任何信息
|
||||
|
||||
# 为 fine_tune 表添加 template 列
|
||||
try:
|
||||
cursor.execute("ALTER TABLE fine_tune ADD COLUMN template VARCHAR(100) COMMENT '训练模板,如 qwen, llama, chatglm 等'")
|
||||
logger.debug("fine_tune 表添加 template 列成功")
|
||||
except Exception:
|
||||
pass # 列已存在时不输出任何信息
|
||||
|
||||
# 为 fine_tune 表添加 process_id 列
|
||||
try:
|
||||
cursor.execute("ALTER TABLE fine_tune ADD COLUMN process_id INT COMMENT '训练进程ID'")
|
||||
logger.debug("fine_tune 表添加 process_id 列成功")
|
||||
except Exception:
|
||||
pass # 列已存在时不输出任何信息
|
||||
|
||||
# 为 fine_tune 表添加训练相关列
|
||||
columns_to_add = [
|
||||
("train_dataset_id", "INT COMMENT '训练数据集ID'"),
|
||||
("valid_dataset_id", "INT COMMENT '验证数据集ID'"),
|
||||
("eval_steps", "INT DEFAULT 100 COMMENT '评估步数'"),
|
||||
("lr_scheduler_type", "VARCHAR(50) DEFAULT 'cosine' COMMENT '学习率调度器'"),
|
||||
("warmup_ratio", "FLOAT DEFAULT 0.05 COMMENT '预热比例'"),
|
||||
("weight_decay", "FLOAT DEFAULT 0.01 COMMENT '权重衰减'"),
|
||||
("batch_size", "INT DEFAULT 1 COMMENT '批次大小'"),
|
||||
("learning_rate", "FLOAT DEFAULT 0.0001 COMMENT '学习率'"),
|
||||
("n_epochs", "FLOAT DEFAULT 1.0 COMMENT '训练轮数'"),
|
||||
("max_length", "INT DEFAULT 512 COMMENT '最大长度'"),
|
||||
("lora_alpha", "VARCHAR(10) DEFAULT '32' COMMENT 'LoRA alpha'"),
|
||||
("lora_rank", "VARCHAR(10) DEFAULT '8' COMMENT 'LoRA rank'"),
|
||||
("lora_dropout", "FLOAT DEFAULT 0.1 COMMENT 'LoRA dropout'"),
|
||||
("valid_ratio", "INT DEFAULT 10 COMMENT '验证集比例'"),
|
||||
]
|
||||
for col_name, col_def in columns_to_add:
|
||||
try:
|
||||
cursor.execute(f"ALTER TABLE fine_tune ADD COLUMN {col_name} {col_def}")
|
||||
logger.debug(f"fine_tune 表添加 {col_name} 列成功")
|
||||
except Exception:
|
||||
pass # 列已存在时不输出任何信息
|
||||
|
||||
# 插入默认管理员用户
|
||||
cursor.execute("SELECT * FROM users WHERE username = 'admin'")
|
||||
if not cursor.fetchone():
|
||||
@@ -323,8 +379,8 @@ def init_database():
|
||||
app = Flask(__name__)
|
||||
app.config['SECRET_KEY'] = CONFIG['secret_key']
|
||||
app.config['CORS_HEADERS'] = 'Content-Type'
|
||||
# 使用字符串形式的 origins
|
||||
CORS(app, origins="*", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["Content-Type", "Authorization"], supports_credentials=False)
|
||||
# 允许所有来源
|
||||
CORS(app, resources={r"/api/*": {"origins": "*"}}, methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["Content-Type", "Authorization"])
|
||||
|
||||
# 注册蓝图
|
||||
register_blueprints(app)
|
||||
|
||||
Reference in New Issue
Block a user