Files
YG-Datasets/easy-dataset-main/hooks/useGenerateDataset.js

136 lines
4.4 KiB
JavaScript
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import { useCallback } from 'react';
import { toast } from 'sonner';
import i18n from '@/lib/i18n';
import axios from 'axios';
import { useAtomValue } from 'jotai/index';
import { selectedModelInfoAtom } from '@/lib/store';
import { useTranslation } from 'react-i18next';
export function useGenerateDataset() {
const model = useAtomValue(selectedModelInfoAtom);
const { t } = useTranslation();
const generateSingleDataset = useCallback(
async ({ projectId, questionId, questionInfo, imageId, imageName }) => {
// 获取模型参数
if (!model) {
toast.error(t('models.configNotFound'));
return null;
}
// 判断是否为图片问题
const isImageQuestion = !!imageId;
// 调用API生成数据集
const currentLanguage = i18n.language === 'zh-CN' ? '中文' : 'en';
if (isImageQuestion) {
// 图片问题:调用图片数据集生成接口
toast.promise(
axios.post(`/api/projects/${projectId}/images/datasets`, {
imageName,
question: { question: questionInfo, id: questionId },
model,
language: currentLanguage
}),
{
loading: t('datasets.generating'),
description: `图片:【${imageName}\n问题:【${questionInfo}`,
position: 'top-right',
success: data => {
return '生成数据集成功';
},
error: error => {
return t('datasets.generateFailed', { error: error.response?.data?.error });
}
}
);
} else {
// 文本问题:调用普通数据集生成接口
toast.promise(
axios.post(`/api/projects/${projectId}/datasets`, {
questionId,
model,
language: currentLanguage
}),
{
loading: t('datasets.generating'),
description: `问题:【${questionInfo}`,
position: 'top-right',
success: data => {
return '生成数据集成功';
},
error: error => {
return t('datasets.generateFailed', { error: error.response?.data?.error });
}
}
);
}
},
[model, t]
);
const generateMultipleDataset = useCallback(
async (projectId, questions) => {
let completed = 0;
const total = questions.length;
// 显示带进度的Loading
const loadingToastId = toast.loading(`正在处理请求 (${completed}/${total})...`, { position: 'top-right' });
// 处理每个请求
const processRequest = async question => {
try {
const isImageQuestion = !!question.imageId;
let response;
if (isImageQuestion) {
// 图片问题
response = await axios.post(`/api/projects/${projectId}/images/datasets`, {
imageName: question.imageName,
question,
model,
language: i18n.language === 'zh-CN' ? '中文' : 'en'
});
} else {
// 文本问题
response = await axios.post(`/api/projects/${projectId}/datasets`, {
questionId: question.id,
model,
language: i18n.language === 'zh-CN' ? '中文' : 'en'
});
}
const data = response.data;
completed++;
toast.success(`${question.question} 完成`, { position: 'top-right' });
toast.loading(`正在处理请求 (${completed}/${total})...`, { id: loadingToastId });
return data;
} catch (error) {
completed++;
toast.error(`${question.question} 失败`, {
description: error.message,
position: 'top-right'
});
toast.loading(`正在处理请求 (${completed}/${total})...`, { id: loadingToastId });
throw error;
}
};
try {
const results = await Promise.allSettled(questions.map(req => processRequest(req)));
// 全部完成后更新Loading为完成状态
toast.success(`全部请求处理完成 (成功: ${results.filter(r => r.status === 'fulfilled').length}/${total})`, {
id: loadingToastId,
position: 'top-right'
});
return results;
} catch {
// Promise.allSettled不会进入catch这里只是保险
}
},
[model, t]
);
return { generateSingleDataset, generateMultipleDataset };
}