前端重构: - 删除旧的大体积 Vue 组件(HomeView, FileManage, TextSplit 等) - 删除旧的 composables(useFormatters, useModels, useProjects) - 新增 core/, page-logic/, pages/, shared/ 模块化目录结构 - 提取 CSS 到 styles/pages/ 目录 - 添加全局样式 variables.css 和 common.css 后端 API 更新: - chunks: 语义分割 API 增强 - files: 文件处理 API 更新 - models: 模型管理 API 更新 - questions: 问答管理 API 更新 - database: 数据库连接优化 - semantic_embedding: 语义嵌入服务优化 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
403 lines
14 KiB
Python
403 lines
14 KiB
Python
"""
|
||
Questions API Router
|
||
"""
|
||
import asyncio
|
||
import json
|
||
import re
|
||
from typing import List, Optional
|
||
from uuid import UUID
|
||
|
||
import httpx
|
||
from fastapi import APIRouter, Depends, Query
|
||
from pydantic import BaseModel, Field
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.api.response import ApiResponse, PaginatedResponse
|
||
from app.core.crud import CRUDBase
|
||
from app.core.database import AsyncSessionLocal, get_db
|
||
from app.core.exceptions import NotFoundException, ValidationException
|
||
from app.core.logging import log_failure, log_success
|
||
from app.models.models import Chunk, ModelConfig, Question
|
||
from app.schemas.question import QuestionCreateSchema, QuestionResponse
|
||
|
||
router = APIRouter()
|
||
|
||
# Initialize CRUD
|
||
question_crud = CRUDBase(Question)
|
||
|
||
VALID_MODEL_TYPES = {"chat", "vlm", "embedding", "rerank"}
|
||
|
||
DEFAULT_PRESET_PROMPT = (
|
||
"你是一名高质量中文问答数据构建助手。"
|
||
"请基于给定 chunk 内容生成准确、自然、可用于训练的数据集问答对。"
|
||
"问题必须清晰具体,答案必须直接来自内容或基于内容做合理概括,"
|
||
"不要编造原文没有的信息,不要输出与目录、导航、页眉页脚、噪声文字相关的问题。"
|
||
)
|
||
|
||
|
||
class GenerateRequest(BaseModel):
|
||
"""Request model for generating questions"""
|
||
chunk_ids: List[UUID] = Field(..., min_length=1)
|
||
model_id: UUID
|
||
count: int = Field(3, ge=1, le=10)
|
||
dirty_data_filter: bool = True
|
||
thinking_mode: bool = True
|
||
preset_prompt: str = Field(default=DEFAULT_PRESET_PROMPT, min_length=1, max_length=4000)
|
||
|
||
|
||
def normalize_model_type(model_type: str | None, model_name: str | None) -> str:
|
||
"""Normalize model type, with keyword fallback for legacy records."""
|
||
if model_type in VALID_MODEL_TYPES and model_type != "chat":
|
||
return model_type
|
||
|
||
normalized_name = (model_name or "").strip().lower()
|
||
rerank_keywords = ("rerank", "bce-reranker", "gte-rerank")
|
||
embedding_keywords = (
|
||
"embedding",
|
||
"embed",
|
||
"text-embedding",
|
||
"bge-",
|
||
"bge_m3",
|
||
"gte-",
|
||
"m3e",
|
||
"e5-",
|
||
"jina-embeddings",
|
||
)
|
||
vlm_keywords = ("vl", "vision", "visual", "multimodal", "qwen-vl", "gpt-4o")
|
||
|
||
if any(keyword in normalized_name for keyword in rerank_keywords):
|
||
return "rerank"
|
||
if any(keyword in normalized_name for keyword in embedding_keywords):
|
||
return "embedding"
|
||
if any(keyword in normalized_name for keyword in vlm_keywords):
|
||
return "vlm"
|
||
return model_type if model_type in VALID_MODEL_TYPES else "chat"
|
||
|
||
|
||
def is_dirty_chunk(content: str) -> bool:
|
||
"""Heuristic dirty-data filter for low-value chunks."""
|
||
normalized = re.sub(r"\s+", " ", (content or "")).strip()
|
||
if len(normalized) < 40:
|
||
return True
|
||
|
||
if len(re.sub(r"[^\u4e00-\u9fffA-Za-z0-9]", "", normalized)) < 24:
|
||
return True
|
||
|
||
lowered = normalized.lower()
|
||
if lowered in {"目录", "contents", "table of contents"}:
|
||
return True
|
||
|
||
lines = [line.strip() for line in (content or "").splitlines() if line.strip()]
|
||
if lines:
|
||
short_lines = sum(1 for line in lines if len(line) <= 18)
|
||
dotted_lines = sum(1 for line in lines if re.search(r"[·•…\.]{3,}|\s\d+$", line))
|
||
if short_lines / len(lines) > 0.7 and len(lines) >= 3:
|
||
return True
|
||
if dotted_lines / len(lines) > 0.4:
|
||
return True
|
||
|
||
punctuation_ratio = sum(1 for ch in normalized if not ch.isalnum() and not ("\u4e00" <= ch <= "\u9fff")) / max(len(normalized), 1)
|
||
if punctuation_ratio > 0.45:
|
||
return True
|
||
|
||
return False
|
||
|
||
|
||
def build_generation_prompt(chunk: Chunk, request: GenerateRequest) -> str:
|
||
"""Build user prompt for QA generation."""
|
||
thinking_instruction = (
|
||
"请先对内容做简短分析,识别核心事实、概念、关系与潜在考点,然后再生成问答。"
|
||
"分析过程只用于提高质量,不要在最终输出中暴露你的思维链。"
|
||
if request.thinking_mode
|
||
else "直接基于内容生成高质量问答。"
|
||
)
|
||
|
||
return (
|
||
f"{request.preset_prompt}\n\n"
|
||
"输出要求:\n"
|
||
f"1. 生成 {request.count} 组问答。\n"
|
||
"2. 只输出 JSON 数组,不要输出解释、标题、Markdown。\n"
|
||
'3. 每个对象结构为 {"question":"...","answer":"...","question_type":"fact|summary|reasoning"}。\n'
|
||
"4. 问题避免重复,答案避免空泛。\n"
|
||
"5. 如果内容不足以生成高质量问答,请返回空数组 []。\n"
|
||
f"6. {thinking_instruction}\n\n"
|
||
f"Chunk 名称:{chunk.name or '未命名分片'}\n"
|
||
f"Chunk 内容:\n{chunk.content}"
|
||
)
|
||
|
||
|
||
def extract_text_from_response(data: dict) -> str:
|
||
"""Extract response text from provider response."""
|
||
choices = data.get("choices") or []
|
||
if choices:
|
||
message = choices[0].get("message") or {}
|
||
content = message.get("content")
|
||
if isinstance(content, str):
|
||
return content
|
||
if isinstance(content, list):
|
||
parts = [item.get("text", "") for item in content if isinstance(item, dict)]
|
||
return "\n".join(part for part in parts if part)
|
||
return ""
|
||
|
||
|
||
def parse_generated_questions(raw_text: str) -> List[dict]:
|
||
"""Parse JSON array from model output."""
|
||
text = (raw_text or "").strip()
|
||
if not text:
|
||
return []
|
||
|
||
fenced_match = re.search(r"```json\s*(.*?)\s*```", text, flags=re.S)
|
||
if fenced_match:
|
||
text = fenced_match.group(1).strip()
|
||
|
||
if not text.startswith("["):
|
||
array_match = re.search(r"(\[\s*\{.*\}\s*\])", text, flags=re.S)
|
||
if array_match:
|
||
text = array_match.group(1)
|
||
|
||
try:
|
||
parsed = json.loads(text)
|
||
except json.JSONDecodeError:
|
||
return []
|
||
|
||
if not isinstance(parsed, list):
|
||
return []
|
||
|
||
normalized = []
|
||
for item in parsed:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
question = str(item.get("question", "")).strip()
|
||
answer = str(item.get("answer", "")).strip()
|
||
question_type = str(item.get("question_type", "fact")).strip() or "fact"
|
||
if not question or not answer:
|
||
continue
|
||
normalized.append({
|
||
"question": question,
|
||
"answer": answer,
|
||
"question_type": question_type
|
||
})
|
||
return normalized
|
||
|
||
|
||
async def call_generation_model(model: ModelConfig, prompt: str) -> str:
|
||
"""Call configured chat model for question generation."""
|
||
provider = model.provider
|
||
api_base = (model.api_base or "").rstrip("/")
|
||
api_key = model.api_key or ""
|
||
model_name = model.model_name
|
||
headers = {
|
||
"Authorization": f"Bearer {api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
payload = {
|
||
"model": model_name,
|
||
"messages": [
|
||
{
|
||
"role": "system",
|
||
"content": "你是问答数据构建助手。严格按 JSON 输出,不要输出额外说明。"
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": prompt
|
||
}
|
||
],
|
||
"temperature": 0.4,
|
||
"response_format": {"type": "json_object"}
|
||
}
|
||
|
||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||
if provider == "minimax":
|
||
response = await client.post(
|
||
f"{api_base}/chat/completions_v2",
|
||
headers=headers,
|
||
json={k: v for k, v in payload.items() if k != "response_format"}
|
||
)
|
||
else:
|
||
response = await client.post(
|
||
f"{api_base}/chat/completions",
|
||
headers=headers,
|
||
json=payload
|
||
)
|
||
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
content = extract_text_from_response(data)
|
||
if not content:
|
||
raise ValueError("Model returned empty content")
|
||
|
||
if content.lstrip().startswith("{"):
|
||
obj = json.loads(content)
|
||
if isinstance(obj, dict) and isinstance(obj.get("questions"), list):
|
||
return json.dumps(obj["questions"], ensure_ascii=False)
|
||
return content
|
||
|
||
|
||
async def process_generate_async(project_id: UUID, request: GenerateRequest):
|
||
"""Generate QA pairs in background."""
|
||
async with AsyncSessionLocal() as db:
|
||
try:
|
||
model_result = await db.execute(
|
||
select(ModelConfig).where(ModelConfig.id == request.model_id, ModelConfig.project_id == None) # noqa: E711
|
||
)
|
||
model = model_result.scalar_one_or_none()
|
||
if not model:
|
||
return
|
||
|
||
model_type = normalize_model_type(model.model_type, model.model_name)
|
||
if model_type not in {"chat", "vlm"}:
|
||
raise ValidationException("Selected model must be chat/vlm type", field="model_id")
|
||
|
||
chunk_result = await db.execute(
|
||
select(Chunk).where(Chunk.id.in_(request.chunk_ids), Chunk.project_id == project_id)
|
||
)
|
||
chunks = chunk_result.scalars().all()
|
||
if not chunks:
|
||
return
|
||
|
||
created_count = 0
|
||
skipped_count = 0
|
||
for chunk in chunks:
|
||
if request.dirty_data_filter and is_dirty_chunk(chunk.content):
|
||
skipped_count += 1
|
||
continue
|
||
|
||
prompt = build_generation_prompt(chunk, request)
|
||
raw_text = await call_generation_model(model, prompt)
|
||
qa_pairs = parse_generated_questions(raw_text)[:request.count]
|
||
if not qa_pairs:
|
||
skipped_count += 1
|
||
continue
|
||
|
||
for item in qa_pairs:
|
||
db.add(Question(
|
||
project_id=project_id,
|
||
chunk_id=chunk.id,
|
||
content=item["question"],
|
||
answer=item["answer"],
|
||
question_type=item["question_type"],
|
||
source="generated"
|
||
))
|
||
created_count += 1
|
||
|
||
await db.commit()
|
||
|
||
log_success(
|
||
"问答批量生成完成",
|
||
project_id=str(project_id),
|
||
model_id=str(model.id),
|
||
chunk_count=len(chunks),
|
||
created_questions=created_count,
|
||
skipped_chunks=skipped_count
|
||
)
|
||
except Exception as e:
|
||
log_failure(
|
||
"问答批量生成失败",
|
||
project_id=str(project_id),
|
||
model_id=str(request.model_id),
|
||
error=str(e)
|
||
)
|
||
|
||
|
||
@router.post("/generate", response_model=ApiResponse)
|
||
async def generate_questions(
|
||
project_id: UUID,
|
||
request: GenerateRequest,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""Generate questions from chunks using LLM in background."""
|
||
model_result = await db.execute(
|
||
select(ModelConfig).where(ModelConfig.id == request.model_id, ModelConfig.project_id == None) # noqa: E711
|
||
)
|
||
model = model_result.scalar_one_or_none()
|
||
if not model:
|
||
raise ValidationException("Selected model not found", field="model_id")
|
||
|
||
model_type = normalize_model_type(model.model_type, model.model_name)
|
||
if model_type not in {"chat", "vlm"}:
|
||
raise ValidationException("Selected model must be chat/vlm type", field="model_id")
|
||
if not model.api_key:
|
||
raise ValidationException("Selected model is missing API Key", field="model_id")
|
||
|
||
chunk_result = await db.execute(
|
||
select(Chunk.id).where(Chunk.id.in_(request.chunk_ids), Chunk.project_id == project_id)
|
||
)
|
||
valid_chunk_ids = [row[0] for row in chunk_result.all()]
|
||
if not valid_chunk_ids:
|
||
raise ValidationException("No valid chunks found", field="chunk_ids")
|
||
|
||
request_payload = request.model_copy(update={"chunk_ids": valid_chunk_ids})
|
||
asyncio.create_task(process_generate_async(project_id, request_payload))
|
||
|
||
return ApiResponse.ok(
|
||
data={"chunk_count": len(valid_chunk_ids), "status": "processing"},
|
||
message="Question generation started in background"
|
||
)
|
||
|
||
|
||
@router.get("", response_model=ApiResponse)
|
||
async def list_questions(
|
||
project_id: UUID,
|
||
chunk_id: Optional[UUID] = Query(None),
|
||
page: int = Query(1, ge=1),
|
||
page_size: int = Query(20, ge=1, le=100),
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""List questions for a project"""
|
||
filters = {"project_id": project_id}
|
||
if chunk_id:
|
||
filters["chunk_id"] = chunk_id
|
||
|
||
skip = (page - 1) * page_size
|
||
questions, total = await question_crud.get_multi(
|
||
db,
|
||
skip=skip,
|
||
limit=page_size,
|
||
filters=filters,
|
||
order_by="created_at",
|
||
descending=True
|
||
)
|
||
|
||
question_responses = [QuestionResponse.model_validate(q) for q in questions]
|
||
return PaginatedResponse.ok(
|
||
items=question_responses,
|
||
page=page,
|
||
page_size=page_size,
|
||
total=total
|
||
)
|
||
|
||
|
||
@router.put("/{question_id}", response_model=ApiResponse)
|
||
async def update_question(
|
||
project_id: UUID,
|
||
question_id: UUID,
|
||
question: QuestionCreateSchema,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""Update question"""
|
||
db_question = await question_crud.get(db, question_id)
|
||
if not db_question or db_question.project_id != project_id:
|
||
raise NotFoundException("Question", question_id)
|
||
|
||
updated_question = await question_crud.update(db, db_question, question)
|
||
return ApiResponse.ok(
|
||
data=QuestionResponse.model_validate(updated_question),
|
||
message="Question updated successfully"
|
||
)
|
||
|
||
|
||
@router.delete("/{question_id}", response_model=ApiResponse)
|
||
async def delete_question(
|
||
project_id: UUID,
|
||
question_id: UUID,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""Delete question"""
|
||
question = await question_crud.get(db, question_id)
|
||
if not question or question.project_id != project_id:
|
||
raise NotFoundException("Question", question_id)
|
||
|
||
await question_crud.delete(db, question_id)
|
||
return ApiResponse.ok(message="Question deleted successfully")
|