'use client'; import { useTranslation } from 'react-i18next'; import { toast } from 'sonner'; import axios from 'axios'; const useDatasetExport = projectId => { const { t } = useTranslation(); // 优化的流式导出 - 使用 WritableStream 避免内存溢出 const exportDatasetsStreaming = async (exportOptions, onProgress) => { try { const batchSize = exportOptions.batchSize || 1000; let offset = 0; let hasMore = true; let totalProcessed = 0; let isFirstBatch = true; // 确定文件格式 const fileFormat = exportOptions.fileFormat || 'json'; const formatType = exportOptions.formatType || 'alpaca'; // 生成文件名 const formatSuffixMap = { alpaca: 'alpaca', multilingualthinking: 'multilingual-thinking', sharegpt: 'sharegpt', custom: 'custom' }; const formatSuffix = formatSuffixMap[formatType] || formatType || 'export'; const balanceSuffix = exportOptions.balanceMode ? '-balanced' : ''; const dateStr = new Date().toISOString().slice(0, 10); const fileName = `datasets-${projectId}-${formatSuffix}${balanceSuffix}-${dateStr}.${fileFormat}`; // 创建可写流 let fileStream; let writer; try { // 使用 showSaveFilePicker API(现代浏览器) if (window.showSaveFilePicker) { const handle = await window.showSaveFilePicker({ suggestedName: fileName, types: [ { description: 'Dataset File', accept: { 'application/json': [`.${fileFormat}`] } } ] }); fileStream = await handle.createWritable(); } else { // 降级方案:使用内存缓冲区(但分块处理) fileStream = null; } } catch (err) { // 用户取消或不支持,使用降级方案 fileStream = null; } // 如果不支持流式写入,使用分块累积方案 let chunks = []; let chunkCount = 0; const MAX_CHUNKS_IN_MEMORY = 5; // 最多在内存中保留5批数据 // 写入文件头(JSON数组开始或CSV表头) if (fileFormat === 'json') { if (fileStream) { await fileStream.write('[\n'); } else { chunks.push('[\n'); } } else if (fileFormat === 'csv') { // 写入CSV表头 const headers = getCSVHeaders(formatType, exportOptions); const headerLine = headers.join(',') + '\n'; if (fileStream) { await fileStream.write(headerLine); } else { chunks.push(headerLine); } } // 分批获取和写入数据 while (hasMore) { const apiUrl = `/api/projects/${projectId}/datasets/export`; const requestBody = { batchMode: true, offset: offset, batchSize: batchSize }; // 如果有选中的数据集 ID,传递 ID 列表 if (exportOptions.selectedIds && exportOptions.selectedIds.length > 0) { requestBody.selectedIds = exportOptions.selectedIds; } else if (exportOptions.confirmedOnly) { requestBody.status = 'confirmed'; } // 检查是否是平衡导出模式 if (exportOptions.balanceMode && exportOptions.balanceConfig) { requestBody.balanceMode = true; requestBody.balanceConfig = exportOptions.balanceConfig; } const response = await axios.post(apiUrl, requestBody); const batchResult = response.data; // 如果需要包含文本块内容,批量查询并填充 if (exportOptions.customFields?.includeChunk && batchResult.data.length > 0) { const chunkNames = batchResult.data.map(item => item.chunkName).filter(name => name); if (chunkNames.length > 0) { try { const chunkResponse = await axios.post(`/api/projects/${projectId}/chunks/batch-content`, { chunkNames }); const chunkContentMap = chunkResponse.data; batchResult.data.forEach(item => { if (item.chunkName && chunkContentMap[item.chunkName]) { item.chunkContent = chunkContentMap[item.chunkName]; } }); } catch (chunkError) { console.error('获取文本块内容失败:', chunkError); } } } // 转换当前批次数据 const formattedBatch = formatDataBatch(batchResult.data, exportOptions); // 写入当前批次 if (fileFormat === 'json') { // 保持与原逻辑一致:JSON 导出为“格式化后的 JSON 数组”(2空格缩进) // 每条记录单独 stringify + 缩进,并在数组级别拼接,避免一次性 stringify 全量数据导致内存暴涨 const batchContent = formattedBatch .map(item => { const pretty = JSON.stringify(item, null, 2); // 将对象的每一行整体再缩进 2 个空格,以符合数组元素缩进 return ' ' + pretty.replace(/\n/g, '\n '); }) .join(',\n'); const content = isFirstBatch ? batchContent : ',\n' + batchContent; if (fileStream) { await fileStream.write(content); } else { chunks.push(content); chunkCount++; } } else if (fileFormat === 'jsonl') { const batchContent = formattedBatch.map(item => JSON.stringify(item)).join('\n') + '\n'; if (fileStream) { await fileStream.write(batchContent); } else { chunks.push(batchContent); chunkCount++; } } else if (fileFormat === 'csv') { const batchContent = formatBatchToCSV(formattedBatch, formatType, exportOptions); if (fileStream) { await fileStream.write(batchContent); } else { chunks.push(batchContent); chunkCount++; } } // 如果使用内存缓冲且累积了足够多的块,触发部分下载 if (!fileStream && chunkCount >= MAX_CHUNKS_IN_MEMORY) { // 这里我们仍然需要等到最后才能下载,但至少限制了内存使用 // 可以考虑使用 Blob 分片 } hasMore = batchResult.hasMore; offset = batchResult.offset; totalProcessed += batchResult.data.length; isFirstBatch = false; // 通知进度更新 if (onProgress) { onProgress({ processed: totalProcessed, currentBatch: batchResult.data.length, hasMore }); } // 避免过快请求 if (hasMore) { await new Promise(resolve => setTimeout(resolve, 50)); } } // 写入文件尾 if (fileFormat === 'json') { if (fileStream) { await fileStream.write('\n]\n'); await fileStream.close(); } else { chunks.push('\n]\n'); } } else { if (fileStream) { await fileStream.close(); } } // 如果使用内存缓冲方案,现在触发下载 if (!fileStream) { downloadFromChunks(chunks, fileName); } toast.success(t('datasets.exportSuccess')); return true; } catch (error) { console.error('Streaming export failed:', error); toast.error(error.message || t('datasets.exportFailed')); return false; } }; // 从内存块下载文件(优化版本,使用 Blob 流) const downloadFromChunks = (chunks, fileName) => { // 使用 Blob 构造函数,它会自动处理大数据 const blob = new Blob(chunks, { type: 'application/octet-stream' }); const url = URL.createObjectURL(blob); const a = document.createElement('a'); a.href = url; a.download = fileName; document.body.appendChild(a); a.click(); document.body.removeChild(a); // 延迟释放 URL,确保下载开始 setTimeout(() => URL.revokeObjectURL(url), 1000); }; // 获取CSV表头 const getCSVHeaders = (formatType, exportOptions) => { if (formatType === 'alpaca') { return ['instruction', 'input', 'output', 'system']; } else if (formatType === 'sharegpt') { return ['messages']; } else if (formatType === 'multilingualthinking') { return ['reasoning_language', 'developer', 'user', 'analysis', 'final', 'messages']; } else if (formatType === 'custom') { const { questionField, answerField, cotField, includeLabels, includeChunk, questionOnly } = exportOptions.customFields; const headers = [questionField]; if (!questionOnly) { headers.push(answerField); if (exportOptions.includeCOT && cotField) { headers.push(cotField); } } if (includeLabels) headers.push('label'); if (includeChunk) headers.push('chunk'); return headers; } return []; }; // 格式化数据批次 const formatDataBatch = (dataBatch, exportOptions) => { const formatType = exportOptions.formatType || 'alpaca'; if (formatType === 'alpaca') { if (exportOptions.alpacaFieldType === 'instruction') { return dataBatch.map(({ question, answer, cot }) => ({ instruction: question, input: '', output: cot && exportOptions.includeCOT ? `${cot}\n${answer}` : answer, system: exportOptions.systemPrompt || '' })); } else { return dataBatch.map(({ question, answer, cot }) => ({ instruction: exportOptions.customInstruction || '', input: question, output: cot && exportOptions.includeCOT ? `${cot}\n${answer}` : answer, system: exportOptions.systemPrompt || '' })); } } else if (formatType === 'sharegpt') { return dataBatch.map(({ question, answer, cot }) => { const messages = []; if (exportOptions.systemPrompt) { messages.push({ role: 'system', content: exportOptions.systemPrompt }); } messages.push({ role: 'user', content: question }); messages.push({ role: 'assistant', content: cot && exportOptions.includeCOT ? `${cot}\n${answer}` : answer }); return { messages }; }); } else if (formatType === 'multilingualthinking') { return dataBatch.map(({ question, answer, cot }) => ({ reasoning_language: exportOptions.reasoningLanguage || 'English', developer: exportOptions.systemPrompt || '', user: question, analysis: exportOptions.includeCOT && cot ? cot : null, final: answer, messages: [ { content: exportOptions.systemPrompt || '', role: 'system', thinking: null }, { content: question, role: 'user', thinking: null }, { content: answer, role: 'assistant', thinking: exportOptions.includeCOT && cot ? cot : null } ] })); } else if (formatType === 'custom') { const { questionField, answerField, cotField, includeLabels, includeChunk, questionOnly } = exportOptions.customFields; return dataBatch.map(({ question, answer, cot, questionLabel: labels, chunkContent }) => { const item = { [questionField]: question }; if (!questionOnly) { item[answerField] = answer; if (cot && exportOptions.includeCOT && cotField) { item[cotField] = cot; } } if (includeLabels && labels && labels.length > 0) { item.label = labels.split(' ')[1]; } if (includeChunk && chunkContent) { item.chunk = chunkContent; } return item; }); } return dataBatch; }; // 将批次格式化为CSV行 const formatBatchToCSV = (formattedBatch, formatType, exportOptions) => { const headers = getCSVHeaders(formatType, exportOptions); return ( formattedBatch .map(item => { return headers .map(header => { let field = item[header]?.toString() || ''; // 对于复杂对象,转换为JSON字符串 if (typeof item[header] === 'object') { field = JSON.stringify(item[header]); } // CSV转义 if (field.includes(',') || field.includes('\n') || field.includes('"')) { field = `"${field.replace(/"/g, '""')}"`; } return field; }) .join(','); }) .join('\n') + '\n' ); }; // 处理和下载数据的通用函数(保留用于小数据量) const processAndDownloadData = async (dataToExport, exportOptions) => { const formattedData = formatDataBatch(dataToExport, exportOptions); let content; let fileExtension; const fileFormat = exportOptions.fileFormat || 'json'; if (fileFormat === 'jsonl') { content = formattedData.map(item => JSON.stringify(item)).join('\n'); fileExtension = 'jsonl'; } else if (fileFormat === 'csv') { const headers = getCSVHeaders(exportOptions.formatType, exportOptions); const csvRows = [ headers.join(','), ...formattedData.map(item => headers .map(header => { let field = item[header]?.toString() || ''; if (typeof item[header] === 'object') { field = JSON.stringify(item[header]); } if (field.includes(',') || field.includes('\n') || field.includes('"')) { field = `"${field.replace(/"/g, '""')}"`; } return field; }) .join(',') ) ]; content = csvRows.join('\n'); fileExtension = 'csv'; } else { content = JSON.stringify(formattedData, null, 2); fileExtension = 'json'; } const blob = new Blob([content], { type: 'application/json' }); const url = URL.createObjectURL(blob); const a = document.createElement('a'); a.href = url; const formatSuffixMap = { alpaca: 'alpaca', multilingualthinking: 'multilingual-thinking', sharegpt: 'sharegpt', custom: 'custom' }; const formatSuffix = formatSuffixMap[exportOptions.formatType] || exportOptions.formatType || 'export'; const balanceSuffix = exportOptions.balanceMode ? '-balanced' : ''; const dateStr = new Date().toISOString().slice(0, 10); a.download = `datasets-${projectId}-${formatSuffix}${balanceSuffix}-${dateStr}.${fileExtension}`; document.body.appendChild(a); a.click(); document.body.removeChild(a); URL.revokeObjectURL(url); }; // 导出数据集(保持向后兼容的原有功能) const exportDatasets = async exportOptions => { try { const apiUrl = `/api/projects/${projectId}/datasets/export`; const requestBody = {}; if (exportOptions.selectedIds && exportOptions.selectedIds.length > 0) { requestBody.selectedIds = exportOptions.selectedIds; } else if (exportOptions.confirmedOnly) { requestBody.status = 'confirmed'; } if (exportOptions.balanceMode && exportOptions.balanceConfig) { requestBody.balanceMode = true; requestBody.balanceConfig = exportOptions.balanceConfig; } const response = await axios.post(apiUrl, requestBody); let dataToExport = response.data; await processAndDownloadData(dataToExport, exportOptions); toast.success(t('datasets.exportSuccess')); return true; } catch (error) { toast.error(error.message); return false; } }; // 导出平衡数据集 const exportBalancedDataset = async exportOptions => { const balancedOptions = { ...exportOptions, balanceMode: true, balanceConfig: exportOptions.balanceConfig }; return await exportDatasets(balancedOptions); }; return { exportDatasets, exportBalancedDataset, exportDatasetsStreaming }; }; export default useDatasetExport; export { useDatasetExport };