- 更新 Chunks API 端点 - 更新 Datasets API 端点 - 更新 Evaluation API 端点 - 更新 Files API 端点 - 更新 Projects API 端点 - 更新 Questions API 端点 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
141 lines
3.8 KiB
Python
141 lines
3.8 KiB
Python
"""
|
|
Evaluation API Router
|
|
"""
|
|
from typing import Optional
|
|
from uuid import UUID
|
|
from pydantic import BaseModel, Field
|
|
from fastapi import APIRouter, Depends, Query
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.api.response import ApiResponse, PaginatedResponse
|
|
from app.core.database import get_db
|
|
from app.core.exceptions import NotFoundException
|
|
from app.core.crud import CRUDBase
|
|
from app.models.models import EvalDataset, Task
|
|
from app.schemas.eval import EvalDatasetResponse, TaskResponse
|
|
from app.schemas.eval import EvalDatasetCreateSchema
|
|
|
|
router = APIRouter()
|
|
|
|
# Initialize CRUD
|
|
eval_crud = CRUDBase(EvalDataset)
|
|
task_crud = CRUDBase(Task)
|
|
|
|
|
|
class GenerateEvalRequest(BaseModel):
|
|
"""Request for generating evaluation dataset"""
|
|
name: str = Field(..., min_length=1, max_length=255)
|
|
question_type: str = Field("mixed", pattern="^(mixed|fact|reasoning|summary)$")
|
|
count: int = Field(50, ge=1, le=500)
|
|
|
|
|
|
class RunEvalRequest(BaseModel):
|
|
"""Request for running evaluation"""
|
|
model_config_id: Optional[UUID] = None
|
|
|
|
|
|
@router.get("", response_model=ApiResponse)
|
|
async def list_eval_datasets(
|
|
project_id: UUID,
|
|
page: int = Query(1, ge=1),
|
|
page_size: int = Query(20, ge=1, le=100),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""List evaluation datasets"""
|
|
skip = (page - 1) * page_size
|
|
datasets, total = await eval_crud.get_multi(
|
|
db,
|
|
skip=skip,
|
|
limit=page_size,
|
|
filters={"project_id": project_id},
|
|
order_by="created_at",
|
|
descending=True
|
|
)
|
|
|
|
dataset_responses = [EvalDatasetResponse.model_validate(d) for d in datasets]
|
|
return PaginatedResponse.ok(
|
|
items=dataset_responses,
|
|
page=page,
|
|
page_size=page_size,
|
|
total=total
|
|
)
|
|
|
|
|
|
@router.post("", response_model=ApiResponse)
|
|
async def create_eval_dataset(
|
|
project_id: UUID,
|
|
request: GenerateEvalRequest,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""Create evaluation dataset"""
|
|
db_dataset = EvalDataset(
|
|
project_id=project_id,
|
|
name=request.name,
|
|
question_type=request.question_type
|
|
)
|
|
db.add(db_dataset)
|
|
await db.commit()
|
|
await db.refresh(db_dataset)
|
|
|
|
return ApiResponse.ok(
|
|
data={"id": str(db_dataset.id)},
|
|
message="Evaluation dataset created successfully"
|
|
)
|
|
|
|
|
|
@router.get("/{eval_id}", response_model=ApiResponse)
|
|
async def get_eval_dataset(
|
|
project_id: UUID,
|
|
eval_id: UUID,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""Get evaluation dataset by ID"""
|
|
dataset = await eval_crud.get(db, eval_id)
|
|
if not dataset or dataset.project_id != project_id:
|
|
raise NotFoundException("Evaluation Dataset", eval_id)
|
|
|
|
return ApiResponse.ok(data=EvalDatasetResponse.model_validate(dataset))
|
|
|
|
|
|
@router.post("/{eval_id}/evaluate", response_model=ApiResponse)
|
|
async def run_evaluation(
|
|
project_id: UUID,
|
|
eval_id: UUID,
|
|
request: RunEvalRequest,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""Run evaluation on dataset"""
|
|
# Check dataset exists
|
|
dataset = await eval_crud.get(db, eval_id)
|
|
if not dataset or dataset.project_id != project_id:
|
|
raise NotFoundException("Evaluation Dataset", eval_id)
|
|
|
|
# Create evaluation task
|
|
task = Task(
|
|
project_id=project_id,
|
|
task_type="eval",
|
|
status="pending"
|
|
)
|
|
db.add(task)
|
|
await db.commit()
|
|
await db.refresh(task)
|
|
|
|
return ApiResponse.ok(
|
|
data={"task_id": str(task.id)},
|
|
message="Evaluation task started"
|
|
)
|
|
|
|
|
|
@router.get("/results", response_model=ApiResponse)
|
|
async def get_eval_results(
|
|
project_id: UUID,
|
|
task_id: UUID,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""Get evaluation results"""
|
|
task = await task_crud.get(db, task_id)
|
|
if not task or task.project_id != project_id:
|
|
raise NotFoundException("Task", task_id)
|
|
|
|
return ApiResponse.ok(data=TaskResponse.model_validate(task))
|