Files
YG-Datasets/backend/app/api/v1/datasets/__init__.py

134 lines
3.7 KiB
Python
Raw Normal View History

2026-03-17 14:36:31 +08:00
"""
Datasets API Router
"""
from typing import List, Optional
from uuid import UUID
from pydantic import BaseModel, Field
from fastapi import APIRouter, Depends, Query
2026-03-17 14:36:31 +08:00
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.response import ApiResponse, PaginatedResponse
2026-03-17 14:36:31 +08:00
from app.core.database import get_db
from app.core.exceptions import NotFoundException
from app.core.crud import CRUDBase
from app.models.models import Dataset
from app.schemas.dataset import DatasetResponse
from app.schemas.dataset import DatasetCreateSchema
2026-03-17 14:36:31 +08:00
router = APIRouter()
# Initialize CRUD
dataset_crud = CRUDBase(Dataset)
2026-03-17 14:36:31 +08:00
class ExportRequest(BaseModel):
"""Export request schema"""
format: str = Field("alpaca", pattern="^(alpaca|sharegpt|llama_factory|json)$")
2026-03-17 14:36:31 +08:00
@router.get("", response_model=ApiResponse)
async def list_datasets(
project_id: UUID,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
db: AsyncSession = Depends(get_db)
):
2026-03-17 14:36:31 +08:00
"""List datasets for a project"""
skip = (page - 1) * page_size
datasets, total = await dataset_crud.get_multi(
db,
skip=skip,
limit=page_size,
filters={"project_id": project_id},
order_by="created_at",
descending=True
2026-03-17 14:36:31 +08:00
)
dataset_responses = [DatasetResponse.model_validate(d) for d in datasets]
return PaginatedResponse.ok(
items=dataset_responses,
page=page,
page_size=page_size,
total=total
)
2026-03-17 14:36:31 +08:00
@router.post("", response_model=ApiResponse)
2026-03-17 14:36:31 +08:00
async def create_dataset(
project_id: UUID,
dataset: DatasetCreateSchema,
2026-03-17 14:36:31 +08:00
db: AsyncSession = Depends(get_db)
):
"""Create a new dataset"""
# Add project_id to the dataset
dataset_dict = dataset.model_dump()
dataset_dict["project_id"] = project_id
db_dataset = Dataset(**dataset_dict)
2026-03-17 14:36:31 +08:00
db.add(db_dataset)
await db.commit()
await db.refresh(db_dataset)
return ApiResponse.ok(
data={"id": str(db_dataset.id)},
message="Dataset created successfully"
)
2026-03-17 14:36:31 +08:00
@router.get("/{dataset_id}", response_model=ApiResponse)
2026-03-17 14:36:31 +08:00
async def get_dataset(
project_id: UUID,
dataset_id: UUID,
db: AsyncSession = Depends(get_db)
):
"""Get dataset by ID"""
dataset = await dataset_crud.get(db, dataset_id)
if not dataset or dataset.project_id != project_id:
raise NotFoundException("Dataset", dataset_id)
2026-03-17 14:36:31 +08:00
return ApiResponse.ok(data=DatasetResponse.model_validate(dataset))
2026-03-17 14:36:31 +08:00
@router.delete("/{dataset_id}", response_model=ApiResponse)
2026-03-17 14:36:31 +08:00
async def delete_dataset(
project_id: UUID,
dataset_id: UUID,
db: AsyncSession = Depends(get_db)
):
"""Delete dataset"""
dataset = await dataset_crud.get(db, dataset_id)
if not dataset or dataset.project_id != project_id:
raise NotFoundException("Dataset", dataset_id)
2026-03-17 14:36:31 +08:00
await dataset_crud.delete(db, dataset_id)
return ApiResponse.ok(message="Dataset deleted successfully")
2026-03-17 14:36:31 +08:00
@router.post("/{dataset_id}/export", response_model=ApiResponse)
2026-03-17 14:36:31 +08:00
async def export_dataset(
project_id: UUID,
dataset_id: UUID,
request: ExportRequest,
db: AsyncSession = Depends(get_db)
):
"""Export dataset in specified format"""
dataset = await dataset_crud.get(db, dataset_id)
if not dataset or dataset.project_id != project_id:
raise NotFoundException("Dataset", dataset_id)
2026-03-17 14:36:31 +08:00
# Return sample data based on format
sample_data = [
{
"instruction": "这是一个示例指令",
"input": "",
"output": "这是一个示例输出"
}
]
if request.format == "json":
return ApiResponse.ok(data=sample_data)
2026-03-17 14:36:31 +08:00
return ApiResponse.ok(
data={"data": sample_data, "format": request.format},
message="Dataset exported successfully"
)