Files
YG-Datasets/backend/app/api/v1/datasets/__init__.py
Developer db11429290 feat(backend): 更新 API 端点实现
- 更新 Chunks API 端点
- 更新 Datasets API 端点
- 更新 Evaluation API 端点
- 更新 Files API 端点
- 更新 Projects API 端点
- 更新 Questions API 端点

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-17 17:29:58 +08:00

134 lines
3.7 KiB
Python

"""
Datasets API Router
"""
from typing import List, Optional
from uuid import UUID
from pydantic import BaseModel, Field
from fastapi import APIRouter, Depends, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.response import ApiResponse, PaginatedResponse
from app.core.database import get_db
from app.core.exceptions import NotFoundException
from app.core.crud import CRUDBase
from app.models.models import Dataset
from app.schemas.dataset import DatasetResponse
from app.schemas.dataset import DatasetCreateSchema
router = APIRouter()
# Initialize CRUD
dataset_crud = CRUDBase(Dataset)
class ExportRequest(BaseModel):
"""Export request schema"""
format: str = Field("alpaca", pattern="^(alpaca|sharegpt|llama_factory|json)$")
@router.get("", response_model=ApiResponse)
async def list_datasets(
project_id: UUID,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
db: AsyncSession = Depends(get_db)
):
"""List datasets for a project"""
skip = (page - 1) * page_size
datasets, total = await dataset_crud.get_multi(
db,
skip=skip,
limit=page_size,
filters={"project_id": project_id},
order_by="created_at",
descending=True
)
dataset_responses = [DatasetResponse.model_validate(d) for d in datasets]
return PaginatedResponse.ok(
items=dataset_responses,
page=page,
page_size=page_size,
total=total
)
@router.post("", response_model=ApiResponse)
async def create_dataset(
project_id: UUID,
dataset: DatasetCreateSchema,
db: AsyncSession = Depends(get_db)
):
"""Create a new dataset"""
# Add project_id to the dataset
dataset_dict = dataset.model_dump()
dataset_dict["project_id"] = project_id
db_dataset = Dataset(**dataset_dict)
db.add(db_dataset)
await db.commit()
await db.refresh(db_dataset)
return ApiResponse.ok(
data={"id": str(db_dataset.id)},
message="Dataset created successfully"
)
@router.get("/{dataset_id}", response_model=ApiResponse)
async def get_dataset(
project_id: UUID,
dataset_id: UUID,
db: AsyncSession = Depends(get_db)
):
"""Get dataset by ID"""
dataset = await dataset_crud.get(db, dataset_id)
if not dataset or dataset.project_id != project_id:
raise NotFoundException("Dataset", dataset_id)
return ApiResponse.ok(data=DatasetResponse.model_validate(dataset))
@router.delete("/{dataset_id}", response_model=ApiResponse)
async def delete_dataset(
project_id: UUID,
dataset_id: UUID,
db: AsyncSession = Depends(get_db)
):
"""Delete dataset"""
dataset = await dataset_crud.get(db, dataset_id)
if not dataset or dataset.project_id != project_id:
raise NotFoundException("Dataset", dataset_id)
await dataset_crud.delete(db, dataset_id)
return ApiResponse.ok(message="Dataset deleted successfully")
@router.post("/{dataset_id}/export", response_model=ApiResponse)
async def export_dataset(
project_id: UUID,
dataset_id: UUID,
request: ExportRequest,
db: AsyncSession = Depends(get_db)
):
"""Export dataset in specified format"""
dataset = await dataset_crud.get(db, dataset_id)
if not dataset or dataset.project_id != project_id:
raise NotFoundException("Dataset", dataset_id)
# Return sample data based on format
sample_data = [
{
"instruction": "这是一个示例指令",
"input": "",
"output": "这是一个示例输出"
}
]
if request.format == "json":
return ApiResponse.ok(data=sample_data)
return ApiResponse.ok(
data={"data": sample_data, "format": request.format},
message="Dataset exported successfully"
)