模型微调已经调通

增加了参数预览
This commit is contained in:
2026-01-28 10:31:09 +08:00
parent 8a638b6372
commit a560d24e2f
8 changed files with 898 additions and 96 deletions

View File

@@ -6,6 +6,7 @@ from .model_manage import model_manage_bp
from .model_chat import model_chat_bp
from .dimension import dimension_bp
from .logs import logs_bp
from .fine_tune import fine_tune_bp
# 注册所有蓝图
def register_blueprints(app):
@@ -15,3 +16,4 @@ def register_blueprints(app):
app.register_blueprint(model_chat_bp)
app.register_blueprint(dimension_bp)
app.register_blueprint(logs_bp)
app.register_blueprint(fine_tune_bp)

View File

@@ -5,6 +5,7 @@ import io
import os
import time
import zipfile
import json
from flask import Blueprint, request, jsonify, send_from_directory, Response
from werkzeug.utils import secure_filename
@@ -52,6 +53,45 @@ def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def update_dataset_info_json(dataset_name=None, actual_filename=None, remove_filename=None):
"""更新 datasets/dataset_info.json 文件
Args:
dataset_name: 数据集名称(用于作为 key
actual_filename: 实际保存的文件名(含时间戳前缀),用于 file_name
remove_filename: 要移除的文件名为None表示不移除
"""
info_path = os.path.join(DATASET_FOLDER, 'dataset_info.json')
# 读取现有配置
dataset_info = {}
if os.path.exists(info_path):
try:
with open(info_path, 'r', encoding='utf-8') as f:
dataset_info = json.load(f)
except Exception as e:
print(f"读取 dataset_info.json 失败: {e}")
# 移除旧条目(根据移除的文件名)
if remove_filename:
key = os.path.splitext(remove_filename)[0]
if key in dataset_info:
del dataset_info[key]
print(f"从 dataset_info.json 移除: {key}")
# 添加新条目
if dataset_name and actual_filename:
dataset_info[dataset_name] = {"file_name": actual_filename}
print(f"更新 dataset_info.json: {dataset_name} -> {actual_filename}")
# 写入文件
try:
with open(info_path, 'w', encoding='utf-8') as f:
json.dump(dataset_info, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"写入 dataset_info.json 失败: {e}")
def generic_get_by_id(table_name, id_val):
"""通用按ID查询"""
conn = get_db_connection()
@@ -149,17 +189,40 @@ def delete_dataset(id):
"""删除数据集"""
conn = get_db_connection()
cursor = conn.cursor()
# 获取文件路径列表
cursor.execute("SELECT file_path FROM dataset_files WHERE dataset_id = %s", (id,))
# 获取数据集名称(用于从 dataset_info.json 移除条目)
cursor.execute("SELECT name FROM dataset_manage WHERE id = %s", (id,))
dataset_result = cursor.fetchone()
dataset_name = dataset_result['name'] if dataset_result else None
# 获取文件信息列表(包含原始文件名)
cursor.execute("SELECT file_name, file_path FROM dataset_files WHERE dataset_id = %s", (id,))
files = cursor.fetchall()
# 删除文件
for f in files:
file_path = f.get('file_path')
if file_path and os.path.exists(file_path):
try:
os.remove(file_path)
except Exception as e:
print(f"删除文件失败: {file_path}, {e}")
# 尝试多个可能的路径
paths_to_try = []
if file_path:
paths_to_try.append(file_path)
# 尝试 PROJECT_ROOT 相对路径
rel_path = file_path.replace('/app/base', PROJECT_ROOT, 1) if file_path.startswith('/app/base') else None
if rel_path:
paths_to_try.append(rel_path)
for path in paths_to_try:
if path and os.path.exists(path):
try:
os.remove(path)
print(f"已删除文件: {path}")
break
except Exception as e:
print(f"删除文件失败: {path}, {e}")
# 使用数据集名称从 dataset_info.json 移除条目
if dataset_name:
update_dataset_info_json(remove_filename=dataset_name)
# 删除数据库记录
cursor.execute("DELETE FROM dataset_files WHERE dataset_id = %s", (id,))
cursor.execute("DELETE FROM dataset_manage WHERE id = %s", (id,))
@@ -218,6 +281,12 @@ def upload_dataset_file(dataset_id):
file.save(file_path)
file_size = os.path.getsize(file_path)
# 获取数据集名称用于作为 dataset_info.json 的 key
dataset_name = dataset.get('name') if dataset else None
# 更新 dataset_info.json使用数据集名称作为 key实际保存的文件名作为 file_name
update_dataset_info_json(dataset_name=dataset_name, actual_filename=new_filename)
# 获取文件扩展名(安全处理无扩展名的情况)
parts = filename.rsplit('.', 1)
ext = parts[1].lower() if len(parts) > 1 else 'unknown'
@@ -304,8 +373,8 @@ def delete_dataset_file(file_id):
conn = get_db_connection()
cursor = conn.cursor()
# 获取文件信息
cursor.execute("SELECT dataset_id, file_path FROM dataset_files WHERE id = %s", (file_id,))
# 获取文件信息(包含 dataset_id
cursor.execute("SELECT dataset_id, file_name, file_path FROM dataset_files WHERE id = %s", (file_id,))
file_info = cursor.fetchone()
if not file_info:
@@ -313,6 +382,13 @@ def delete_dataset_file(file_id):
conn.close()
return jsonify({'code': 1, 'message': '文件不存在'})
dataset_id = file_info['dataset_id']
# 获取数据集名称(用于从 dataset_info.json 移除条目)
cursor.execute("SELECT name FROM dataset_manage WHERE id = %s", (dataset_id,))
dataset_result = cursor.fetchone()
dataset_name = dataset_result['name'] if dataset_result else None
# 删除物理文件
file_path = file_info['file_path']
if file_path and os.path.exists(file_path):
@@ -324,8 +400,11 @@ def delete_dataset_file(file_id):
# 删除数据库记录
cursor.execute("DELETE FROM dataset_files WHERE id = %s", (file_id,))
# 使用数据集名称从 dataset_info.json 移除条目
if dataset_name:
update_dataset_info_json(remove_filename=dataset_name)
# 更新数据集的文件数量和大小
dataset_id = file_info['dataset_id']
cursor.execute("SELECT COUNT(*) as count, SUM(file_size) as total_size FROM dataset_files WHERE dataset_id = %s", (dataset_id,))
result = cursor.fetchone()
file_count = result['count'] or 0

