""" Datasets API Router """ from typing import List, Optional from uuid import UUID from pydantic import BaseModel, Field from fastapi import APIRouter, Depends, Query from sqlalchemy.ext.asyncio import AsyncSession from app.api.response import ApiResponse, PaginatedResponse from app.core.database import get_db from app.core.exceptions import NotFoundException from app.core.crud import CRUDBase from app.models.models import Dataset from app.schemas.dataset import DatasetResponse from app.schemas.dataset import DatasetCreateSchema router = APIRouter() # Initialize CRUD dataset_crud = CRUDBase(Dataset) class ExportRequest(BaseModel): """Export request schema""" format: str = Field("alpaca", pattern="^(alpaca|sharegpt|llama_factory|json)$") @router.get("", response_model=ApiResponse) async def list_datasets( project_id: UUID, page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100), db: AsyncSession = Depends(get_db) ): """List datasets for a project""" skip = (page - 1) * page_size datasets, total = await dataset_crud.get_multi( db, skip=skip, limit=page_size, filters={"project_id": project_id}, order_by="created_at", descending=True ) dataset_responses = [DatasetResponse.model_validate(d) for d in datasets] return PaginatedResponse.ok( items=dataset_responses, page=page, page_size=page_size, total=total ) @router.post("", response_model=ApiResponse) async def create_dataset( project_id: UUID, dataset: DatasetCreateSchema, db: AsyncSession = Depends(get_db) ): """Create a new dataset""" # Add project_id to the dataset dataset_dict = dataset.model_dump() dataset_dict["project_id"] = project_id db_dataset = Dataset(**dataset_dict) db.add(db_dataset) await db.commit() await db.refresh(db_dataset) return ApiResponse.ok( data={"id": str(db_dataset.id)}, message="Dataset created successfully" ) @router.get("/{dataset_id}", response_model=ApiResponse) async def get_dataset( project_id: UUID, dataset_id: UUID, db: AsyncSession = Depends(get_db) ): """Get dataset by ID""" dataset = await dataset_crud.get(db, dataset_id) if not dataset or dataset.project_id != project_id: raise NotFoundException("Dataset", dataset_id) return ApiResponse.ok(data=DatasetResponse.model_validate(dataset)) @router.delete("/{dataset_id}", response_model=ApiResponse) async def delete_dataset( project_id: UUID, dataset_id: UUID, db: AsyncSession = Depends(get_db) ): """Delete dataset""" dataset = await dataset_crud.get(db, dataset_id) if not dataset or dataset.project_id != project_id: raise NotFoundException("Dataset", dataset_id) await dataset_crud.delete(db, dataset_id) return ApiResponse.ok(message="Dataset deleted successfully") @router.post("/{dataset_id}/export", response_model=ApiResponse) async def export_dataset( project_id: UUID, dataset_id: UUID, request: ExportRequest, db: AsyncSession = Depends(get_db) ): """Export dataset in specified format""" dataset = await dataset_crud.get(db, dataset_id) if not dataset or dataset.project_id != project_id: raise NotFoundException("Dataset", dataset_id) # Return sample data based on format sample_data = [ { "instruction": "这是一个示例指令", "input": "", "output": "这是一个示例输出" } ] if request.format == "json": return ApiResponse.ok(data=sample_data) return ApiResponse.ok( data={"data": sample_data, "format": request.format}, message="Dataset exported successfully" )