- 更新 Chunks API 端点 - 更新 Datasets API 端点 - 更新 Evaluation API 端点 - 更新 Files API 端点 - 更新 Projects API 端点 - 更新 Questions API 端点 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
134 lines
3.7 KiB
Python
134 lines
3.7 KiB
Python
"""
|
|
Datasets API Router
|
|
"""
|
|
from typing import List, Optional
|
|
from uuid import UUID
|
|
from pydantic import BaseModel, Field
|
|
from fastapi import APIRouter, Depends, Query
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.api.response import ApiResponse, PaginatedResponse
|
|
from app.core.database import get_db
|
|
from app.core.exceptions import NotFoundException
|
|
from app.core.crud import CRUDBase
|
|
from app.models.models import Dataset
|
|
from app.schemas.dataset import DatasetResponse
|
|
from app.schemas.dataset import DatasetCreateSchema
|
|
|
|
router = APIRouter()
|
|
|
|
# Initialize CRUD
|
|
dataset_crud = CRUDBase(Dataset)
|
|
|
|
|
|
class ExportRequest(BaseModel):
|
|
"""Export request schema"""
|
|
format: str = Field("alpaca", pattern="^(alpaca|sharegpt|llama_factory|json)$")
|
|
|
|
|
|
@router.get("", response_model=ApiResponse)
|
|
async def list_datasets(
|
|
project_id: UUID,
|
|
page: int = Query(1, ge=1),
|
|
page_size: int = Query(20, ge=1, le=100),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""List datasets for a project"""
|
|
skip = (page - 1) * page_size
|
|
datasets, total = await dataset_crud.get_multi(
|
|
db,
|
|
skip=skip,
|
|
limit=page_size,
|
|
filters={"project_id": project_id},
|
|
order_by="created_at",
|
|
descending=True
|
|
)
|
|
|
|
dataset_responses = [DatasetResponse.model_validate(d) for d in datasets]
|
|
return PaginatedResponse.ok(
|
|
items=dataset_responses,
|
|
page=page,
|
|
page_size=page_size,
|
|
total=total
|
|
)
|
|
|
|
|
|
@router.post("", response_model=ApiResponse)
|
|
async def create_dataset(
|
|
project_id: UUID,
|
|
dataset: DatasetCreateSchema,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""Create a new dataset"""
|
|
# Add project_id to the dataset
|
|
dataset_dict = dataset.model_dump()
|
|
dataset_dict["project_id"] = project_id
|
|
db_dataset = Dataset(**dataset_dict)
|
|
db.add(db_dataset)
|
|
await db.commit()
|
|
await db.refresh(db_dataset)
|
|
|
|
return ApiResponse.ok(
|
|
data={"id": str(db_dataset.id)},
|
|
message="Dataset created successfully"
|
|
)
|
|
|
|
|
|
@router.get("/{dataset_id}", response_model=ApiResponse)
|
|
async def get_dataset(
|
|
project_id: UUID,
|
|
dataset_id: UUID,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""Get dataset by ID"""
|
|
dataset = await dataset_crud.get(db, dataset_id)
|
|
if not dataset or dataset.project_id != project_id:
|
|
raise NotFoundException("Dataset", dataset_id)
|
|
|
|
return ApiResponse.ok(data=DatasetResponse.model_validate(dataset))
|
|
|
|
|
|
@router.delete("/{dataset_id}", response_model=ApiResponse)
|
|
async def delete_dataset(
|
|
project_id: UUID,
|
|
dataset_id: UUID,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""Delete dataset"""
|
|
dataset = await dataset_crud.get(db, dataset_id)
|
|
if not dataset or dataset.project_id != project_id:
|
|
raise NotFoundException("Dataset", dataset_id)
|
|
|
|
await dataset_crud.delete(db, dataset_id)
|
|
return ApiResponse.ok(message="Dataset deleted successfully")
|
|
|
|
|
|
@router.post("/{dataset_id}/export", response_model=ApiResponse)
|
|
async def export_dataset(
|
|
project_id: UUID,
|
|
dataset_id: UUID,
|
|
request: ExportRequest,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""Export dataset in specified format"""
|
|
dataset = await dataset_crud.get(db, dataset_id)
|
|
if not dataset or dataset.project_id != project_id:
|
|
raise NotFoundException("Dataset", dataset_id)
|
|
|
|
# Return sample data based on format
|
|
sample_data = [
|
|
{
|
|
"instruction": "这是一个示例指令",
|
|
"input": "",
|
|
"output": "这是一个示例输出"
|
|
}
|
|
]
|
|
|
|
if request.format == "json":
|
|
return ApiResponse.ok(data=sample_data)
|
|
|
|
return ApiResponse.ok(
|
|
data={"data": sample_data, "format": request.format},
|
|
message="Dataset exported successfully"
|
|
)
|