""" Questions API Router """ import asyncio import json import re from typing import List, Optional from uuid import UUID import httpx from fastapi import APIRouter, Depends, Query from pydantic import BaseModel, Field from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.api.response import ApiResponse, PaginatedResponse from app.core.crud import CRUDBase from app.core.database import AsyncSessionLocal, get_db from app.core.exceptions import NotFoundException, ValidationException from app.core.logging import log_failure, log_success from app.models.models import Chunk, ModelConfig, Question from app.schemas.question import QuestionCreateSchema, QuestionResponse router = APIRouter() # Initialize CRUD question_crud = CRUDBase(Question) VALID_MODEL_TYPES = {"chat", "vlm", "embedding", "rerank"} DEFAULT_PRESET_PROMPT = ( "你是一名高质量中文问答数据构建助手。" "请基于给定 chunk 内容生成准确、自然、可用于训练的数据集问答对。" "问题必须清晰具体,答案必须直接来自内容或基于内容做合理概括," "不要编造原文没有的信息,不要输出与目录、导航、页眉页脚、噪声文字相关的问题。" ) class GenerateRequest(BaseModel): """Request model for generating questions""" chunk_ids: List[UUID] = Field(..., min_length=1) model_id: UUID count: int = Field(3, ge=1, le=10) dirty_data_filter: bool = True thinking_mode: bool = True preset_prompt: str = Field(default=DEFAULT_PRESET_PROMPT, min_length=1, max_length=4000) def normalize_model_type(model_type: str | None, model_name: str | None) -> str: """Normalize model type, with keyword fallback for legacy records.""" if model_type in VALID_MODEL_TYPES and model_type != "chat": return model_type normalized_name = (model_name or "").strip().lower() rerank_keywords = ("rerank", "bce-reranker", "gte-rerank") embedding_keywords = ( "embedding", "embed", "text-embedding", "bge-", "bge_m3", "gte-", "m3e", "e5-", "jina-embeddings", ) vlm_keywords = ("vl", "vision", "visual", "multimodal", "qwen-vl", "gpt-4o") if any(keyword in normalized_name for keyword in rerank_keywords): return "rerank" if any(keyword in normalized_name for keyword in embedding_keywords): return "embedding" if any(keyword in normalized_name for keyword in vlm_keywords): return "vlm" return model_type if model_type in VALID_MODEL_TYPES else "chat" def is_dirty_chunk(content: str) -> bool: """Heuristic dirty-data filter for low-value chunks.""" normalized = re.sub(r"\s+", " ", (content or "")).strip() if len(normalized) < 40: return True if len(re.sub(r"[^\u4e00-\u9fffA-Za-z0-9]", "", normalized)) < 24: return True lowered = normalized.lower() if lowered in {"目录", "contents", "table of contents"}: return True lines = [line.strip() for line in (content or "").splitlines() if line.strip()] if lines: short_lines = sum(1 for line in lines if len(line) <= 18) dotted_lines = sum(1 for line in lines if re.search(r"[·•…\.]{3,}|\s\d+$", line)) if short_lines / len(lines) > 0.7 and len(lines) >= 3: return True if dotted_lines / len(lines) > 0.4: return True punctuation_ratio = sum(1 for ch in normalized if not ch.isalnum() and not ("\u4e00" <= ch <= "\u9fff")) / max(len(normalized), 1) if punctuation_ratio > 0.45: return True return False def build_generation_prompt(chunk: Chunk, request: GenerateRequest) -> str: """Build user prompt for QA generation.""" thinking_instruction = ( "请先对内容做简短分析,识别核心事实、概念、关系与潜在考点,然后再生成问答。" "分析过程只用于提高质量,不要在最终输出中暴露你的思维链。" if request.thinking_mode else "直接基于内容生成高质量问答。" ) return ( f"{request.preset_prompt}\n\n" "输出要求:\n" f"1. 生成 {request.count} 组问答。\n" "2. 只输出 JSON 数组,不要输出解释、标题、Markdown。\n" '3. 每个对象结构为 {"question":"...","answer":"...","question_type":"fact|summary|reasoning"}。\n' "4. 问题避免重复,答案避免空泛。\n" "5. 如果内容不足以生成高质量问答,请返回空数组 []。\n" f"6. {thinking_instruction}\n\n" f"Chunk 名称:{chunk.name or '未命名分片'}\n" f"Chunk 内容:\n{chunk.content}" ) def extract_text_from_response(data: dict) -> str: """Extract response text from provider response.""" choices = data.get("choices") or [] if choices: message = choices[0].get("message") or {} content = message.get("content") if isinstance(content, str): return content if isinstance(content, list): parts = [item.get("text", "") for item in content if isinstance(item, dict)] return "\n".join(part for part in parts if part) return "" def parse_generated_questions(raw_text: str) -> List[dict]: """Parse JSON array from model output.""" text = (raw_text or "").strip() if not text: return [] fenced_match = re.search(r"```json\s*(.*?)\s*```", text, flags=re.S) if fenced_match: text = fenced_match.group(1).strip() if not text.startswith("["): array_match = re.search(r"(\[\s*\{.*\}\s*\])", text, flags=re.S) if array_match: text = array_match.group(1) try: parsed = json.loads(text) except json.JSONDecodeError: return [] if not isinstance(parsed, list): return [] normalized = [] for item in parsed: if not isinstance(item, dict): continue question = str(item.get("question", "")).strip() answer = str(item.get("answer", "")).strip() question_type = str(item.get("question_type", "fact")).strip() or "fact" if not question or not answer: continue normalized.append({ "question": question, "answer": answer, "question_type": question_type }) return normalized async def call_generation_model(model: ModelConfig, prompt: str) -> str: """Call configured chat model for question generation.""" provider = model.provider api_base = (model.api_base or "").rstrip("/") api_key = model.api_key or "" model_name = model.model_name headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json" } payload = { "model": model_name, "messages": [ { "role": "system", "content": "你是问答数据构建助手。严格按 JSON 输出,不要输出额外说明。" }, { "role": "user", "content": prompt } ], "temperature": 0.4, "response_format": {"type": "json_object"} } async with httpx.AsyncClient(timeout=120.0) as client: if provider == "minimax": response = await client.post( f"{api_base}/chat/completions_v2", headers=headers, json={k: v for k, v in payload.items() if k != "response_format"} ) else: response = await client.post( f"{api_base}/chat/completions", headers=headers, json=payload ) response.raise_for_status() data = response.json() content = extract_text_from_response(data) if not content: raise ValueError("Model returned empty content") if content.lstrip().startswith("{"): obj = json.loads(content) if isinstance(obj, dict) and isinstance(obj.get("questions"), list): return json.dumps(obj["questions"], ensure_ascii=False) return content async def process_generate_async(project_id: UUID, request: GenerateRequest): """Generate QA pairs in background.""" async with AsyncSessionLocal() as db: try: model_result = await db.execute( select(ModelConfig).where(ModelConfig.id == request.model_id, ModelConfig.project_id == None) # noqa: E711 ) model = model_result.scalar_one_or_none() if not model: return model_type = normalize_model_type(model.model_type, model.model_name) if model_type not in {"chat", "vlm"}: raise ValidationException("Selected model must be chat/vlm type", field="model_id") chunk_result = await db.execute( select(Chunk).where(Chunk.id.in_(request.chunk_ids), Chunk.project_id == project_id) ) chunks = chunk_result.scalars().all() if not chunks: return created_count = 0 skipped_count = 0 for chunk in chunks: if request.dirty_data_filter and is_dirty_chunk(chunk.content): skipped_count += 1 continue prompt = build_generation_prompt(chunk, request) raw_text = await call_generation_model(model, prompt) qa_pairs = parse_generated_questions(raw_text)[:request.count] if not qa_pairs: skipped_count += 1 continue for item in qa_pairs: db.add(Question( project_id=project_id, chunk_id=chunk.id, content=item["question"], answer=item["answer"], question_type=item["question_type"], source="generated" )) created_count += 1 await db.commit() log_success( "问答批量生成完成", project_id=str(project_id), model_id=str(model.id), chunk_count=len(chunks), created_questions=created_count, skipped_chunks=skipped_count ) except Exception as e: log_failure( "问答批量生成失败", project_id=str(project_id), model_id=str(request.model_id), error=str(e) ) @router.post("/generate", response_model=ApiResponse) async def generate_questions( project_id: UUID, request: GenerateRequest, db: AsyncSession = Depends(get_db) ): """Generate questions from chunks using LLM in background.""" model_result = await db.execute( select(ModelConfig).where(ModelConfig.id == request.model_id, ModelConfig.project_id == None) # noqa: E711 ) model = model_result.scalar_one_or_none() if not model: raise ValidationException("Selected model not found", field="model_id") model_type = normalize_model_type(model.model_type, model.model_name) if model_type not in {"chat", "vlm"}: raise ValidationException("Selected model must be chat/vlm type", field="model_id") if not model.api_key: raise ValidationException("Selected model is missing API Key", field="model_id") chunk_result = await db.execute( select(Chunk.id).where(Chunk.id.in_(request.chunk_ids), Chunk.project_id == project_id) ) valid_chunk_ids = [row[0] for row in chunk_result.all()] if not valid_chunk_ids: raise ValidationException("No valid chunks found", field="chunk_ids") request_payload = request.model_copy(update={"chunk_ids": valid_chunk_ids}) asyncio.create_task(process_generate_async(project_id, request_payload)) return ApiResponse.ok( data={"chunk_count": len(valid_chunk_ids), "status": "processing"}, message="Question generation started in background" ) @router.get("", response_model=ApiResponse) async def list_questions( project_id: UUID, chunk_id: Optional[UUID] = Query(None), page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100), db: AsyncSession = Depends(get_db) ): """List questions for a project""" filters = {"project_id": project_id} if chunk_id: filters["chunk_id"] = chunk_id skip = (page - 1) * page_size questions, total = await question_crud.get_multi( db, skip=skip, limit=page_size, filters=filters, order_by="created_at", descending=True ) question_responses = [QuestionResponse.model_validate(q) for q in questions] return PaginatedResponse.ok( items=question_responses, page=page, page_size=page_size, total=total ) @router.put("/{question_id}", response_model=ApiResponse) async def update_question( project_id: UUID, question_id: UUID, question: QuestionCreateSchema, db: AsyncSession = Depends(get_db) ): """Update question""" db_question = await question_crud.get(db, question_id) if not db_question or db_question.project_id != project_id: raise NotFoundException("Question", question_id) updated_question = await question_crud.update(db, db_question, question) return ApiResponse.ok( data=QuestionResponse.model_validate(updated_question), message="Question updated successfully" ) @router.delete("/{question_id}", response_model=ApiResponse) async def delete_question( project_id: UUID, question_id: UUID, db: AsyncSession = Depends(get_db) ): """Delete question""" question = await question_crud.get(db, question_id) if not question or question.project_id != project_id: raise NotFoundException("Question", question_id) await question_crud.delete(db, question_id) return ApiResponse.ok(message="Question deleted successfully")