488 lines
16 KiB
JavaScript
488 lines
16 KiB
JavaScript
'use client';
|
||
|
||
import { useTranslation } from 'react-i18next';
|
||
import { toast } from 'sonner';
|
||
import axios from 'axios';
|
||
|
||
const useDatasetExport = projectId => {
|
||
const { t } = useTranslation();
|
||
|
||
// 优化的流式导出 - 使用 WritableStream 避免内存溢出
|
||
const exportDatasetsStreaming = async (exportOptions, onProgress) => {
|
||
try {
|
||
const batchSize = exportOptions.batchSize || 1000;
|
||
let offset = 0;
|
||
let hasMore = true;
|
||
let totalProcessed = 0;
|
||
let isFirstBatch = true;
|
||
|
||
// 确定文件格式
|
||
const fileFormat = exportOptions.fileFormat || 'json';
|
||
const formatType = exportOptions.formatType || 'alpaca';
|
||
|
||
// 生成文件名
|
||
const formatSuffixMap = {
|
||
alpaca: 'alpaca',
|
||
multilingualthinking: 'multilingual-thinking',
|
||
sharegpt: 'sharegpt',
|
||
custom: 'custom'
|
||
};
|
||
const formatSuffix = formatSuffixMap[formatType] || formatType || 'export';
|
||
const balanceSuffix = exportOptions.balanceMode ? '-balanced' : '';
|
||
const dateStr = new Date().toISOString().slice(0, 10);
|
||
const fileName = `datasets-${projectId}-${formatSuffix}${balanceSuffix}-${dateStr}.${fileFormat}`;
|
||
|
||
// 创建可写流
|
||
let fileStream;
|
||
let writer;
|
||
|
||
try {
|
||
// 使用 showSaveFilePicker API(现代浏览器)
|
||
if (window.showSaveFilePicker) {
|
||
const handle = await window.showSaveFilePicker({
|
||
suggestedName: fileName,
|
||
types: [
|
||
{
|
||
description: 'Dataset File',
|
||
accept: {
|
||
'application/json': [`.${fileFormat}`]
|
||
}
|
||
}
|
||
]
|
||
});
|
||
fileStream = await handle.createWritable();
|
||
} else {
|
||
// 降级方案:使用内存缓冲区(但分块处理)
|
||
fileStream = null;
|
||
}
|
||
} catch (err) {
|
||
// 用户取消或不支持,使用降级方案
|
||
fileStream = null;
|
||
}
|
||
|
||
// 如果不支持流式写入,使用分块累积方案
|
||
let chunks = [];
|
||
let chunkCount = 0;
|
||
const MAX_CHUNKS_IN_MEMORY = 5; // 最多在内存中保留5批数据
|
||
|
||
// 写入文件头(JSON数组开始或CSV表头)
|
||
if (fileFormat === 'json') {
|
||
if (fileStream) {
|
||
await fileStream.write('[\n');
|
||
} else {
|
||
chunks.push('[\n');
|
||
}
|
||
} else if (fileFormat === 'csv') {
|
||
// 写入CSV表头
|
||
const headers = getCSVHeaders(formatType, exportOptions);
|
||
const headerLine = headers.join(',') + '\n';
|
||
if (fileStream) {
|
||
await fileStream.write(headerLine);
|
||
} else {
|
||
chunks.push(headerLine);
|
||
}
|
||
}
|
||
|
||
// 分批获取和写入数据
|
||
while (hasMore) {
|
||
const apiUrl = `/api/projects/${projectId}/datasets/export`;
|
||
const requestBody = {
|
||
batchMode: true,
|
||
offset: offset,
|
||
batchSize: batchSize
|
||
};
|
||
|
||
// 如果有选中的数据集 ID,传递 ID 列表
|
||
if (exportOptions.selectedIds && exportOptions.selectedIds.length > 0) {
|
||
requestBody.selectedIds = exportOptions.selectedIds;
|
||
} else if (exportOptions.confirmedOnly) {
|
||
requestBody.status = 'confirmed';
|
||
}
|
||
|
||
// 检查是否是平衡导出模式
|
||
if (exportOptions.balanceMode && exportOptions.balanceConfig) {
|
||
requestBody.balanceMode = true;
|
||
requestBody.balanceConfig = exportOptions.balanceConfig;
|
||
}
|
||
|
||
const response = await axios.post(apiUrl, requestBody);
|
||
const batchResult = response.data;
|
||
|
||
// 如果需要包含文本块内容,批量查询并填充
|
||
if (exportOptions.customFields?.includeChunk && batchResult.data.length > 0) {
|
||
const chunkNames = batchResult.data.map(item => item.chunkName).filter(name => name);
|
||
|
||
if (chunkNames.length > 0) {
|
||
try {
|
||
const chunkResponse = await axios.post(`/api/projects/${projectId}/chunks/batch-content`, {
|
||
chunkNames
|
||
});
|
||
const chunkContentMap = chunkResponse.data;
|
||
|
||
batchResult.data.forEach(item => {
|
||
if (item.chunkName && chunkContentMap[item.chunkName]) {
|
||
item.chunkContent = chunkContentMap[item.chunkName];
|
||
}
|
||
});
|
||
} catch (chunkError) {
|
||
console.error('获取文本块内容失败:', chunkError);
|
||
}
|
||
}
|
||
}
|
||
|
||
// 转换当前批次数据
|
||
const formattedBatch = formatDataBatch(batchResult.data, exportOptions);
|
||
|
||
// 写入当前批次
|
||
if (fileFormat === 'json') {
|
||
// 保持与原逻辑一致:JSON 导出为“格式化后的 JSON 数组”(2空格缩进)
|
||
// 每条记录单独 stringify + 缩进,并在数组级别拼接,避免一次性 stringify 全量数据导致内存暴涨
|
||
const batchContent = formattedBatch
|
||
.map(item => {
|
||
const pretty = JSON.stringify(item, null, 2);
|
||
// 将对象的每一行整体再缩进 2 个空格,以符合数组元素缩进
|
||
return ' ' + pretty.replace(/\n/g, '\n ');
|
||
})
|
||
.join(',\n');
|
||
|
||
const content = isFirstBatch ? batchContent : ',\n' + batchContent;
|
||
|
||
if (fileStream) {
|
||
await fileStream.write(content);
|
||
} else {
|
||
chunks.push(content);
|
||
chunkCount++;
|
||
}
|
||
} else if (fileFormat === 'jsonl') {
|
||
const batchContent = formattedBatch.map(item => JSON.stringify(item)).join('\n') + '\n';
|
||
|
||
if (fileStream) {
|
||
await fileStream.write(batchContent);
|
||
} else {
|
||
chunks.push(batchContent);
|
||
chunkCount++;
|
||
}
|
||
} else if (fileFormat === 'csv') {
|
||
const batchContent = formatBatchToCSV(formattedBatch, formatType, exportOptions);
|
||
|
||
if (fileStream) {
|
||
await fileStream.write(batchContent);
|
||
} else {
|
||
chunks.push(batchContent);
|
||
chunkCount++;
|
||
}
|
||
}
|
||
|
||
// 如果使用内存缓冲且累积了足够多的块,触发部分下载
|
||
if (!fileStream && chunkCount >= MAX_CHUNKS_IN_MEMORY) {
|
||
// 这里我们仍然需要等到最后才能下载,但至少限制了内存使用
|
||
// 可以考虑使用 Blob 分片
|
||
}
|
||
|
||
hasMore = batchResult.hasMore;
|
||
offset = batchResult.offset;
|
||
totalProcessed += batchResult.data.length;
|
||
isFirstBatch = false;
|
||
|
||
// 通知进度更新
|
||
if (onProgress) {
|
||
onProgress({
|
||
processed: totalProcessed,
|
||
currentBatch: batchResult.data.length,
|
||
hasMore
|
||
});
|
||
}
|
||
|
||
// 避免过快请求
|
||
if (hasMore) {
|
||
await new Promise(resolve => setTimeout(resolve, 50));
|
||
}
|
||
}
|
||
|
||
// 写入文件尾
|
||
if (fileFormat === 'json') {
|
||
if (fileStream) {
|
||
await fileStream.write('\n]\n');
|
||
await fileStream.close();
|
||
} else {
|
||
chunks.push('\n]\n');
|
||
}
|
||
} else {
|
||
if (fileStream) {
|
||
await fileStream.close();
|
||
}
|
||
}
|
||
|
||
// 如果使用内存缓冲方案,现在触发下载
|
||
if (!fileStream) {
|
||
downloadFromChunks(chunks, fileName);
|
||
}
|
||
|
||
toast.success(t('datasets.exportSuccess'));
|
||
return true;
|
||
} catch (error) {
|
||
console.error('Streaming export failed:', error);
|
||
toast.error(error.message || t('datasets.exportFailed'));
|
||
return false;
|
||
}
|
||
};
|
||
|
||
// 从内存块下载文件(优化版本,使用 Blob 流)
|
||
const downloadFromChunks = (chunks, fileName) => {
|
||
// 使用 Blob 构造函数,它会自动处理大数据
|
||
const blob = new Blob(chunks, { type: 'application/octet-stream' });
|
||
const url = URL.createObjectURL(blob);
|
||
const a = document.createElement('a');
|
||
a.href = url;
|
||
a.download = fileName;
|
||
document.body.appendChild(a);
|
||
a.click();
|
||
document.body.removeChild(a);
|
||
|
||
// 延迟释放 URL,确保下载开始
|
||
setTimeout(() => URL.revokeObjectURL(url), 1000);
|
||
};
|
||
|
||
// 获取CSV表头
|
||
const getCSVHeaders = (formatType, exportOptions) => {
|
||
if (formatType === 'alpaca') {
|
||
return ['instruction', 'input', 'output', 'system'];
|
||
} else if (formatType === 'sharegpt') {
|
||
return ['messages'];
|
||
} else if (formatType === 'multilingualthinking') {
|
||
return ['reasoning_language', 'developer', 'user', 'analysis', 'final', 'messages'];
|
||
} else if (formatType === 'custom') {
|
||
const { questionField, answerField, cotField, includeLabels, includeChunk, questionOnly } =
|
||
exportOptions.customFields;
|
||
const headers = [questionField];
|
||
if (!questionOnly) {
|
||
headers.push(answerField);
|
||
if (exportOptions.includeCOT && cotField) {
|
||
headers.push(cotField);
|
||
}
|
||
}
|
||
if (includeLabels) headers.push('label');
|
||
if (includeChunk) headers.push('chunk');
|
||
return headers;
|
||
}
|
||
return [];
|
||
};
|
||
|
||
// 格式化数据批次
|
||
const formatDataBatch = (dataBatch, exportOptions) => {
|
||
const formatType = exportOptions.formatType || 'alpaca';
|
||
|
||
if (formatType === 'alpaca') {
|
||
if (exportOptions.alpacaFieldType === 'instruction') {
|
||
return dataBatch.map(({ question, answer, cot }) => ({
|
||
instruction: question,
|
||
input: '',
|
||
output: cot && exportOptions.includeCOT ? `<think>${cot}</think>\n${answer}` : answer,
|
||
system: exportOptions.systemPrompt || ''
|
||
}));
|
||
} else {
|
||
return dataBatch.map(({ question, answer, cot }) => ({
|
||
instruction: exportOptions.customInstruction || '',
|
||
input: question,
|
||
output: cot && exportOptions.includeCOT ? `<think>${cot}</think>\n${answer}` : answer,
|
||
system: exportOptions.systemPrompt || ''
|
||
}));
|
||
}
|
||
} else if (formatType === 'sharegpt') {
|
||
return dataBatch.map(({ question, answer, cot }) => {
|
||
const messages = [];
|
||
if (exportOptions.systemPrompt) {
|
||
messages.push({ role: 'system', content: exportOptions.systemPrompt });
|
||
}
|
||
messages.push({
|
||
role: 'user',
|
||
content: question
|
||
});
|
||
messages.push({
|
||
role: 'assistant',
|
||
content: cot && exportOptions.includeCOT ? `<think>${cot}</think>\n${answer}` : answer
|
||
});
|
||
return { messages };
|
||
});
|
||
} else if (formatType === 'multilingualthinking') {
|
||
return dataBatch.map(({ question, answer, cot }) => ({
|
||
reasoning_language: exportOptions.reasoningLanguage || 'English',
|
||
developer: exportOptions.systemPrompt || '',
|
||
user: question,
|
||
analysis: exportOptions.includeCOT && cot ? cot : null,
|
||
final: answer,
|
||
messages: [
|
||
{
|
||
content: exportOptions.systemPrompt || '',
|
||
role: 'system',
|
||
thinking: null
|
||
},
|
||
{
|
||
content: question,
|
||
role: 'user',
|
||
thinking: null
|
||
},
|
||
{
|
||
content: answer,
|
||
role: 'assistant',
|
||
thinking: exportOptions.includeCOT && cot ? cot : null
|
||
}
|
||
]
|
||
}));
|
||
} else if (formatType === 'custom') {
|
||
const { questionField, answerField, cotField, includeLabels, includeChunk, questionOnly } =
|
||
exportOptions.customFields;
|
||
return dataBatch.map(({ question, answer, cot, questionLabel: labels, chunkContent }) => {
|
||
const item = { [questionField]: question };
|
||
if (!questionOnly) {
|
||
item[answerField] = answer;
|
||
if (cot && exportOptions.includeCOT && cotField) {
|
||
item[cotField] = cot;
|
||
}
|
||
}
|
||
if (includeLabels && labels && labels.length > 0) {
|
||
item.label = labels.split(' ')[1];
|
||
}
|
||
if (includeChunk && chunkContent) {
|
||
item.chunk = chunkContent;
|
||
}
|
||
return item;
|
||
});
|
||
}
|
||
return dataBatch;
|
||
};
|
||
|
||
// 将批次格式化为CSV行
|
||
const formatBatchToCSV = (formattedBatch, formatType, exportOptions) => {
|
||
const headers = getCSVHeaders(formatType, exportOptions);
|
||
return (
|
||
formattedBatch
|
||
.map(item => {
|
||
return headers
|
||
.map(header => {
|
||
let field = item[header]?.toString() || '';
|
||
// 对于复杂对象,转换为JSON字符串
|
||
if (typeof item[header] === 'object') {
|
||
field = JSON.stringify(item[header]);
|
||
}
|
||
// CSV转义
|
||
if (field.includes(',') || field.includes('\n') || field.includes('"')) {
|
||
field = `"${field.replace(/"/g, '""')}"`;
|
||
}
|
||
return field;
|
||
})
|
||
.join(',');
|
||
})
|
||
.join('\n') + '\n'
|
||
);
|
||
};
|
||
|
||
// 处理和下载数据的通用函数(保留用于小数据量)
|
||
const processAndDownloadData = async (dataToExport, exportOptions) => {
|
||
const formattedData = formatDataBatch(dataToExport, exportOptions);
|
||
|
||
let content;
|
||
let fileExtension;
|
||
const fileFormat = exportOptions.fileFormat || 'json';
|
||
|
||
if (fileFormat === 'jsonl') {
|
||
content = formattedData.map(item => JSON.stringify(item)).join('\n');
|
||
fileExtension = 'jsonl';
|
||
} else if (fileFormat === 'csv') {
|
||
const headers = getCSVHeaders(exportOptions.formatType, exportOptions);
|
||
const csvRows = [
|
||
headers.join(','),
|
||
...formattedData.map(item =>
|
||
headers
|
||
.map(header => {
|
||
let field = item[header]?.toString() || '';
|
||
if (typeof item[header] === 'object') {
|
||
field = JSON.stringify(item[header]);
|
||
}
|
||
if (field.includes(',') || field.includes('\n') || field.includes('"')) {
|
||
field = `"${field.replace(/"/g, '""')}"`;
|
||
}
|
||
return field;
|
||
})
|
||
.join(',')
|
||
)
|
||
];
|
||
content = csvRows.join('\n');
|
||
fileExtension = 'csv';
|
||
} else {
|
||
content = JSON.stringify(formattedData, null, 2);
|
||
fileExtension = 'json';
|
||
}
|
||
|
||
const blob = new Blob([content], { type: 'application/json' });
|
||
const url = URL.createObjectURL(blob);
|
||
const a = document.createElement('a');
|
||
a.href = url;
|
||
|
||
const formatSuffixMap = {
|
||
alpaca: 'alpaca',
|
||
multilingualthinking: 'multilingual-thinking',
|
||
sharegpt: 'sharegpt',
|
||
custom: 'custom'
|
||
};
|
||
const formatSuffix = formatSuffixMap[exportOptions.formatType] || exportOptions.formatType || 'export';
|
||
const balanceSuffix = exportOptions.balanceMode ? '-balanced' : '';
|
||
const dateStr = new Date().toISOString().slice(0, 10);
|
||
a.download = `datasets-${projectId}-${formatSuffix}${balanceSuffix}-${dateStr}.${fileExtension}`;
|
||
|
||
document.body.appendChild(a);
|
||
a.click();
|
||
document.body.removeChild(a);
|
||
URL.revokeObjectURL(url);
|
||
};
|
||
|
||
// 导出数据集(保持向后兼容的原有功能)
|
||
const exportDatasets = async exportOptions => {
|
||
try {
|
||
const apiUrl = `/api/projects/${projectId}/datasets/export`;
|
||
const requestBody = {};
|
||
|
||
if (exportOptions.selectedIds && exportOptions.selectedIds.length > 0) {
|
||
requestBody.selectedIds = exportOptions.selectedIds;
|
||
} else if (exportOptions.confirmedOnly) {
|
||
requestBody.status = 'confirmed';
|
||
}
|
||
|
||
if (exportOptions.balanceMode && exportOptions.balanceConfig) {
|
||
requestBody.balanceMode = true;
|
||
requestBody.balanceConfig = exportOptions.balanceConfig;
|
||
}
|
||
|
||
const response = await axios.post(apiUrl, requestBody);
|
||
let dataToExport = response.data;
|
||
|
||
await processAndDownloadData(dataToExport, exportOptions);
|
||
|
||
toast.success(t('datasets.exportSuccess'));
|
||
return true;
|
||
} catch (error) {
|
||
toast.error(error.message);
|
||
return false;
|
||
}
|
||
};
|
||
|
||
// 导出平衡数据集
|
||
const exportBalancedDataset = async exportOptions => {
|
||
const balancedOptions = {
|
||
...exportOptions,
|
||
balanceMode: true,
|
||
balanceConfig: exportOptions.balanceConfig
|
||
};
|
||
return await exportDatasets(balancedOptions);
|
||
};
|
||
|
||
return {
|
||
exportDatasets,
|
||
exportBalancedDataset,
|
||
exportDatasetsStreaming
|
||
};
|
||
};
|
||
|
||
export default useDatasetExport;
|
||
export { useDatasetExport };
|