443
src/api/fine_tune.py Normal file
View File

@@ -0,0 +1,443 @@
"""
精调训练 API 路由
调用 llamafactory-cli 执行训练任务
"""
import os
import subprocess
import json
import threading
import time
from flask import Blueprint, request, jsonify
import logging
logger = logging.getLogger(__name__)
train_logger = logging.getLogger('train') # 专门的训练日志 logger输出到 train.log
# 创建蓝图
fine_tune_bp = Blueprint('fine_tune', __name__, url_prefix='/api/fine-tune')
# 训练类型映射
TRAIN_TYPE_MAP = {
'SFT': 'sft',
'DPO': 'dpo',
'CPT': 'cpt'
}
# 训练方法映射
FINETUNING_TYPE_MAP = {
'lora': 'lora',
'full': 'full'
}
@fine_tune_bp.route('/start', methods=['POST'])
def start_training():
"""启动 llamafactory 训练任务"""
try:
data = request.json
train_logger.info(f"[TRAIN] ========== 开始训练任务 ==========")
train_logger.info(f"[TRAIN] 收到启动训练请求: base_model={data.get('base_model')}, train_dataset_id={data.get('train_dataset_id')}")
# 必填参数验证
required_fields = ['base_model', 'template', 'train_dataset_id']
for field in required_fields:
if not data.get(field):
return jsonify({'code': 1, 'message': f'缺少必要参数: {field}'})
# 获取模型信息
model_path = data.get('base_model')
# 尝试转换为整数
try:
model_id = int(model_path) if str(model_path).isdigit() else None
except (ValueError, TypeError):
model_id = None
if model_id:
# 如果是 model_id需要获取模型路径
from .model_manage import get_db_connection
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("SELECT id, name, path FROM model_manage WHERE id = %s", (model_id,))
result = cursor.fetchone()
conn.close()
logger.info(f"模型查询结果: {result}")
if result and result.get('path'):
model_path = result['path']
logger.info(f"从数据库获取的模型路径: {model_path}")
else:
return jsonify({'code': 1, 'message': '模型不存在或路径为空'})
elif not model_path:
return jsonify({'code': 1, 'message': f'模型路径为空'})
train_logger.info(f"[TRAIN] 模型路径: {model_path}")
# 设置工作目录为 llamafactory 目录
llamafactory_dir = '/app/src/llamafactory'
# 处理数据集文件:将数据集复制到 llamafactory 的 datasets 目录
dataset_id = data.get('train_dataset_id')
try:
dataset_id_int = int(dataset_id) if str(dataset_id).isdigit() else None
except (ValueError, TypeError):
dataset_id_int = None
llamafactory_datasets_dir = os.path.join(llamafactory_dir, 'datasets')
os.makedirs(llamafactory_datasets_dir, exist_ok=True)
# 获取数据集名称(用于 --dataset 参数)
dataset_key = None
if dataset_id_int:
from .datasets import get_db_connection as get_dataset_conn
conn = get_dataset_conn()
cursor = conn.cursor()
cursor.execute("SELECT dm.name FROM dataset_manage dm WHERE dm.id = %s", (dataset_id_int,))
dataset_result = cursor.fetchone()
conn.close()
dataset_key = dataset_result['name'] if dataset_result else None
if dataset_key:
# 从 dataset_info.json 读取实际文件名
src_info_json = os.path.join('/app/base', 'datasets', 'dataset_info.json')
actual_file_name = None
if os.path.exists(src_info_json):
import json as json_lib
with open(src_info_json, 'r', encoding='utf-8') as f:
dataset_info = json_lib.load(f)
if dataset_key in dataset_info:
actual_file_name = dataset_info[dataset_key].get('file_name')
train_logger.info(f"[TRAIN] 从 dataset_info.json 获取文件名: {dataset_key} -> {actual_file_name}")
# 复制数据集文件到 llamafactory 目录
if actual_file_name:
src_file = os.path.join('/app/base', 'datasets', actual_file_name)
dst_file = os.path.join(llamafactory_datasets_dir, actual_file_name)
if os.path.exists(src_file):
import shutil
shutil.copy2(src_file, dst_file)
train_logger.info(f"[TRAIN] 复制数据集文件: {src_file} -> {dst_file}")
else:
train_logger.warning(f"[TRAIN] 数据集文件不存在: {src_file}")
# 复制 dataset_info.json 到 llamafactory datasets 目录
src_info_json = os.path.join('/app/base', 'datasets', 'dataset_info.json')
dst_info_json = os.path.join(llamafactory_datasets_dir, 'dataset_info.json')
try:
if os.path.exists(src_info_json):
shutil.copy2(src_info_json, dst_info_json)
train_logger.info(f"[TRAIN] 已复制 dataset_info.json 到 llamafactory 目录")
else:
train_logger.warning(f"[TRAIN] dataset_info.json 不存在: {src_info_json}")
except Exception as e:
train_logger.warning(f"[TRAIN] 复制 dataset_info.json 失败: {e}")
# 获取选中的 GPU 索引
gpus = data.get('gpus', [])
if gpus:
gpu_ids = [gpu.get('id', '').replace('gpu', '') for gpu in gpus]
gpu_ids = [g for g in gpu_ids if g.isdigit()]
cuda_devices = ','.join(gpu_ids)
else:
cuda_devices = '0'
# 设置环境变量
env = os.environ.copy()
env['CUDA_VISIBLE_DEVICES'] = cuda_devices
env['TF_CPP_MIN_LOG_LEVEL'] = '2' # 减少 TensorFlow 日志
# 构建 llamafactory-cli 命令(传入数据集名称用于 --dataset 参数)
cmd = build_train_command(data, model_path, dataset_key)
cmd_str = ' '.join(cmd)
train_logger.info(f"[TRAIN] 执行训练命令: {cmd_str}")
# 在返回的命令中显示 GPU 配置
cmd_str_with_gpu = f"CUDA_VISIBLE_DEVICES={cuda_devices} {cmd_str}"
# 生成训练日志文件路径(按日期分目录)
from datetime import datetime
today = datetime.now().strftime('%Y-%m-%d')
task_id_str = str(data.get('task_id', 'unknown'))
log_dir = os.path.join(llamafactory_dir, 'logs', today)
train_output_log = os.path.join(log_dir, f'train_{task_id_str}.log')
# 确保日志目录存在
os.makedirs(log_dir, exist_ok=True)
train_logger.info(f"[TRAIN] 启动训练进程...")
# 使用线程在后台运行训练进程
def run_training():
with open(train_output_log, 'w', encoding='utf-8') as log_file:
process = subprocess.Popen(
cmd,
cwd=llamafactory_dir,
stdout=log_file,
stderr=subprocess.STDOUT,
env=env
)
train_logger.info(f"[TRAIN] 训练进程 PID: {process.pid}")
# 等待进程完成
process.wait()
train_logger.info(f"[TRAIN] 训练进程已结束,退出码: {process.returncode}")
# 更新任务状态
final_status = 'completed' if process.returncode == 0 else 'failed'
update_fine_tune_status(data.get('task_id'), final_status, process.pid)
# 启动后台线程
training_thread = threading.Thread(target=run_training, daemon=True)
training_thread.start()
# 立即返回,不等待进程完成
pid = None # 此时还不知道实际 PID稍后可从日志获取
train_logger.info(f"[TRAIN] 训练任务已在后台启动")
train_logger.info(f"[TRAIN] 训练日志输出到: {train_output_log}")
# 更新任务状态为运行中
update_fine_tune_status(data.get('task_id'), 'running', 0)
return jsonify({
'code': 0,
'message': f'训练任务已启动 (GPU: {cuda_devices})',
'data': {
'task_id': data.get('task_id'),
'gpu_ids': cuda_devices,
'command': cmd_str_with_gpu,
'log_file': train_output_log
}
})
except Exception as e:
import traceback
train_logger.error(f"[TRAIN] 启动训练任务失败: {e}")
train_logger.error(f"[TRAIN] 详细错误: {traceback.format_exc()}")
return jsonify({'code': 1, 'message': str(e)})
def build_train_command(data, model_path, dataset_name=None):
"""构建 llamafactory-cli train 命令"""
# llamafactory-cli 路径(已在系统 PATH 中)
cmd = ['llamafactory-cli', 'train']
# 训练阶段
train_type = data.get('train_type', 'SFT')
cmd.extend(['--stage', TRAIN_TYPE_MAP.get(train_type, 'sft')])
cmd.append('--do_train')
# 模型路径
cmd.extend(['--model_name_or_path', model_path])
# 数据集 - 使用数据集名称dataset_manage.name不是实际文件名
if dataset_name:
cmd.extend(['--dataset', dataset_name])
train_logger.info(f"[TRAIN] 使用数据集名称: {dataset_name}")
else:
# 回退到原有逻辑
dataset_id = data.get('train_dataset_id')
try:
dataset_id_int = int(dataset_id) if str(dataset_id).isdigit() else None
except (ValueError, TypeError):
dataset_id_int = None
if dataset_id_int:
dataset_name = get_dataset_name(dataset_id_int)
train_logger.info(f"[TRAIN] 从数据库获取的数据集名称: {dataset_name}")
else:
dataset_name = dataset_id
cmd.extend(['--dataset', dataset_name])
# 数据集目录
cmd.extend(['--dataset_dir', './datasets']) # llamafactory 工作目录下的 datasets 目录
# 模板
template = data.get('template')
cmd.extend(['--template', template])
# 训练方法
train_method = data.get('train_method', 'lora')
cmd.extend(['--finetuning_type', FINETUNING_TYPE_MAP.get(train_method, 'lora')])
# 输出目录
output_dir = data.get('output_model_name', f"./saves/{template}/{train_method}")
if not output_dir.startswith('./'):
output_dir = f"./saves/{output_dir}"
cmd.extend(['--output_dir', output_dir])
# 常用参数
cmd.extend([
'--overwrite_cache',
'--overwrite_output_dir',
'--cutoff_len', str(data.get('max_length', 512)),
'--preprocessing_num_workers', '16',
'--per_device_train_batch_size', str(data.get('batch_size', 1)),
'--per_device_eval_batch_size', '1',
'--gradient_accumulation_steps', str(data.get('gradient_accumulation_steps', 8)),
'--lr_scheduler_type', data.get('lr_scheduler_type', 'cosine'),
'--logging_steps', '50',
'--warmup_steps', str(data.get('warmup_steps', 20)),
'--save_steps', '100',
'--eval_steps', str(data.get('eval_steps', 100)),
])
# 学习率
cmd.extend(['--learning_rate', str(data.get('learning_rate', 0.0001))])
# 训练轮数
cmd.extend(['--num_train_epochs', str(data.get('n_epochs', 1.0))])
# 验证集比例
val_ratio = data.get('valid_ratio', 0)
if val_ratio > 0:
cmd.extend(['--val_size', str(val_ratio / 100)])
# 最大样本数
if data.get('max_samples'):
cmd.extend(['--max_samples', str(data.get('max_samples'))])
# 其他选项
if data.get('plot_loss'):
cmd.append('--plot_loss')
if data.get('fp16'):
cmd.append('--fp16')
if data.get('load_best_model_at_end'):
cmd.append('--load_best_model_at_end')
return cmd
def get_dataset_name(dataset_id):
"""根据数据集 ID 获取数据集名称"""
try:
from .datasets import get_db_connection
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("SELECT id, name FROM dataset_manage WHERE id = %s", (dataset_id,))
result = cursor.fetchone()
conn.close()
logger.info(f"数据集查询结果: {result}")
if result and result.get('name'):
return result['name']
logger.warning(f"未找到数据集 ID={dataset_id},使用默认值")
return 'default'
except Exception as e:
logger.error(f"查询数据集失败: {e}")
return 'default'
def update_fine_tune_status(task_id, status, pid=None):
"""更新训练任务状态"""
try:
from .model_manage import get_db_connection
conn = get_db_connection()
cursor = conn.cursor()
if status == 'running' and pid:
cursor.execute(
"UPDATE fine_tune SET status = %s, process_id = %s WHERE id = %s",
(status, pid, task_id)
)
else:
cursor.execute(
"UPDATE fine_tune SET status = %s WHERE id = %s",
(status, task_id)
)
conn.commit()
conn.close()
except Exception as e:
logger.error(f"更新任务状态失败: {e}")
@fine_tune_bp.route('/stop/<int:task_id>', methods=['POST'])
def stop_training(task_id):
"""停止训练任务"""
try:
from .model_manage import get_db_connection
# 获取进程 ID
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("SELECT process_id FROM fine_tune WHERE id = %s", (task_id,))
result = cursor.fetchone()
conn.close()
if result and result.get('process_id'):
pid = result['process_id']
try:
# 尝试终止进程
import signal
os.kill(pid, signal.SIGTERM)
logger.info(f"已终止训练进程 PID: {pid}")
except ProcessLookupError:
logger.warning(f"进程 {pid} 不存在")
except PermissionError:
logger.error(f"没有权限终止进程 {pid}")
# 更新状态
update_fine_tune_status(task_id, 'stopped')
return jsonify({'code': 0, 'message': '训练任务已停止'})
except Exception as e:
logger.error(f"停止训练任务失败: {e}")
return jsonify({'code': 1, 'message': str(e)})
@fine_tune_bp.route('/status/<int:task_id>', methods=['GET'])
def get_training_status(task_id):
"""获取训练任务状态"""
try:
from .model_manage import get_db_connection
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute(
"SELECT id, name, status, progress, process_id FROM fine_tune WHERE id = %s",
(task_id,)
)
result = cursor.fetchone()
conn.close()
if result:
return jsonify({
'code': 0,
'data': {
'task_id': result['id'],
'name': result['name'],
'status': result['status'],
'progress': result['progress'],
'pid': result.get('process_id')
}
})
else:
return jsonify({'code': 1, 'message': '任务不存在'})
except Exception as e:
logger.error(f"获取任务状态失败: {e}")
return jsonify({'code': 1, 'message': str(e)})
def get_db_connection():
"""获取数据库连接"""
import pymysql
import yaml
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'config.yaml')
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
CONFIG = yaml.safe_load(f)
db_config = CONFIG['database']
return pymysql.connect(
host=db_config['host'],
port=db_config['port'],
user=db_config['username'],
password=db_config['password'],
database=db_config['name'],
charset=db_config.get('charset', 'utf8mb4'),
cursorclass=pymysql.cursors.DictCursor
)

