123 lines
3.8 KiB
Python
123 lines
3.8 KiB
Python
"""
|
||
Questions API Router
|
||
"""
|
||
from typing import List, Optional
|
||
from uuid import UUID
|
||
from pydantic import BaseModel
|
||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select
|
||
from app.core.database import get_db
|
||
from app.models.models import Question, Chunk
|
||
from app.schemas.base import QuestionCreate, QuestionResponse
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
class GenerateRequest(BaseModel):
|
||
"""Request model for generating questions"""
|
||
chunk_ids: List[UUID] = []
|
||
count: int = 5
|
||
question_types: List[str] = ["fact", "summary"]
|
||
|
||
|
||
@router.post("/generate", response_model=dict)
|
||
async def generate_questions(
|
||
project_id: UUID,
|
||
request: GenerateRequest,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""Generate questions from chunks using LLM"""
|
||
# TODO: Implement LLM-based question generation
|
||
# This is a placeholder that creates sample questions
|
||
|
||
if not request.chunk_ids:
|
||
raise HTTPException(status_code=400, detail="chunk_ids is required")
|
||
|
||
# Get chunks
|
||
result = await db.execute(
|
||
select(Chunk).where(Chunk.id.in_(request.chunk_ids), Chunk.project_id == project_id)
|
||
)
|
||
chunks = result.scalars().all()
|
||
|
||
if not chunks:
|
||
raise HTTPException(status_code=404, detail="No chunks found")
|
||
|
||
# Create sample questions (placeholder)
|
||
created_questions = []
|
||
for chunk in chunks:
|
||
for i in range(request.count):
|
||
question = Question(
|
||
project_id=project_id,
|
||
chunk_id=chunk.id,
|
||
content=f"这是关于「{chunk.name}」的问题 {i+1}?",
|
||
answer=f"这是问题 {i+1} 的答案。",
|
||
question_type=request.question_types[0] if request.question_types else "fact",
|
||
source="generated"
|
||
)
|
||
db.add(question)
|
||
created_questions.append(question)
|
||
|
||
await db.commit()
|
||
|
||
return {
|
||
"questions": len(created_questions),
|
||
"message": f"Successfully generated {len(created_questions)} questions"
|
||
}
|
||
|
||
|
||
@router.get("/", response_model=dict)
|
||
async def list_questions(
|
||
project_id: UUID,
|
||
chunk_id: Optional[UUID] = Query(None),
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""List questions for a project"""
|
||
query = select(Question).where(Question.project_id == project_id)
|
||
|
||
if chunk_id:
|
||
query = query.where(Question.chunk_id == chunk_id)
|
||
|
||
result = await db.execute(query)
|
||
questions = result.scalars().all()
|
||
|
||
return {"questions": [QuestionResponse.model_validate(q) for q in questions]}
|
||
|
||
|
||
@router.put("/{question_id}", response_model=dict)
|
||
async def update_question(
|
||
project_id: UUID,
|
||
question_id: UUID,
|
||
question: QuestionCreate,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""Update question"""
|
||
result = await db.execute(
|
||
select(Question).where(Question.id == question_id, Question.project_id == project_id)
|
||
)
|
||
db_question = result.scalar_one_or_none()
|
||
if not db_question:
|
||
raise HTTPException(status_code=404, detail="Question not found")
|
||
|
||
for key, value in question.model_dump(exclude_unset=True).items():
|
||
setattr(db_question, key, value)
|
||
|
||
await db.commit()
|
||
await db.refresh(db_question)
|
||
return QuestionResponse.model_validate(db_question)
|
||
|
||
|
||
@router.delete("/{question_id}", response_model=dict)
|
||
async def delete_question(project_id: UUID, question_id: UUID, db: AsyncSession = Depends(get_db)):
|
||
"""Delete question"""
|
||
result = await db.execute(
|
||
select(Question).where(Question.id == question_id, Question.project_id == project_id)
|
||
)
|
||
question = result.scalar_one_or_none()
|
||
if not question:
|
||
raise HTTPException(status_code=404, detail="Question not found")
|
||
|
||
await db.delete(question)
|
||
await db.commit()
|
||
return {"message": "Question deleted successfully"}
|