Files
YG-Datasets/backend/app/api/v1/questions/__init__.py

123 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Questions API Router
"""
from typing import List, Optional
from uuid import UUID
from pydantic import BaseModel
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.database import get_db
from app.models.models import Question, Chunk
from app.schemas.base import QuestionCreate, QuestionResponse
router = APIRouter()
class GenerateRequest(BaseModel):
"""Request model for generating questions"""
chunk_ids: List[UUID] = []
count: int = 5
question_types: List[str] = ["fact", "summary"]
@router.post("/generate", response_model=dict)
async def generate_questions(
project_id: UUID,
request: GenerateRequest,
db: AsyncSession = Depends(get_db)
):
"""Generate questions from chunks using LLM"""
# TODO: Implement LLM-based question generation
# This is a placeholder that creates sample questions
if not request.chunk_ids:
raise HTTPException(status_code=400, detail="chunk_ids is required")
# Get chunks
result = await db.execute(
select(Chunk).where(Chunk.id.in_(request.chunk_ids), Chunk.project_id == project_id)
)
chunks = result.scalars().all()
if not chunks:
raise HTTPException(status_code=404, detail="No chunks found")
# Create sample questions (placeholder)
created_questions = []
for chunk in chunks:
for i in range(request.count):
question = Question(
project_id=project_id,
chunk_id=chunk.id,
content=f"这是关于「{chunk.name}」的问题 {i+1}",
answer=f"这是问题 {i+1} 的答案。",
question_type=request.question_types[0] if request.question_types else "fact",
source="generated"
)
db.add(question)
created_questions.append(question)
await db.commit()
return {
"questions": len(created_questions),
"message": f"Successfully generated {len(created_questions)} questions"
}
@router.get("/", response_model=dict)
async def list_questions(
project_id: UUID,
chunk_id: Optional[UUID] = Query(None),
db: AsyncSession = Depends(get_db)
):
"""List questions for a project"""
query = select(Question).where(Question.project_id == project_id)
if chunk_id:
query = query.where(Question.chunk_id == chunk_id)
result = await db.execute(query)
questions = result.scalars().all()
return {"questions": [QuestionResponse.model_validate(q) for q in questions]}
@router.put("/{question_id}", response_model=dict)
async def update_question(
project_id: UUID,
question_id: UUID,
question: QuestionCreate,
db: AsyncSession = Depends(get_db)
):
"""Update question"""
result = await db.execute(
select(Question).where(Question.id == question_id, Question.project_id == project_id)
)
db_question = result.scalar_one_or_none()
if not db_question:
raise HTTPException(status_code=404, detail="Question not found")
for key, value in question.model_dump(exclude_unset=True).items():
setattr(db_question, key, value)
await db.commit()
await db.refresh(db_question)
return QuestionResponse.model_validate(db_question)
@router.delete("/{question_id}", response_model=dict)
async def delete_question(project_id: UUID, question_id: UUID, db: AsyncSession = Depends(get_db)):
"""Delete question"""
result = await db.execute(
select(Question).where(Question.id == question_id, Question.project_id == project_id)
)
question = result.scalar_one_or_none()
if not question:
raise HTTPException(status_code=404, detail="Question not found")
await db.delete(question)
await db.commit()
return {"message": "Question deleted successfully"}