View File

@@ -6,8 +6,12 @@ import pymysql
import yaml
from flask import Blueprint, request, jsonify
# 获取项目根目录
# 获取项目根目录 - 优先使用环境变量,否则从文件路径计算
MOUNT_BASE = os.environ.get('MOUNT_BASE', '/app/base')
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 如果 PROJECT_ROOT 是 /app 或 /app/src/llamafactory则使用挂载路径
if PROJECT_ROOT in ('/app', '/app/src/llamafactory'):
PROJECT_ROOT = MOUNT_BASE
# 创建蓝图
model_manage_bp = Blueprint('model_manage', __name__, url_prefix='/api/model-manage')

View File

@@ -86,6 +86,15 @@ def setup_logger(name='app'):
datefmt='%H:%M:%S'
))
# 5. 训练日志处理器 - 专门记录训练输出
train_log_path = os.path.join(log_dir, 'train.log')
train_handler = RotatingFileHandler(train_log_path, maxBytes=100*1024*1024, backupCount=5, encoding='utf-8')
train_handler.setLevel(logging.INFO)
train_handler.setFormatter(logging.Formatter(
'[%(asctime)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
))
# 添加处理器到 logger
logger.addHandler(all_handler)
logger.addHandler(error_handler)
@@ -98,6 +107,13 @@ def setup_logger(name='app'):
request_logger.addHandler(request_handler)
request_logger.addHandler(console_handler)
# 为训练日志创建单独的 logger
train_logger = logging.getLogger('train')
train_logger.setLevel(logging.INFO)
train_logger.handlers.clear()
train_logger.addHandler(train_handler)
train_logger.addHandler(console_handler)
return logger
@@ -137,6 +153,7 @@ def init_database():
id INT AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(255) NOT NULL,
base_model VARCHAR(255),
template VARCHAR(100) COMMENT '训练模板,如 qwen, llama, chatglm 等',
train_type VARCHAR(50),
train_method VARCHAR(50),
gpus JSON COMMENT 'GPU硬件选择支持多卡训练',
@@ -144,6 +161,7 @@ def init_database():
valid_split VARCHAR(50),
valid_ratio INT DEFAULT 10,
output_model_name VARCHAR(255),
process_id INT COMMENT '训练进程ID',
status VARCHAR(50) DEFAULT 'pending',
progress INT DEFAULT 0,
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
@@ -305,6 +323,44 @@ def init_database():
except Exception:
pass # 列已存在时不输出任何信息
# 为 fine_tune 表添加 template 列
try:
cursor.execute("ALTER TABLE fine_tune ADD COLUMN template VARCHAR(100) COMMENT '训练模板,如 qwen, llama, chatglm 等'")
logger.debug("fine_tune 表添加 template 列成功")
except Exception:
pass # 列已存在时不输出任何信息
# 为 fine_tune 表添加 process_id 列
try:
cursor.execute("ALTER TABLE fine_tune ADD COLUMN process_id INT COMMENT '训练进程ID'")
logger.debug("fine_tune 表添加 process_id 列成功")
except Exception:
pass # 列已存在时不输出任何信息
# 为 fine_tune 表添加训练相关列
columns_to_add = [
("train_dataset_id", "INT COMMENT '训练数据集ID'"),
("valid_dataset_id", "INT COMMENT '验证数据集ID'"),
("eval_steps", "INT DEFAULT 100 COMMENT '评估步数'"),
("lr_scheduler_type", "VARCHAR(50) DEFAULT 'cosine' COMMENT '学习率调度器'"),
("warmup_ratio", "FLOAT DEFAULT 0.05 COMMENT '预热比例'"),
("weight_decay", "FLOAT DEFAULT 0.01 COMMENT '权重衰减'"),
("batch_size", "INT DEFAULT 1 COMMENT '批次大小'"),
("learning_rate", "FLOAT DEFAULT 0.0001 COMMENT '学习率'"),
("n_epochs", "FLOAT DEFAULT 1.0 COMMENT '训练轮数'"),
("max_length", "INT DEFAULT 512 COMMENT '最大长度'"),
("lora_alpha", "VARCHAR(10) DEFAULT '32' COMMENT 'LoRA alpha'"),
("lora_rank", "VARCHAR(10) DEFAULT '8' COMMENT 'LoRA rank'"),
("lora_dropout", "FLOAT DEFAULT 0.1 COMMENT 'LoRA dropout'"),
("valid_ratio", "INT DEFAULT 10 COMMENT '验证集比例'"),
]
for col_name, col_def in columns_to_add:
try:
cursor.execute(f"ALTER TABLE fine_tune ADD COLUMN {col_name} {col_def}")
logger.debug(f"fine_tune 表添加 {col_name} 列成功")
except Exception:
pass # 列已存在时不输出任何信息
# 插入默认管理员用户
cursor.execute("SELECT * FROM users WHERE username = 'admin'")
if not cursor.fetchone():
@@ -323,8 +379,8 @@ def init_database():
app = Flask(__name__)
app.config['SECRET_KEY'] = CONFIG['secret_key']
app.config['CORS_HEADERS'] = 'Content-Type'
# 使用字符串形式的 origins
CORS(app, origins="*", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["Content-Type", "Authorization"], supports_credentials=False)
# 允许所有来源
CORS(app, resources={r"/api/*": {"origins": "*"}}, methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["Content-Type", "Authorization"])
# 注册蓝图
register_blueprints(app)