Files
YG-Datasets/backend/app/api/v1/questions/__init__.py
Developer 6aa271c4f7 refactor: 前端架构重构 - 提取 CSS 和逻辑到独立模块
前端重构:
- 删除旧的大体积 Vue 组件(HomeView, FileManage, TextSplit 等)
- 删除旧的 composables(useFormatters, useModels, useProjects)
- 新增 core/, page-logic/, pages/, shared/ 模块化目录结构
- 提取 CSS 到 styles/pages/ 目录
- 添加全局样式 variables.css 和 common.css

后端 API 更新:
- chunks: 语义分割 API 增强
- files: 文件处理 API 更新
- models: 模型管理 API 更新
- questions: 问答管理 API 更新
- database: 数据库连接优化
- semantic_embedding: 语义嵌入服务优化

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-19 14:23:34 +08:00

403 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Questions API Router
"""
import asyncio
import json
import re
from typing import List, Optional
from uuid import UUID
import httpx
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.response import ApiResponse, PaginatedResponse
from app.core.crud import CRUDBase
from app.core.database import AsyncSessionLocal, get_db
from app.core.exceptions import NotFoundException, ValidationException
from app.core.logging import log_failure, log_success
from app.models.models import Chunk, ModelConfig, Question
from app.schemas.question import QuestionCreateSchema, QuestionResponse
router = APIRouter()
# Initialize CRUD
question_crud = CRUDBase(Question)
VALID_MODEL_TYPES = {"chat", "vlm", "embedding", "rerank"}
DEFAULT_PRESET_PROMPT = (
"你是一名高质量中文问答数据构建助手。"
"请基于给定 chunk 内容生成准确、自然、可用于训练的数据集问答对。"
"问题必须清晰具体,答案必须直接来自内容或基于内容做合理概括,"
"不要编造原文没有的信息,不要输出与目录、导航、页眉页脚、噪声文字相关的问题。"
)
class GenerateRequest(BaseModel):
"""Request model for generating questions"""
chunk_ids: List[UUID] = Field(..., min_length=1)
model_id: UUID
count: int = Field(3, ge=1, le=10)
dirty_data_filter: bool = True
thinking_mode: bool = True
preset_prompt: str = Field(default=DEFAULT_PRESET_PROMPT, min_length=1, max_length=4000)
def normalize_model_type(model_type: str | None, model_name: str | None) -> str:
"""Normalize model type, with keyword fallback for legacy records."""
if model_type in VALID_MODEL_TYPES and model_type != "chat":
return model_type
normalized_name = (model_name or "").strip().lower()
rerank_keywords = ("rerank", "bce-reranker", "gte-rerank")
embedding_keywords = (
"embedding",
"embed",
"text-embedding",
"bge-",
"bge_m3",
"gte-",
"m3e",
"e5-",
"jina-embeddings",
)
vlm_keywords = ("vl", "vision", "visual", "multimodal", "qwen-vl", "gpt-4o")
if any(keyword in normalized_name for keyword in rerank_keywords):
return "rerank"
if any(keyword in normalized_name for keyword in embedding_keywords):
return "embedding"
if any(keyword in normalized_name for keyword in vlm_keywords):
return "vlm"
return model_type if model_type in VALID_MODEL_TYPES else "chat"
def is_dirty_chunk(content: str) -> bool:
"""Heuristic dirty-data filter for low-value chunks."""
normalized = re.sub(r"\s+", " ", (content or "")).strip()
if len(normalized) < 40:
return True
if len(re.sub(r"[^\u4e00-\u9fffA-Za-z0-9]", "", normalized)) < 24:
return True
lowered = normalized.lower()
if lowered in {"目录", "contents", "table of contents"}:
return True
lines = [line.strip() for line in (content or "").splitlines() if line.strip()]
if lines:
short_lines = sum(1 for line in lines if len(line) <= 18)
dotted_lines = sum(1 for line in lines if re.search(r"[·•…\.]{3,}|\s\d+$", line))
if short_lines / len(lines) > 0.7 and len(lines) >= 3:
return True
if dotted_lines / len(lines) > 0.4:
return True
punctuation_ratio = sum(1 for ch in normalized if not ch.isalnum() and not ("\u4e00" <= ch <= "\u9fff")) / max(len(normalized), 1)
if punctuation_ratio > 0.45:
return True
return False
def build_generation_prompt(chunk: Chunk, request: GenerateRequest) -> str:
"""Build user prompt for QA generation."""
thinking_instruction = (
"请先对内容做简短分析,识别核心事实、概念、关系与潜在考点,然后再生成问答。"
"分析过程只用于提高质量,不要在最终输出中暴露你的思维链。"
if request.thinking_mode
else "直接基于内容生成高质量问答。"
)
return (
f"{request.preset_prompt}\n\n"
"输出要求:\n"
f"1. 生成 {request.count} 组问答。\n"
"2. 只输出 JSON 数组不要输出解释、标题、Markdown。\n"
'3. 每个对象结构为 {"question":"...","answer":"...","question_type":"fact|summary|reasoning"}。\n'
"4. 问题避免重复,答案避免空泛。\n"
"5. 如果内容不足以生成高质量问答,请返回空数组 []。\n"
f"6. {thinking_instruction}\n\n"
f"Chunk 名称:{chunk.name or '未命名分片'}\n"
f"Chunk 内容:\n{chunk.content}"
)
def extract_text_from_response(data: dict) -> str:
"""Extract response text from provider response."""
choices = data.get("choices") or []
if choices:
message = choices[0].get("message") or {}
content = message.get("content")
if isinstance(content, str):
return content
if isinstance(content, list):
parts = [item.get("text", "") for item in content if isinstance(item, dict)]
return "\n".join(part for part in parts if part)
return ""
def parse_generated_questions(raw_text: str) -> List[dict]:
"""Parse JSON array from model output."""
text = (raw_text or "").strip()
if not text:
return []
fenced_match = re.search(r"```json\s*(.*?)\s*```", text, flags=re.S)
if fenced_match:
text = fenced_match.group(1).strip()
if not text.startswith("["):
array_match = re.search(r"(\[\s*\{.*\}\s*\])", text, flags=re.S)
if array_match:
text = array_match.group(1)
try:
parsed = json.loads(text)
except json.JSONDecodeError:
return []
if not isinstance(parsed, list):
return []
normalized = []
for item in parsed:
if not isinstance(item, dict):
continue
question = str(item.get("question", "")).strip()
answer = str(item.get("answer", "")).strip()
question_type = str(item.get("question_type", "fact")).strip() or "fact"
if not question or not answer:
continue
normalized.append({
"question": question,
"answer": answer,
"question_type": question_type
})
return normalized
async def call_generation_model(model: ModelConfig, prompt: str) -> str:
"""Call configured chat model for question generation."""
provider = model.provider
api_base = (model.api_base or "").rstrip("/")
api_key = model.api_key or ""
model_name = model.model_name
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
payload = {
"model": model_name,
"messages": [
{
"role": "system",
"content": "你是问答数据构建助手。严格按 JSON 输出,不要输出额外说明。"
},
{
"role": "user",
"content": prompt
}
],
"temperature": 0.4,
"response_format": {"type": "json_object"}
}
async with httpx.AsyncClient(timeout=120.0) as client:
if provider == "minimax":
response = await client.post(
f"{api_base}/chat/completions_v2",
headers=headers,
json={k: v for k, v in payload.items() if k != "response_format"}
)
else:
response = await client.post(
f"{api_base}/chat/completions",
headers=headers,
json=payload
)
response.raise_for_status()
data = response.json()
content = extract_text_from_response(data)
if not content:
raise ValueError("Model returned empty content")
if content.lstrip().startswith("{"):
obj = json.loads(content)
if isinstance(obj, dict) and isinstance(obj.get("questions"), list):
return json.dumps(obj["questions"], ensure_ascii=False)
return content
async def process_generate_async(project_id: UUID, request: GenerateRequest):
"""Generate QA pairs in background."""
async with AsyncSessionLocal() as db:
try:
model_result = await db.execute(
select(ModelConfig).where(ModelConfig.id == request.model_id, ModelConfig.project_id == None) # noqa: E711
)
model = model_result.scalar_one_or_none()
if not model:
return
model_type = normalize_model_type(model.model_type, model.model_name)
if model_type not in {"chat", "vlm"}:
raise ValidationException("Selected model must be chat/vlm type", field="model_id")
chunk_result = await db.execute(
select(Chunk).where(Chunk.id.in_(request.chunk_ids), Chunk.project_id == project_id)
)
chunks = chunk_result.scalars().all()
if not chunks:
return
created_count = 0
skipped_count = 0
for chunk in chunks:
if request.dirty_data_filter and is_dirty_chunk(chunk.content):
skipped_count += 1
continue
prompt = build_generation_prompt(chunk, request)
raw_text = await call_generation_model(model, prompt)
qa_pairs = parse_generated_questions(raw_text)[:request.count]
if not qa_pairs:
skipped_count += 1
continue
for item in qa_pairs:
db.add(Question(
project_id=project_id,
chunk_id=chunk.id,
content=item["question"],
answer=item["answer"],
question_type=item["question_type"],
source="generated"
))
created_count += 1
await db.commit()
log_success(
"问答批量生成完成",
project_id=str(project_id),
model_id=str(model.id),
chunk_count=len(chunks),
created_questions=created_count,
skipped_chunks=skipped_count
)
except Exception as e:
log_failure(
"问答批量生成失败",
project_id=str(project_id),
model_id=str(request.model_id),
error=str(e)
)
@router.post("/generate", response_model=ApiResponse)
async def generate_questions(
project_id: UUID,
request: GenerateRequest,
db: AsyncSession = Depends(get_db)
):
"""Generate questions from chunks using LLM in background."""
model_result = await db.execute(
select(ModelConfig).where(ModelConfig.id == request.model_id, ModelConfig.project_id == None) # noqa: E711
)
model = model_result.scalar_one_or_none()
if not model:
raise ValidationException("Selected model not found", field="model_id")
model_type = normalize_model_type(model.model_type, model.model_name)
if model_type not in {"chat", "vlm"}:
raise ValidationException("Selected model must be chat/vlm type", field="model_id")
if not model.api_key:
raise ValidationException("Selected model is missing API Key", field="model_id")
chunk_result = await db.execute(
select(Chunk.id).where(Chunk.id.in_(request.chunk_ids), Chunk.project_id == project_id)
)
valid_chunk_ids = [row[0] for row in chunk_result.all()]
if not valid_chunk_ids:
raise ValidationException("No valid chunks found", field="chunk_ids")
request_payload = request.model_copy(update={"chunk_ids": valid_chunk_ids})
asyncio.create_task(process_generate_async(project_id, request_payload))
return ApiResponse.ok(
data={"chunk_count": len(valid_chunk_ids), "status": "processing"},
message="Question generation started in background"
)
@router.get("", response_model=ApiResponse)
async def list_questions(
project_id: UUID,
chunk_id: Optional[UUID] = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
db: AsyncSession = Depends(get_db)
):
"""List questions for a project"""
filters = {"project_id": project_id}
if chunk_id:
filters["chunk_id"] = chunk_id
skip = (page - 1) * page_size
questions, total = await question_crud.get_multi(
db,
skip=skip,
limit=page_size,
filters=filters,
order_by="created_at",
descending=True
)
question_responses = [QuestionResponse.model_validate(q) for q in questions]
return PaginatedResponse.ok(
items=question_responses,
page=page,
page_size=page_size,
total=total
)
@router.put("/{question_id}", response_model=ApiResponse)
async def update_question(
project_id: UUID,
question_id: UUID,
question: QuestionCreateSchema,
db: AsyncSession = Depends(get_db)
):
"""Update question"""
db_question = await question_crud.get(db, question_id)
if not db_question or db_question.project_id != project_id:
raise NotFoundException("Question", question_id)
updated_question = await question_crud.update(db, db_question, question)
return ApiResponse.ok(
data=QuestionResponse.model_validate(updated_question),
message="Question updated successfully"
)
@router.delete("/{question_id}", response_model=ApiResponse)
async def delete_question(
project_id: UUID,
question_id: UUID,
db: AsyncSession = Depends(get_db)
):
"""Delete question"""
question = await question_crud.get(db, question_id)
if not question or question.project_id != project_id:
raise NotFoundException("Question", question_id)
await question_crud.delete(db, question_id)
return ApiResponse.ok(message="Question deleted successfully")