Files

403 lines
14 KiB
Python
Raw Permalink Normal View History

2026-03-17 14:36:31 +08:00
"""
Questions API Router
"""
import asyncio
import json
import re
2026-03-17 14:36:31 +08:00
from typing import List, Optional
from uuid import UUID
import httpx
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel, Field
from sqlalchemy import select
2026-03-17 14:36:31 +08:00
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.response import ApiResponse, PaginatedResponse
from app.core.crud import CRUDBase
from app.core.database import AsyncSessionLocal, get_db
from app.core.exceptions import NotFoundException, ValidationException
from app.core.logging import log_failure, log_success
from app.models.models import Chunk, ModelConfig, Question
from app.schemas.question import QuestionCreateSchema, QuestionResponse
2026-03-17 14:36:31 +08:00
router = APIRouter()
# Initialize CRUD
question_crud = CRUDBase(Question)
VALID_MODEL_TYPES = {"chat", "vlm", "embedding", "rerank"}
DEFAULT_PRESET_PROMPT = (
"你是一名高质量中文问答数据构建助手。"
"请基于给定 chunk 内容生成准确、自然、可用于训练的数据集问答对。"
"问题必须清晰具体,答案必须直接来自内容或基于内容做合理概括,"
"不要编造原文没有的信息,不要输出与目录、导航、页眉页脚、噪声文字相关的问题。"
)
2026-03-17 14:36:31 +08:00
class GenerateRequest(BaseModel):
"""Request model for generating questions"""
chunk_ids: List[UUID] = Field(..., min_length=1)
model_id: UUID
count: int = Field(3, ge=1, le=10)
dirty_data_filter: bool = True
thinking_mode: bool = True
preset_prompt: str = Field(default=DEFAULT_PRESET_PROMPT, min_length=1, max_length=4000)
def normalize_model_type(model_type: str | None, model_name: str | None) -> str:
"""Normalize model type, with keyword fallback for legacy records."""
if model_type in VALID_MODEL_TYPES and model_type != "chat":
return model_type
normalized_name = (model_name or "").strip().lower()
rerank_keywords = ("rerank", "bce-reranker", "gte-rerank")
embedding_keywords = (
"embedding",
"embed",
"text-embedding",
"bge-",
"bge_m3",
"gte-",
"m3e",
"e5-",
"jina-embeddings",
)
vlm_keywords = ("vl", "vision", "visual", "multimodal", "qwen-vl", "gpt-4o")
if any(keyword in normalized_name for keyword in rerank_keywords):
return "rerank"
if any(keyword in normalized_name for keyword in embedding_keywords):
return "embedding"
if any(keyword in normalized_name for keyword in vlm_keywords):
return "vlm"
return model_type if model_type in VALID_MODEL_TYPES else "chat"
def is_dirty_chunk(content: str) -> bool:
"""Heuristic dirty-data filter for low-value chunks."""
normalized = re.sub(r"\s+", " ", (content or "")).strip()
if len(normalized) < 40:
return True
if len(re.sub(r"[^\u4e00-\u9fffA-Za-z0-9]", "", normalized)) < 24:
return True
lowered = normalized.lower()
if lowered in {"目录", "contents", "table of contents"}:
return True
lines = [line.strip() for line in (content or "").splitlines() if line.strip()]
if lines:
short_lines = sum(1 for line in lines if len(line) <= 18)
dotted_lines = sum(1 for line in lines if re.search(r"[·•…\.]{3,}|\s\d+$", line))
if short_lines / len(lines) > 0.7 and len(lines) >= 3:
return True
if dotted_lines / len(lines) > 0.4:
return True
punctuation_ratio = sum(1 for ch in normalized if not ch.isalnum() and not ("\u4e00" <= ch <= "\u9fff")) / max(len(normalized), 1)
if punctuation_ratio > 0.45:
return True
return False
def build_generation_prompt(chunk: Chunk, request: GenerateRequest) -> str:
"""Build user prompt for QA generation."""
thinking_instruction = (
"请先对内容做简短分析,识别核心事实、概念、关系与潜在考点,然后再生成问答。"
"分析过程只用于提高质量,不要在最终输出中暴露你的思维链。"
if request.thinking_mode
else "直接基于内容生成高质量问答。"
)
return (
f"{request.preset_prompt}\n\n"
"输出要求:\n"
f"1. 生成 {request.count} 组问答。\n"
"2. 只输出 JSON 数组不要输出解释、标题、Markdown。\n"
'3. 每个对象结构为 {"question":"...","answer":"...","question_type":"fact|summary|reasoning"}。\n'
"4. 问题避免重复,答案避免空泛。\n"
"5. 如果内容不足以生成高质量问答,请返回空数组 []。\n"
f"6. {thinking_instruction}\n\n"
f"Chunk 名称:{chunk.name or '未命名分片'}\n"
f"Chunk 内容:\n{chunk.content}"
)
def extract_text_from_response(data: dict) -> str:
"""Extract response text from provider response."""
choices = data.get("choices") or []
if choices:
message = choices[0].get("message") or {}
content = message.get("content")
if isinstance(content, str):
return content
if isinstance(content, list):
parts = [item.get("text", "") for item in content if isinstance(item, dict)]
return "\n".join(part for part in parts if part)
return ""
def parse_generated_questions(raw_text: str) -> List[dict]:
"""Parse JSON array from model output."""
text = (raw_text or "").strip()
if not text:
return []
fenced_match = re.search(r"```json\s*(.*?)\s*```", text, flags=re.S)
if fenced_match:
text = fenced_match.group(1).strip()
if not text.startswith("["):
array_match = re.search(r"(\[\s*\{.*\}\s*\])", text, flags=re.S)
if array_match:
text = array_match.group(1)
try:
parsed = json.loads(text)
except json.JSONDecodeError:
return []
if not isinstance(parsed, list):
return []
normalized = []
for item in parsed:
if not isinstance(item, dict):
continue
question = str(item.get("question", "")).strip()
answer = str(item.get("answer", "")).strip()
question_type = str(item.get("question_type", "fact")).strip() or "fact"
if not question or not answer:
continue
normalized.append({
"question": question,
"answer": answer,
"question_type": question_type
})
return normalized
async def call_generation_model(model: ModelConfig, prompt: str) -> str:
"""Call configured chat model for question generation."""
provider = model.provider
api_base = (model.api_base or "").rstrip("/")
api_key = model.api_key or ""
model_name = model.model_name
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
payload = {
"model": model_name,
"messages": [
{
"role": "system",
"content": "你是问答数据构建助手。严格按 JSON 输出,不要输出额外说明。"
},
{
"role": "user",
"content": prompt
}
],
"temperature": 0.4,
"response_format": {"type": "json_object"}
}
async with httpx.AsyncClient(timeout=120.0) as client:
if provider == "minimax":
response = await client.post(
f"{api_base}/chat/completions_v2",
headers=headers,
json={k: v for k, v in payload.items() if k != "response_format"}
)
else:
response = await client.post(
f"{api_base}/chat/completions",
headers=headers,
json=payload
)
response.raise_for_status()
data = response.json()
content = extract_text_from_response(data)
if not content:
raise ValueError("Model returned empty content")
if content.lstrip().startswith("{"):
obj = json.loads(content)
if isinstance(obj, dict) and isinstance(obj.get("questions"), list):
return json.dumps(obj["questions"], ensure_ascii=False)
return content
async def process_generate_async(project_id: UUID, request: GenerateRequest):
"""Generate QA pairs in background."""
async with AsyncSessionLocal() as db:
try:
model_result = await db.execute(
select(ModelConfig).where(ModelConfig.id == request.model_id, ModelConfig.project_id == None) # noqa: E711
)
model = model_result.scalar_one_or_none()
if not model:
return
model_type = normalize_model_type(model.model_type, model.model_name)
if model_type not in {"chat", "vlm"}:
raise ValidationException("Selected model must be chat/vlm type", field="model_id")
chunk_result = await db.execute(
select(Chunk).where(Chunk.id.in_(request.chunk_ids), Chunk.project_id == project_id)
)
chunks = chunk_result.scalars().all()
if not chunks:
return
created_count = 0
skipped_count = 0
for chunk in chunks:
if request.dirty_data_filter and is_dirty_chunk(chunk.content):
skipped_count += 1
continue
prompt = build_generation_prompt(chunk, request)
raw_text = await call_generation_model(model, prompt)
qa_pairs = parse_generated_questions(raw_text)[:request.count]
if not qa_pairs:
skipped_count += 1
continue
for item in qa_pairs:
db.add(Question(
project_id=project_id,
chunk_id=chunk.id,
content=item["question"],
answer=item["answer"],
question_type=item["question_type"],
source="generated"
))
created_count += 1
await db.commit()
log_success(
"问答批量生成完成",
project_id=str(project_id),
model_id=str(model.id),
chunk_count=len(chunks),
created_questions=created_count,
skipped_chunks=skipped_count
)
except Exception as e:
log_failure(
"问答批量生成失败",
project_id=str(project_id),
model_id=str(request.model_id),
error=str(e)
)
2026-03-17 14:36:31 +08:00
@router.post("/generate", response_model=ApiResponse)
2026-03-17 14:36:31 +08:00
async def generate_questions(
project_id: UUID,
request: GenerateRequest,
db: AsyncSession = Depends(get_db)
):
"""Generate questions from chunks using LLM in background."""
model_result = await db.execute(
select(ModelConfig).where(ModelConfig.id == request.model_id, ModelConfig.project_id == None) # noqa: E711
2026-03-17 14:36:31 +08:00
)
model = model_result.scalar_one_or_none()
if not model:
raise ValidationException("Selected model not found", field="model_id")
2026-03-17 14:36:31 +08:00
model_type = normalize_model_type(model.model_type, model.model_name)
if model_type not in {"chat", "vlm"}:
raise ValidationException("Selected model must be chat/vlm type", field="model_id")
if not model.api_key:
raise ValidationException("Selected model is missing API Key", field="model_id")
2026-03-17 14:36:31 +08:00
chunk_result = await db.execute(
select(Chunk.id).where(Chunk.id.in_(request.chunk_ids), Chunk.project_id == project_id)
)
valid_chunk_ids = [row[0] for row in chunk_result.all()]
if not valid_chunk_ids:
raise ValidationException("No valid chunks found", field="chunk_ids")
2026-03-17 14:36:31 +08:00
request_payload = request.model_copy(update={"chunk_ids": valid_chunk_ids})
asyncio.create_task(process_generate_async(project_id, request_payload))
2026-03-17 14:36:31 +08:00
return ApiResponse.ok(
data={"chunk_count": len(valid_chunk_ids), "status": "processing"},
message="Question generation started in background"
)
2026-03-17 14:36:31 +08:00
@router.get("", response_model=ApiResponse)
2026-03-17 14:36:31 +08:00
async def list_questions(
project_id: UUID,
chunk_id: Optional[UUID] = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
2026-03-17 14:36:31 +08:00
db: AsyncSession = Depends(get_db)
):
"""List questions for a project"""
filters = {"project_id": project_id}
2026-03-17 14:36:31 +08:00
if chunk_id:
filters["chunk_id"] = chunk_id
skip = (page - 1) * page_size
questions, total = await question_crud.get_multi(
db,
skip=skip,
limit=page_size,
filters=filters,
order_by="created_at",
descending=True
)
2026-03-17 14:36:31 +08:00
question_responses = [QuestionResponse.model_validate(q) for q in questions]
return PaginatedResponse.ok(
items=question_responses,
page=page,
page_size=page_size,
total=total
)
2026-03-17 14:36:31 +08:00
@router.put("/{question_id}", response_model=ApiResponse)
2026-03-17 14:36:31 +08:00
async def update_question(
project_id: UUID,
question_id: UUID,
question: QuestionCreateSchema,
2026-03-17 14:36:31 +08:00
db: AsyncSession = Depends(get_db)
):
"""Update question"""
db_question = await question_crud.get(db, question_id)
if not db_question or db_question.project_id != project_id:
raise NotFoundException("Question", question_id)
updated_question = await question_crud.update(db, db_question, question)
return ApiResponse.ok(
data=QuestionResponse.model_validate(updated_question),
message="Question updated successfully"
2026-03-17 14:36:31 +08:00
)
@router.delete("/{question_id}", response_model=ApiResponse)
async def delete_question(
project_id: UUID,
question_id: UUID,
db: AsyncSession = Depends(get_db)
):
2026-03-17 14:36:31 +08:00
"""Delete question"""
question = await question_crud.get(db, question_id)
if not question or question.project_id != project_id:
raise NotFoundException("Question", question_id)
2026-03-17 14:36:31 +08:00
await question_crud.delete(db, question_id)
return ApiResponse.ok(message="Question deleted successfully")