模型开始训练界面以及查看日志功能完善
This commit is contained in:
@@ -219,10 +219,21 @@
|
||||
<div class="mb-6">
|
||||
<h3 class="text-sm font-semibold text-gray-700 mb-4 pb-2 border-b border-gray-100">基本信息</h3>
|
||||
<div class="mb-4">
|
||||
<label class="block text-sm text-gray-600 mb-3">任务名称</label>
|
||||
<label class="block text-sm text-gray-600 mb-3">
|
||||
任务名称
|
||||
<span class="text-gray-400 text-xs ml-1">(英文、数字、下划线)</span>
|
||||
</label>
|
||||
<div>
|
||||
<input type="text" name="name" class="w-full px-3 py-2 border border-gray-300 rounded-lg text-sm focus:border-primary focus:outline-none" placeholder="请输入任务名称" maxlength="50">
|
||||
<p class="text-xs text-gray-400 mt-1"><span id="nameCount">0</span> / 50</p>
|
||||
<p id="nameFormatError" class="text-xs text-red-500 mt-1 hidden">任务名称只能包含英文、数字和下划线</p>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<label class="block text-sm text-gray-600 mb-3">任务描述</label>
|
||||
<div>
|
||||
<textarea name="description" class="w-full px-3 py-2 border border-gray-300 rounded-lg text-sm focus:border-primary focus:outline-none resize-none" placeholder="请输入任务描述(选填)" maxlength="200" rows="3"></textarea>
|
||||
<p class="text-xs text-gray-400 mt-1"><span id="descriptionCount">0</span> / 200</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -466,16 +477,16 @@
|
||||
</tr>
|
||||
<tr class="hover:bg-blue-50/30 transition-colors">
|
||||
<td class="py-3 px-4">
|
||||
<span class="text-gray-700 font-mono text-sm">eval_steps</span>
|
||||
<span class="text-gray-700 font-mono text-sm">save_steps</span>
|
||||
<span class="text-red-500 ml-1">*</span>
|
||||
</td>
|
||||
<td class="py-3 px-4">
|
||||
<input type="number" name="eval_steps" value="100" min="10" max="10000" class="w-24 px-3 py-1.5 border border-gray-300 rounded-lg text-sm text-center focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20 transition-all">
|
||||
<input type="number" name="save_steps" value="100" min="10" max="10000" class="w-24 px-3 py-1.5 border border-gray-300 rounded-lg text-sm text-center focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20 transition-all">
|
||||
</td>
|
||||
<td class="py-3 px-4 text-xs text-gray-500">
|
||||
<span class="inline-flex items-center px-2 py-0.5 rounded bg-gray-100 text-gray-600 font-mono">[10, 10000]</span>
|
||||
</td>
|
||||
<td class="py-3 px-4 text-xs text-gray-500 leading-relaxed">每训练多少步进行一次模型评估,建议设置为100的倍数</td>
|
||||
<td class="py-3 px-4 text-xs text-gray-500 leading-relaxed">每训练多少步进行一次模型保存,建议设置为100的倍数</td>
|
||||
</tr>
|
||||
<tr class="hover:bg-blue-50/30 transition-colors">
|
||||
<td class="py-3 px-4">
|
||||
@@ -616,14 +627,7 @@
|
||||
<div class="mb-6">
|
||||
<h3 class="text-sm font-semibold text-gray-700 mb-4 pb-2 border-b border-gray-100">训练产出</h3>
|
||||
|
||||
<!-- 模型名称 -->
|
||||
<div class="mb-4">
|
||||
<label class="block text-sm text-gray-600 mb-3">模型名称</label>
|
||||
<div>
|
||||
<input type="text" name="output_model_name" class="w-64 px-3 py-2 border border-gray-300 rounded-lg text-sm focus:border-primary focus:outline-none" placeholder="请输入模型名称" maxlength="50">
|
||||
<p class="text-xs text-gray-400 mt-1"><span id="modelNameCount">0</span> / 50</p>
|
||||
</div>
|
||||
</div>
|
||||
<p class="text-sm text-gray-500 mb-4">训练完成后,模型将保存为: <code class="bg-gray-100 px-2 py-0.5 rounded text-primary" id="modelNamePreview">任务名称</code></p>
|
||||
|
||||
<!-- 训练命令预览 -->
|
||||
<div class="mt-4">
|
||||
@@ -678,16 +682,38 @@
|
||||
});
|
||||
});
|
||||
|
||||
// 任务名称字数统计
|
||||
// 任务名称字数统计和实时预览(只能输入英文、数字、下划线)
|
||||
const nameInput = document.querySelector('input[name="name"]');
|
||||
const nameFormatError = document.getElementById('nameFormatError');
|
||||
const nameRegex = /^[a-zA-Z0-9_]*$/;
|
||||
|
||||
nameInput.addEventListener('input', () => {
|
||||
const value = nameInput.value;
|
||||
// 验证格式
|
||||
if (value.length > 0 && !nameRegex.test(value)) {
|
||||
nameInput.classList.add('border-red-500');
|
||||
nameInput.classList.remove('border-gray-300');
|
||||
nameFormatError.classList.remove('hidden');
|
||||
} else {
|
||||
nameInput.classList.remove('border-red-500');
|
||||
nameInput.classList.add('border-gray-300');
|
||||
nameFormatError.classList.add('hidden');
|
||||
}
|
||||
// 过滤非法字符:只允许英文、数字、下划线
|
||||
const filteredValue = value.replace(/[^a-zA-Z0-9_]/g, '');
|
||||
if (value !== filteredValue) {
|
||||
nameInput.value = filteredValue;
|
||||
}
|
||||
document.getElementById('nameCount').textContent = nameInput.value.length;
|
||||
// 更新模型名称预览
|
||||
document.getElementById('modelNamePreview').textContent = nameInput.value || '任务名称';
|
||||
updateCommandPreview();
|
||||
});
|
||||
|
||||
// 模型名称字数统计
|
||||
const modelNameInput = document.querySelector('input[name="output_model_name"]');
|
||||
modelNameInput.addEventListener('input', () => {
|
||||
document.getElementById('modelNameCount').textContent = modelNameInput.value.length;
|
||||
// 任务描述字数统计
|
||||
const descInput = document.querySelector('textarea[name="description"]');
|
||||
descInput.addEventListener('input', () => {
|
||||
document.getElementById('descriptionCount').textContent = descInput.value.length;
|
||||
});
|
||||
|
||||
// 加载数据集列表
|
||||
@@ -774,7 +800,7 @@
|
||||
'batch_size': 1,
|
||||
'learning_rate': 0.0001,
|
||||
'n_epochs': 1,
|
||||
'eval_steps': 100,
|
||||
'save_steps': 100,
|
||||
'lr_scheduler_type': 'cosine',
|
||||
'max_length': 512,
|
||||
'warmup_ratio': 0.05,
|
||||
@@ -1014,7 +1040,7 @@
|
||||
batch_size: parseInt(formData.get('batch_size')) || 1,
|
||||
learning_rate: parseFloat(formData.get('learning_rate')) || 0.0001,
|
||||
n_epochs: parseFloat(formData.get('n_epochs')) || 1.0,
|
||||
eval_steps: parseInt(formData.get('eval_steps')) || 100,
|
||||
save_steps: parseInt(formData.get('save_steps')) || 100,
|
||||
lr_scheduler_type: formData.get('lr_scheduler_type') || 'cosine',
|
||||
max_length: parseInt(formData.get('max_length')) || 512,
|
||||
warmup_ratio: parseFloat(formData.get('warmup_ratio')) || 0.05,
|
||||
@@ -1024,15 +1050,18 @@
|
||||
lora_rank: formData.get('lora_rank') || '8'
|
||||
};
|
||||
|
||||
const taskName = formData.get('name');
|
||||
|
||||
const data = {
|
||||
name: formData.get('name'),
|
||||
name: taskName,
|
||||
description: formData.get('description'),
|
||||
base_model: formData.get('base_model'),
|
||||
template: formData.get('template'),
|
||||
train_type: formData.get('train_type'),
|
||||
train_method: formData.get('train_method'),
|
||||
gpus: selectedGPUs,
|
||||
train_dataset_id: formData.get('train_dataset_id'),
|
||||
output_model_name: formData.get('output_model_name'),
|
||||
output_model_name: taskName, // 使用任务名称作为模型名称
|
||||
...trainParams,
|
||||
status: 'pending',
|
||||
progress: 0
|
||||
@@ -1042,6 +1071,26 @@
|
||||
showMessage('提示', '请输入任务名称', 'warning');
|
||||
return;
|
||||
}
|
||||
|
||||
// 验证任务名称格式
|
||||
const nameRegex = /^[a-zA-Z0-9_]+$/;
|
||||
if (!nameRegex.test(data.name)) {
|
||||
showMessage('提示', '任务名称只能包含英文、数字和下划线', 'warning');
|
||||
return;
|
||||
}
|
||||
|
||||
// 检查任务名称是否重复
|
||||
try {
|
||||
const checkResponse = await fetch(`${API_BASE}/fine-tune/check-name?name=${encodeURIComponent(data.name)}`);
|
||||
const checkResult = await checkResponse.json();
|
||||
if (checkResult.code === 0 && checkResult.data.exists) {
|
||||
showMessage('提示', '任务名称已存在,请使用其他名称', 'warning');
|
||||
return;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('检查任务名称失败:', error);
|
||||
}
|
||||
|
||||
if (selectedGPUs.length === 0) {
|
||||
showMessage('提示', '请选择至少一个GPU硬件', 'warning');
|
||||
return;
|
||||
@@ -1060,6 +1109,12 @@
|
||||
}
|
||||
|
||||
try {
|
||||
// 显示加载中状态
|
||||
const submitBtn = document.querySelector('button[onclick="submitForm()"]');
|
||||
const originalText = submitBtn.innerHTML;
|
||||
submitBtn.disabled = true;
|
||||
submitBtn.innerHTML = '<i class="fa fa-spinner fa-spin mr-2"></i>训练任务创建中...';
|
||||
|
||||
// 第一步:创建训练任务记录
|
||||
const createResponse = await fetch(`${API_BASE}/fine-tune`, {
|
||||
method: 'POST',
|
||||
@@ -1068,6 +1123,8 @@
|
||||
});
|
||||
const createResult = await createResponse.json();
|
||||
if (createResult.code !== 0) {
|
||||
submitBtn.disabled = false;
|
||||
submitBtn.innerHTML = originalText;
|
||||
showMessage('错误', createResult.message || '创建任务失败', 'error');
|
||||
return;
|
||||
}
|
||||
@@ -1077,12 +1134,13 @@
|
||||
// 第二步:启动训练
|
||||
const startData = {
|
||||
task_id: taskId,
|
||||
name: data.name, // 任务名称,用于日志文件名和模型名称
|
||||
base_model: data.base_model,
|
||||
template: data.template,
|
||||
train_type: data.train_type,
|
||||
train_method: data.train_method,
|
||||
train_dataset_id: data.train_dataset_id,
|
||||
output_model_name: data.output_model_name,
|
||||
output_model_name: data.name, // 使用任务名称作为模型名称
|
||||
...trainParams
|
||||
};
|
||||
|
||||
@@ -1093,9 +1151,12 @@
|
||||
});
|
||||
const startResult = await startResponse.json();
|
||||
|
||||
// 恢复按钮状态
|
||||
submitBtn.disabled = false;
|
||||
submitBtn.innerHTML = originalText;
|
||||
|
||||
if (startResult.code === 0) {
|
||||
const cmd = startResult.data?.command || '';
|
||||
showMessage('成功', `训练任务已启动!<br><br><code class="text-xs bg-gray-100 p-1 rounded">${cmd}</code>`, 'success', () => {
|
||||
showMessage('成功', '训练任务已启动!', 'success', () => {
|
||||
window.location.href = 'main.html';
|
||||
});
|
||||
} else {
|
||||
@@ -1108,6 +1169,12 @@
|
||||
showMessage('错误', startResult.message || '启动训练失败', 'error');
|
||||
}
|
||||
} catch (error) {
|
||||
// 恢复按钮状态
|
||||
const submitBtn = document.querySelector('button[onclick="submitForm()"]');
|
||||
if (submitBtn) {
|
||||
submitBtn.disabled = false;
|
||||
submitBtn.innerHTML = '开始训练';
|
||||
}
|
||||
showMessage('错误', '操作失败: ' + error.message, 'error');
|
||||
}
|
||||
}
|
||||
@@ -1146,9 +1213,10 @@
|
||||
const trainMethod = formData.get('train_method') || 'lora';
|
||||
const methodMap = { 'lora': 'lora', 'full': 'full' };
|
||||
|
||||
// 获取输出模型名称
|
||||
const outputModelName = formData.get('output_model_name') || `${template}/${trainMethod}`;
|
||||
const outputDir = outputModelName.startsWith('./') ? outputModelName : `./saves/${outputModelName}`;
|
||||
// 获取输出模型名称(使用任务名称)
|
||||
const taskName = formData.get('name') || 'task_name';
|
||||
const outputModelName = taskName;
|
||||
const outputDir = outputModelName.startsWith('/') ? outputModelName : `/app/base/saves/${outputModelName}`;
|
||||
|
||||
// 获取数据集名称
|
||||
const trainDatasetSelect = form.querySelector('select[name="train_dataset_id"]');
|
||||
@@ -1167,7 +1235,7 @@
|
||||
const nEpochs = parseFloat(formData.get('n_epochs')) || 1.0;
|
||||
const maxLength = parseInt(formData.get('max_length')) || 512;
|
||||
const warmupSteps = parseInt(formData.get('warmup_steps')) || 20;
|
||||
const evalSteps = parseInt(formData.get('eval_steps')) || 100;
|
||||
const saveSteps = parseInt(formData.get('save_steps')) || 100;
|
||||
const gradientAccumulationSteps = parseInt(formData.get('gradient_accumulation_steps')) || 8;
|
||||
const lrSchedulerType = formData.get('lr_scheduler_type') || 'cosine';
|
||||
|
||||
@@ -1204,10 +1272,10 @@
|
||||
cmd += ` --lr_scheduler_type ${lrSchedulerType} \\\n`;
|
||||
cmd += ` --logging_steps 50 \\\n`;
|
||||
cmd += ` --warmup_steps ${warmupSteps} \\\n`;
|
||||
cmd += ` --save_steps 100 \\\n`;
|
||||
cmd += ` --eval_steps ${evalSteps} \\\n`;
|
||||
cmd += ` --save_steps ${saveSteps} \\\n`;
|
||||
cmd += ` --learning_rate ${learningRate} \\\n`;
|
||||
cmd += ` --num_train_epochs ${nEpochs}`;
|
||||
cmd += ` --num_train_epochs ${nEpochs} \\\n`;
|
||||
cmd += ` --plot_loss`;
|
||||
|
||||
return cmd;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user