127 lines
3.6 KiB
Python
127 lines
3.6 KiB
Python
"""
|
|
Datasets API Router
|
|
"""
|
|
from typing import List, Optional
|
|
from uuid import UUID
|
|
from pydantic import BaseModel
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, func
|
|
from app.core.database import get_db
|
|
from app.models.models import Dataset, Question
|
|
from app.schemas.base import DatasetCreate, DatasetResponse
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class ExportRequest(BaseModel):
|
|
"""Export request schema"""
|
|
format: str = "alpaca" # alpaca, sharegpt, llama_factory, json
|
|
|
|
|
|
@router.get("/", response_model=dict)
|
|
async def list_datasets(project_id: UUID, db: AsyncSession = Depends(get_db)):
|
|
"""List datasets for a project"""
|
|
result = await db.execute(
|
|
select(Dataset).where(Dataset.project_id == project_id).order_by(Dataset.created_at.desc())
|
|
)
|
|
datasets = result.scalars().all()
|
|
|
|
# Get question count for each dataset
|
|
dataset_list = []
|
|
for dataset in datasets:
|
|
dataset_data = DatasetResponse.model_validate(dataset)
|
|
# TODO: Count questions in dataset
|
|
dataset_data.question_count = 0
|
|
dataset_list.append(dataset_data)
|
|
|
|
return {"datasets": dataset_list}
|
|
|
|
|
|
@router.post("/", response_model=dict)
|
|
async def create_dataset(
|
|
project_id: UUID,
|
|
dataset: DatasetCreate,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""Create a new dataset"""
|
|
db_dataset = Dataset(project_id=project_id, **dataset.model_dump())
|
|
db.add(db_dataset)
|
|
await db.commit()
|
|
await db.refresh(db_dataset)
|
|
|
|
return {"id": str(db_dataset.id)}
|
|
|
|
|
|
@router.get("/{dataset_id}", response_model=dict)
|
|
async def get_dataset(
|
|
project_id: UUID,
|
|
dataset_id: UUID,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""Get dataset by ID"""
|
|
result = await db.execute(
|
|
select(Dataset).where(Dataset.id == dataset_id, Dataset.project_id == project_id)
|
|
)
|
|
dataset = result.scalar_one_or_none()
|
|
if not dataset:
|
|
raise HTTPException(status_code=404, detail="Dataset not found")
|
|
|
|
return DatasetResponse.model_validate(dataset)
|
|
|
|
|
|
@router.delete("/{dataset_id}", response_model=dict)
|
|
async def delete_dataset(
|
|
project_id: UUID,
|
|
dataset_id: UUID,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""Delete dataset"""
|
|
result = await db.execute(
|
|
select(Dataset).where(Dataset.id == dataset_id, Dataset.project_id == project_id)
|
|
)
|
|
dataset = result.scalar_one_or_none()
|
|
if not dataset:
|
|
raise HTTPException(status_code=404, detail="Dataset not found")
|
|
|
|
await db.delete(dataset)
|
|
await db.commit()
|
|
|
|
return {"message": "Dataset deleted successfully"}
|
|
|
|
|
|
@router.post("/{dataset_id}/export")
|
|
async def export_dataset(
|
|
project_id: UUID,
|
|
dataset_id: UUID,
|
|
request: ExportRequest,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""Export dataset in specified format"""
|
|
# TODO: Implement actual export logic
|
|
|
|
# Get dataset
|
|
result = await db.execute(
|
|
select(Dataset).where(Dataset.id == dataset_id, Dataset.project_id == project_id)
|
|
)
|
|
dataset = result.scalar_one_or_none()
|
|
if not dataset:
|
|
raise HTTPException(status_code=404, detail="Dataset not found")
|
|
|
|
# Get questions for this dataset (placeholder)
|
|
# In real implementation, would link questions to datasets
|
|
|
|
# Return sample data based on format
|
|
sample_data = [
|
|
{
|
|
"instruction": "这是一个示例指令",
|
|
"input": "",
|
|
"output": "这是一个示例输出"
|
|
}
|
|
]
|
|
|
|
if request.format == "json":
|
|
return sample_data
|
|
|
|
return {"data": sample_data, "format": request.format}
|