- 更新 Chunks API 端点 - 更新 Datasets API 端点 - 更新 Evaluation API 端点 - 更新 Files API 端点 - 更新 Projects API 端点 - 更新 Questions API 端点 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
134 lines
4.1 KiB
Python
134 lines
4.1 KiB
Python
"""
|
||
Questions API Router
|
||
"""
|
||
from typing import List, Optional
|
||
from uuid import UUID
|
||
from pydantic import BaseModel, Field
|
||
from fastapi import APIRouter, Depends, Query
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.api.response import ApiResponse, PaginatedResponse
|
||
from app.core.database import get_db
|
||
from app.core.exceptions import NotFoundException, ValidationException
|
||
from app.core.crud import CRUDBase
|
||
from app.models.models import Question, Chunk
|
||
from app.schemas.question import QuestionResponse
|
||
from app.schemas.question import QuestionCreateSchema
|
||
|
||
router = APIRouter()
|
||
|
||
# Initialize CRUD
|
||
question_crud = CRUDBase(Question)
|
||
|
||
|
||
class GenerateRequest(BaseModel):
|
||
"""Request model for generating questions"""
|
||
chunk_ids: List[UUID] = Field(..., min_length=1)
|
||
count: int = Field(5, ge=1, le=50)
|
||
question_types: List[str] = ["fact", "summary"]
|
||
|
||
|
||
@router.post("/generate", response_model=ApiResponse)
|
||
async def generate_questions(
|
||
project_id: UUID,
|
||
request: GenerateRequest,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""Generate questions from chunks using LLM"""
|
||
# Get chunks
|
||
result = await db.execute(
|
||
select(Chunk).where(Chunk.id.in_(request.chunk_ids), Chunk.project_id == project_id)
|
||
)
|
||
chunks = result.scalars().all()
|
||
|
||
if not chunks:
|
||
raise ValidationException("No valid chunks found", field="chunk_ids")
|
||
|
||
# Create sample questions (placeholder for LLM-based generation)
|
||
created_questions = []
|
||
for chunk in chunks:
|
||
for i in range(request.count):
|
||
question = Question(
|
||
project_id=project_id,
|
||
chunk_id=chunk.id,
|
||
content=f"这是关于「{chunk.name}」的问题 {i+1}?",
|
||
answer=f"这是问题 {i+1} 的答案。",
|
||
question_type=request.question_types[0] if request.question_types else "fact",
|
||
source="generated"
|
||
)
|
||
db.add(question)
|
||
created_questions.append(question)
|
||
|
||
await db.commit()
|
||
|
||
return ApiResponse.ok(
|
||
data={"questions": len(created_questions)},
|
||
message=f"Successfully generated {len(created_questions)} questions"
|
||
)
|
||
|
||
|
||
@router.get("", response_model=ApiResponse)
|
||
async def list_questions(
|
||
project_id: UUID,
|
||
chunk_id: Optional[UUID] = Query(None),
|
||
page: int = Query(1, ge=1),
|
||
page_size: int = Query(20, ge=1, le=100),
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""List questions for a project"""
|
||
filters = {"project_id": project_id}
|
||
if chunk_id:
|
||
filters["chunk_id"] = chunk_id
|
||
|
||
skip = (page - 1) * page_size
|
||
questions, total = await question_crud.get_multi(
|
||
db,
|
||
skip=skip,
|
||
limit=page_size,
|
||
filters=filters,
|
||
order_by="created_at",
|
||
descending=True
|
||
)
|
||
|
||
question_responses = [QuestionResponse.model_validate(q) for q in questions]
|
||
return PaginatedResponse.ok(
|
||
items=question_responses,
|
||
page=page,
|
||
page_size=page_size,
|
||
total=total
|
||
)
|
||
|
||
|
||
@router.put("/{question_id}", response_model=ApiResponse)
|
||
async def update_question(
|
||
project_id: UUID,
|
||
question_id: UUID,
|
||
question: QuestionCreateSchema,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""Update question"""
|
||
db_question = await question_crud.get(db, question_id)
|
||
if not db_question or db_question.project_id != project_id:
|
||
raise NotFoundException("Question", question_id)
|
||
|
||
updated_question = await question_crud.update(db, db_question, question)
|
||
return ApiResponse.ok(
|
||
data=QuestionResponse.model_validate(updated_question),
|
||
message="Question updated successfully"
|
||
)
|
||
|
||
|
||
@router.delete("/{question_id}", response_model=ApiResponse)
|
||
async def delete_question(
|
||
project_id: UUID,
|
||
question_id: UUID,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""Delete question"""
|
||
question = await question_crud.get(db, question_id)
|
||
if not question or question.project_id != project_id:
|
||
raise NotFoundException("Question", question_id)
|
||
|
||
await question_crud.delete(db, question_id)
|
||
return ApiResponse.ok(message="Question deleted successfully")
|