101 lines
2.9 KiB
Python
101 lines
2.9 KiB
Python
|
|
"""
|
||
|
|
Evaluation API Router
|
||
|
|
"""
|
||
|
|
from typing import List, Optional
|
||
|
|
from uuid import UUID
|
||
|
|
from pydantic import BaseModel
|
||
|
|
from fastapi import APIRouter, Depends, HTTPException
|
||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
|
from sqlalchemy import select
|
||
|
|
from app.core.database import get_db
|
||
|
|
from app.models.models import EvalDataset, Task
|
||
|
|
from app.schemas.base import EvalDatasetCreate, EvalDatasetResponse, TaskResponse
|
||
|
|
|
||
|
|
router = APIRouter()
|
||
|
|
|
||
|
|
|
||
|
|
class GenerateEvalRequest(BaseModel):
|
||
|
|
"""Request for generating evaluation dataset"""
|
||
|
|
name: str
|
||
|
|
question_type: str = "mixed"
|
||
|
|
count: int = 50
|
||
|
|
|
||
|
|
|
||
|
|
class RunEvalRequest(BaseModel):
|
||
|
|
"""Request for running evaluation"""
|
||
|
|
model_config_id: Optional[UUID] = None
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/", response_model=dict)
|
||
|
|
async def list_eval_datasets(project_id: UUID, db: AsyncSession = Depends(get_db)):
|
||
|
|
"""List evaluation datasets"""
|
||
|
|
result = await db.execute(
|
||
|
|
select(EvalDataset).where(EvalDataset.project_id == project_id).order_by(EvalDataset.created_at.desc())
|
||
|
|
)
|
||
|
|
datasets = result.scalars().all()
|
||
|
|
|
||
|
|
return {"datasets": [EvalDatasetResponse.model_validate(d) for d in datasets]}
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/", response_model=dict)
|
||
|
|
async def create_eval_dataset(
|
||
|
|
project_id: UUID,
|
||
|
|
request: GenerateEvalRequest,
|
||
|
|
db: AsyncSession = Depends(get_db)
|
||
|
|
):
|
||
|
|
"""Create evaluation dataset"""
|
||
|
|
db_dataset = EvalDataset(
|
||
|
|
project_id=project_id,
|
||
|
|
name=request.name,
|
||
|
|
question_type=request.question_type
|
||
|
|
)
|
||
|
|
db.add(db_dataset)
|
||
|
|
await db.commit()
|
||
|
|
await db.refresh(db_dataset)
|
||
|
|
|
||
|
|
return {"id": str(db_dataset.id)}
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/{eval_id}/evaluate", response_model=dict)
|
||
|
|
async def run_evaluation(
|
||
|
|
project_id: UUID,
|
||
|
|
eval_id: UUID,
|
||
|
|
request: RunEvalRequest,
|
||
|
|
db: AsyncSession = Depends(get_db)
|
||
|
|
):
|
||
|
|
"""Run evaluation on dataset"""
|
||
|
|
# Check dataset exists
|
||
|
|
result = await db.execute(
|
||
|
|
select(EvalDataset).where(EvalDataset.id == eval_id, EvalDataset.project_id == project_id)
|
||
|
|
)
|
||
|
|
dataset = result.scalar_one_or_none()
|
||
|
|
if not dataset:
|
||
|
|
raise HTTPException(status_code=404, detail="Evaluation dataset not found")
|
||
|
|
|
||
|
|
# Create evaluation task
|
||
|
|
task = Task(
|
||
|
|
project_id=project_id,
|
||
|
|
task_type="eval",
|
||
|
|
status="pending"
|
||
|
|
)
|
||
|
|
db.add(task)
|
||
|
|
await db.commit()
|
||
|
|
await db.refresh(task)
|
||
|
|
|
||
|
|
# TODO: Start evaluation in background
|
||
|
|
|
||
|
|
return {"task_id": str(task.id), "message": "Evaluation task started"}
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/results", response_model=dict)
|
||
|
|
async def get_eval_results(project_id: UUID, task_id: UUID, db: AsyncSession = Depends(get_db)):
|
||
|
|
"""Get evaluation results"""
|
||
|
|
result = await db.execute(
|
||
|
|
select(Task).where(Task.id == task_id, Task.project_id == project_id)
|
||
|
|
)
|
||
|
|
task = result.scalar_one_or_none()
|
||
|
|
if not task:
|
||
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
||
|
|
|
||
|
|
return TaskResponse.model_validate(task)
|