refactor: 前端架构重构 - 提取 CSS 和逻辑到独立模块
前端重构: - 删除旧的大体积 Vue 组件(HomeView, FileManage, TextSplit 等) - 删除旧的 composables(useFormatters, useModels, useProjects) - 新增 core/, page-logic/, pages/, shared/ 模块化目录结构 - 提取 CSS 到 styles/pages/ 目录 - 添加全局样式 variables.css 和 common.css 后端 API 更新: - chunks: 语义分割 API 增强 - files: 文件处理 API 更新 - models: 模型管理 API 更新 - questions: 问答管理 API 更新 - database: 数据库连接优化 - semantic_embedding: 语义嵌入服务优化 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -11,7 +11,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.api.response import ApiResponse, PaginatedResponse
|
||||
from app.core.database import get_db
|
||||
from app.core.database import get_db, AsyncSessionLocal
|
||||
from app.core.exceptions import NotFoundException
|
||||
from app.core.crud import CRUDBase
|
||||
from app.core.logging import log_success, log_failure
|
||||
@@ -80,6 +80,106 @@ async def process_file_by_type(file: File) -> str:
|
||||
return content
|
||||
|
||||
|
||||
async def process_split_async(
|
||||
project_id: UUID,
|
||||
request: SplitRequest,
|
||||
):
|
||||
"""Run chunk splitting in background."""
|
||||
async with AsyncSessionLocal() as db:
|
||||
file = None
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(File).where(File.id == request.file_id, File.project_id == project_id)
|
||||
)
|
||||
file = result.scalar_one_or_none()
|
||||
if not file:
|
||||
return
|
||||
|
||||
text = await process_file_by_type(file)
|
||||
|
||||
kwargs = {"chunk_size": request.chunk_size, "overlap": request.overlap}
|
||||
if request.method == "custom" and request.separator:
|
||||
kwargs["separator"] = request.separator
|
||||
|
||||
if request.method == "semantic_embedding":
|
||||
kwargs["embedding_provider_type"] = request.embedding_provider or "openai"
|
||||
kwargs["embedding_api_key"] = request.embedding_api_key
|
||||
kwargs["embedding_base_url"] = request.embedding_base_url or "https://api.minimax.chat/v1"
|
||||
kwargs["embedding_model"] = request.embedding_model or "text-embedding-3-small"
|
||||
kwargs["similarity_threshold"] = request.similarity_threshold
|
||||
kwargs["min_chunk_size"] = request.min_chunk_size
|
||||
|
||||
splitter = get_splitter(request.method, **kwargs)
|
||||
split_results = splitter.split(text)
|
||||
|
||||
await db.execute(
|
||||
Chunk.__table__.delete().where(
|
||||
Chunk.project_id == project_id,
|
||||
Chunk.file_id == file.id
|
||||
)
|
||||
)
|
||||
|
||||
chunks = []
|
||||
for chunk_data in split_results:
|
||||
db_chunk = Chunk(
|
||||
project_id=project_id,
|
||||
file_id=file.id,
|
||||
name=chunk_data.get("name", f"Chunk {chunk_data['index'] + 1}"),
|
||||
content=chunk_data["content"],
|
||||
word_count=chunk_data.get("word_count", len(chunk_data["content"].split()))
|
||||
)
|
||||
db.add(db_chunk)
|
||||
chunks.append(db_chunk)
|
||||
|
||||
await db.commit()
|
||||
|
||||
ready_dir = get_project_ready_dir(str(project_id))
|
||||
|
||||
# 删除旧的 markdown 文件(可能有两种命名格式)
|
||||
old_md_files = list(ready_dir.glob(f"{file.id}*.md"))
|
||||
for old_file in old_md_files:
|
||||
try:
|
||||
old_file.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
md_filename = f"{file.id}.md"
|
||||
md_path = ready_dir / md_filename
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: md_path.write_text(text, encoding='utf-8')
|
||||
)
|
||||
|
||||
file.file_path = str(md_path)
|
||||
file.status = "completed"
|
||||
await db.commit()
|
||||
|
||||
log_success(
|
||||
"文件分割完成",
|
||||
project_id=str(project_id),
|
||||
file_id=str(file.id),
|
||||
filename=file.filename,
|
||||
method=request.method,
|
||||
chunk_count=len(chunks),
|
||||
text_length=len(text),
|
||||
ready_path=str(md_path)
|
||||
)
|
||||
except Exception as e:
|
||||
if file:
|
||||
file.status = "failed"
|
||||
await db.commit()
|
||||
|
||||
log_failure(
|
||||
"文件分割失败",
|
||||
project_id=str(project_id),
|
||||
file_id=str(request.file_id),
|
||||
method=request.method,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
|
||||
@router.post("/split", response_model=ApiResponse)
|
||||
async def split_text(
|
||||
project_id: UUID,
|
||||
@@ -88,7 +188,6 @@ async def split_text(
|
||||
):
|
||||
"""Split text into chunks"""
|
||||
try:
|
||||
# Get file
|
||||
result = await db.execute(
|
||||
select(File).where(File.id == request.file_id, File.project_id == project_id)
|
||||
)
|
||||
@@ -107,81 +206,27 @@ async def split_text(
|
||||
overlap=request.overlap
|
||||
)
|
||||
|
||||
# Process file
|
||||
text = await process_file_by_type(file)
|
||||
|
||||
# Update file status
|
||||
file.status = "processing"
|
||||
await db.commit()
|
||||
|
||||
# Split text
|
||||
kwargs = {"chunk_size": request.chunk_size, "overlap": request.overlap}
|
||||
if request.method == "custom" and request.separator:
|
||||
kwargs["separator"] = request.separator
|
||||
|
||||
# 如果使用 semantic_embedding 方法,传递 embedding 参数
|
||||
if request.method == "semantic_embedding":
|
||||
kwargs["embedding_provider_type"] = request.embedding_provider or "openai"
|
||||
kwargs["embedding_api_key"] = request.embedding_api_key
|
||||
kwargs["embedding_base_url"] = request.embedding_base_url or "https://api.minimax.chat/v1"
|
||||
kwargs["embedding_model"] = request.embedding_model or "text-embedding-3-small"
|
||||
kwargs["similarity_threshold"] = request.similarity_threshold
|
||||
kwargs["min_chunk_size"] = request.min_chunk_size
|
||||
|
||||
splitter = get_splitter(request.method, **kwargs)
|
||||
split_results = splitter.split(text)
|
||||
|
||||
# Save chunks
|
||||
chunks = []
|
||||
for chunk_data in split_results:
|
||||
db_chunk = Chunk(
|
||||
asyncio.create_task(
|
||||
process_split_async(
|
||||
project_id=project_id,
|
||||
file_id=file.id,
|
||||
name=chunk_data.get("name", f"Chunk {chunk_data['index'] + 1}"),
|
||||
content=chunk_data["content"],
|
||||
word_count=chunk_data.get("word_count", len(chunk_data["content"].split()))
|
||||
request=request,
|
||||
)
|
||||
db.add(db_chunk)
|
||||
chunks.append(db_chunk)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Save processed markdown to ready directory
|
||||
ready_dir = get_project_ready_dir(str(project_id))
|
||||
md_filename = f"{file.id}_{file.filename}.md"
|
||||
md_path = ready_dir / md_filename
|
||||
|
||||
# Write markdown content to file
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: md_path.write_text(text, encoding='utf-8')
|
||||
)
|
||||
|
||||
# Update file path to ready location
|
||||
file.file_path = str(md_path)
|
||||
file.status = "completed"
|
||||
await db.commit()
|
||||
|
||||
# 记录成功日志
|
||||
log_success(
|
||||
"文件处理完成",
|
||||
project_id=str(project_id),
|
||||
file_id=str(file.id),
|
||||
filename=file.filename,
|
||||
chunk_count=len(chunks),
|
||||
text_length=len(text),
|
||||
ready_path=str(md_path)
|
||||
)
|
||||
|
||||
return ApiResponse.ok(
|
||||
data={"chunks": len(chunks)},
|
||||
message=f"Successfully split into {len(chunks)} chunks"
|
||||
data={"file_id": str(file.id), "status": file.status},
|
||||
message="Split task started, processing in background"
|
||||
)
|
||||
except Exception as e:
|
||||
# 记录失败日志
|
||||
if 'file' in locals() and file:
|
||||
file.status = "failed"
|
||||
await db.commit()
|
||||
|
||||
log_failure(
|
||||
"文件处理失败",
|
||||
"分割任务启动失败",
|
||||
project_id=str(project_id),
|
||||
file_id=str(request.file_id),
|
||||
error=str(e)
|
||||
|
||||
@@ -9,6 +9,7 @@ from uuid import UUID, uuid4
|
||||
from fastapi import APIRouter, Depends, UploadFile, File, Query
|
||||
from fastapi.responses import FileResponse, PlainTextResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.api.response import ApiResponse, PaginatedResponse
|
||||
from app.core.config import get_settings
|
||||
@@ -17,6 +18,7 @@ from app.core.exceptions import ValidationException, NotFoundException
|
||||
from app.core.crud import CRUDBase
|
||||
from app.core.logging import log_success, log_failure
|
||||
from app.models.models import File as FileModel
|
||||
from app.models.models import Chunk, Question
|
||||
from app.schemas.file import FileResponse, FileCreateSchema
|
||||
from markitdown import MarkItDown
|
||||
|
||||
@@ -329,11 +331,27 @@ async def delete_file(
|
||||
file_id: UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Delete file"""
|
||||
"""Delete file and all related data (markdown, chunks, questions)"""
|
||||
file = await file_crud.get(db, file_id)
|
||||
if not file or file.project_id != project_id:
|
||||
raise NotFoundException("File", file_id)
|
||||
|
||||
# Delete related chunks and their questions (explicit deletion for safety)
|
||||
chunks_result = await db.execute(
|
||||
select(Chunk).where(Chunk.file_id == file_id)
|
||||
)
|
||||
chunks = chunks_result.scalars().all()
|
||||
for chunk in chunks:
|
||||
# Delete questions related to this chunk
|
||||
questions_result = await db.execute(
|
||||
select(Question).where(Question.chunk_id == chunk.id)
|
||||
)
|
||||
questions = questions_result.scalars().all()
|
||||
for question in questions:
|
||||
await db.delete(question)
|
||||
# Delete chunk
|
||||
await db.delete(chunk)
|
||||
|
||||
# Delete file from raw directory
|
||||
if file.file_path and os.path.exists(file.file_path):
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
@@ -342,16 +360,27 @@ async def delete_file(
|
||||
file.file_path
|
||||
)
|
||||
|
||||
# Delete file from ready directory (processed markdown)
|
||||
ready_path = Path("/data/code/YG-Datasets/data") / str(project_id) / "ready" / f"{file_id}.md"
|
||||
if ready_path.exists():
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
os.remove,
|
||||
str(ready_path)
|
||||
)
|
||||
# Delete file from ready directory (processed markdown) - try both naming conventions
|
||||
ready_dir = Path("/data/code/YG-Datasets/data") / str(project_id) / "ready"
|
||||
if ready_dir.exists():
|
||||
# Try file_id.md (from upload process)
|
||||
ready_path = ready_dir / f"{file_id}.md"
|
||||
if ready_path.exists():
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
os.remove,
|
||||
str(ready_path)
|
||||
)
|
||||
# Try file_id_filename.md (from split process)
|
||||
for md_file in ready_dir.glob(f"{file_id}_*.md"):
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
os.remove,
|
||||
str(md_file)
|
||||
)
|
||||
|
||||
await file_crud.delete(db, file_id)
|
||||
await db.commit()
|
||||
return ApiResponse.ok(message="File deleted successfully")
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,38 @@ from app.schemas.model import ModelCreate, ModelUpdate, ModelResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
VALID_MODEL_TYPES = {"chat", "vlm", "embedding", "rerank"}
|
||||
|
||||
|
||||
def normalize_model_type(model_type: str | None, model_name: str | None) -> str:
|
||||
"""Normalize model type, with keyword fallback for legacy records."""
|
||||
if model_type in VALID_MODEL_TYPES and model_type != "chat":
|
||||
return model_type
|
||||
|
||||
normalized_name = (model_name or "").strip().lower()
|
||||
|
||||
rerank_keywords = ("rerank", "bce-reranker", "gte-rerank")
|
||||
embedding_keywords = (
|
||||
"embedding",
|
||||
"embed",
|
||||
"text-embedding",
|
||||
"bge-",
|
||||
"bge_m3",
|
||||
"gte-",
|
||||
"m3e",
|
||||
"e5-",
|
||||
"jina-embeddings",
|
||||
)
|
||||
vlm_keywords = ("vl", "vision", "visual", "multimodal", "qwen-vl", "gpt-4o")
|
||||
|
||||
if any(keyword in normalized_name for keyword in rerank_keywords):
|
||||
return "rerank"
|
||||
if any(keyword in normalized_name for keyword in embedding_keywords):
|
||||
return "embedding"
|
||||
if any(keyword in normalized_name for keyword in vlm_keywords):
|
||||
return "vlm"
|
||||
return model_type if model_type in VALID_MODEL_TYPES else "chat"
|
||||
|
||||
|
||||
async def test_model_connection(model: ModelConfig) -> dict:
|
||||
"""Test model connection by calling the API"""
|
||||
@@ -23,6 +55,7 @@ async def test_model_connection(model: ModelConfig) -> dict:
|
||||
api_base = model.api_base or ""
|
||||
provider = model.provider
|
||||
model_name = model.model_name
|
||||
model_type = normalize_model_type(model.model_type, model_name)
|
||||
api_key = model.api_key
|
||||
|
||||
headers = {
|
||||
@@ -32,7 +65,7 @@ async def test_model_connection(model: ModelConfig) -> dict:
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
if provider == "openai":
|
||||
if model_type in {"chat", "vlm"} and provider in {"openai", "ali"}:
|
||||
# OpenAI compatible API test
|
||||
response = await client.post(
|
||||
f"{api_base.rstrip('/')}/chat/completions",
|
||||
@@ -43,7 +76,7 @@ async def test_model_connection(model: ModelConfig) -> dict:
|
||||
"max_tokens": 5
|
||||
}
|
||||
)
|
||||
elif provider == "minimax":
|
||||
elif model_type in {"chat", "vlm"} and provider == "minimax":
|
||||
# MiniMax API test
|
||||
response = await client.post(
|
||||
f"{api_base.rstrip('/')}/chat/completions_v2",
|
||||
@@ -56,7 +89,7 @@ async def test_model_connection(model: ModelConfig) -> dict:
|
||||
"messages": [{"role": "user", "content": "Hi"}]
|
||||
}
|
||||
)
|
||||
elif provider == "glm":
|
||||
elif model_type in {"chat", "vlm"} and provider == "glm":
|
||||
# GLM API test
|
||||
response = await client.post(
|
||||
f"{api_base.rstrip('/')}/chat/completions",
|
||||
@@ -66,8 +99,21 @@ async def test_model_connection(model: ModelConfig) -> dict:
|
||||
"messages": [{"role": "user", "content": "Hi"}]
|
||||
}
|
||||
)
|
||||
elif model_type == "embedding" and provider in {"openai", "ali", "glm"}:
|
||||
response = await client.post(
|
||||
f"{api_base.rstrip('/')}/embeddings",
|
||||
headers=headers,
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": "test"
|
||||
}
|
||||
)
|
||||
elif model_type == "embedding" and provider == "minimax":
|
||||
return {"success": False, "message": "MiniMax embedding 自动测试暂未接入,请手动确认端点与模型"}
|
||||
elif model_type == "rerank":
|
||||
return {"success": False, "message": "Rerank 自动测试暂未接入,请先保存配置并在实际流程中验证"}
|
||||
else:
|
||||
return {"success": False, "message": f"Unsupported provider: {provider}"}
|
||||
return {"success": False, "message": f"Unsupported provider/type: {provider}/{model_type}"}
|
||||
|
||||
if response.status_code == 200:
|
||||
return {"success": True, "message": "Connection successful"}
|
||||
@@ -114,6 +160,7 @@ async def create_model(model: ModelCreate, db: AsyncSession = Depends(get_db)):
|
||||
|
||||
db_model = ModelConfig(
|
||||
provider=model.provider,
|
||||
model_type=model.model_type,
|
||||
model_name=model.model_name,
|
||||
api_key=model.api_key,
|
||||
api_base=model.api_base,
|
||||
@@ -248,6 +295,7 @@ async def test_model(model_id: str, db: AsyncSession = Depends(get_db)):
|
||||
test_result = await test_model_connection(model)
|
||||
|
||||
# Save connection status to database
|
||||
model.model_type = normalize_model_type(model.model_type, model.model_name)
|
||||
model.connection_status = "connected" if test_result["success"] else "disconnected"
|
||||
await db.commit()
|
||||
await db.refresh(model)
|
||||
|
||||
@@ -1,31 +1,303 @@
|
||||
"""
|
||||
Questions API Router
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.response import ApiResponse, PaginatedResponse
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import NotFoundException, ValidationException
|
||||
from app.core.crud import CRUDBase
|
||||
from app.models.models import Question, Chunk
|
||||
from app.schemas.question import QuestionResponse
|
||||
from app.schemas.question import QuestionCreateSchema
|
||||
from app.core.database import AsyncSessionLocal, get_db
|
||||
from app.core.exceptions import NotFoundException, ValidationException
|
||||
from app.core.logging import log_failure, log_success
|
||||
from app.models.models import Chunk, ModelConfig, Question
|
||||
from app.schemas.question import QuestionCreateSchema, QuestionResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Initialize CRUD
|
||||
question_crud = CRUDBase(Question)
|
||||
|
||||
VALID_MODEL_TYPES = {"chat", "vlm", "embedding", "rerank"}
|
||||
|
||||
DEFAULT_PRESET_PROMPT = (
|
||||
"你是一名高质量中文问答数据构建助手。"
|
||||
"请基于给定 chunk 内容生成准确、自然、可用于训练的数据集问答对。"
|
||||
"问题必须清晰具体,答案必须直接来自内容或基于内容做合理概括,"
|
||||
"不要编造原文没有的信息,不要输出与目录、导航、页眉页脚、噪声文字相关的问题。"
|
||||
)
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
"""Request model for generating questions"""
|
||||
chunk_ids: List[UUID] = Field(..., min_length=1)
|
||||
count: int = Field(5, ge=1, le=50)
|
||||
question_types: List[str] = ["fact", "summary"]
|
||||
model_id: UUID
|
||||
count: int = Field(3, ge=1, le=10)
|
||||
dirty_data_filter: bool = True
|
||||
thinking_mode: bool = True
|
||||
preset_prompt: str = Field(default=DEFAULT_PRESET_PROMPT, min_length=1, max_length=4000)
|
||||
|
||||
|
||||
def normalize_model_type(model_type: str | None, model_name: str | None) -> str:
|
||||
"""Normalize model type, with keyword fallback for legacy records."""
|
||||
if model_type in VALID_MODEL_TYPES and model_type != "chat":
|
||||
return model_type
|
||||
|
||||
normalized_name = (model_name or "").strip().lower()
|
||||
rerank_keywords = ("rerank", "bce-reranker", "gte-rerank")
|
||||
embedding_keywords = (
|
||||
"embedding",
|
||||
"embed",
|
||||
"text-embedding",
|
||||
"bge-",
|
||||
"bge_m3",
|
||||
"gte-",
|
||||
"m3e",
|
||||
"e5-",
|
||||
"jina-embeddings",
|
||||
)
|
||||
vlm_keywords = ("vl", "vision", "visual", "multimodal", "qwen-vl", "gpt-4o")
|
||||
|
||||
if any(keyword in normalized_name for keyword in rerank_keywords):
|
||||
return "rerank"
|
||||
if any(keyword in normalized_name for keyword in embedding_keywords):
|
||||
return "embedding"
|
||||
if any(keyword in normalized_name for keyword in vlm_keywords):
|
||||
return "vlm"
|
||||
return model_type if model_type in VALID_MODEL_TYPES else "chat"
|
||||
|
||||
|
||||
def is_dirty_chunk(content: str) -> bool:
|
||||
"""Heuristic dirty-data filter for low-value chunks."""
|
||||
normalized = re.sub(r"\s+", " ", (content or "")).strip()
|
||||
if len(normalized) < 40:
|
||||
return True
|
||||
|
||||
if len(re.sub(r"[^\u4e00-\u9fffA-Za-z0-9]", "", normalized)) < 24:
|
||||
return True
|
||||
|
||||
lowered = normalized.lower()
|
||||
if lowered in {"目录", "contents", "table of contents"}:
|
||||
return True
|
||||
|
||||
lines = [line.strip() for line in (content or "").splitlines() if line.strip()]
|
||||
if lines:
|
||||
short_lines = sum(1 for line in lines if len(line) <= 18)
|
||||
dotted_lines = sum(1 for line in lines if re.search(r"[·•…\.]{3,}|\s\d+$", line))
|
||||
if short_lines / len(lines) > 0.7 and len(lines) >= 3:
|
||||
return True
|
||||
if dotted_lines / len(lines) > 0.4:
|
||||
return True
|
||||
|
||||
punctuation_ratio = sum(1 for ch in normalized if not ch.isalnum() and not ("\u4e00" <= ch <= "\u9fff")) / max(len(normalized), 1)
|
||||
if punctuation_ratio > 0.45:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def build_generation_prompt(chunk: Chunk, request: GenerateRequest) -> str:
|
||||
"""Build user prompt for QA generation."""
|
||||
thinking_instruction = (
|
||||
"请先对内容做简短分析,识别核心事实、概念、关系与潜在考点,然后再生成问答。"
|
||||
"分析过程只用于提高质量,不要在最终输出中暴露你的思维链。"
|
||||
if request.thinking_mode
|
||||
else "直接基于内容生成高质量问答。"
|
||||
)
|
||||
|
||||
return (
|
||||
f"{request.preset_prompt}\n\n"
|
||||
"输出要求:\n"
|
||||
f"1. 生成 {request.count} 组问答。\n"
|
||||
"2. 只输出 JSON 数组,不要输出解释、标题、Markdown。\n"
|
||||
'3. 每个对象结构为 {"question":"...","answer":"...","question_type":"fact|summary|reasoning"}。\n'
|
||||
"4. 问题避免重复,答案避免空泛。\n"
|
||||
"5. 如果内容不足以生成高质量问答,请返回空数组 []。\n"
|
||||
f"6. {thinking_instruction}\n\n"
|
||||
f"Chunk 名称:{chunk.name or '未命名分片'}\n"
|
||||
f"Chunk 内容:\n{chunk.content}"
|
||||
)
|
||||
|
||||
|
||||
def extract_text_from_response(data: dict) -> str:
|
||||
"""Extract response text from provider response."""
|
||||
choices = data.get("choices") or []
|
||||
if choices:
|
||||
message = choices[0].get("message") or {}
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = [item.get("text", "") for item in content if isinstance(item, dict)]
|
||||
return "\n".join(part for part in parts if part)
|
||||
return ""
|
||||
|
||||
|
||||
def parse_generated_questions(raw_text: str) -> List[dict]:
|
||||
"""Parse JSON array from model output."""
|
||||
text = (raw_text or "").strip()
|
||||
if not text:
|
||||
return []
|
||||
|
||||
fenced_match = re.search(r"```json\s*(.*?)\s*```", text, flags=re.S)
|
||||
if fenced_match:
|
||||
text = fenced_match.group(1).strip()
|
||||
|
||||
if not text.startswith("["):
|
||||
array_match = re.search(r"(\[\s*\{.*\}\s*\])", text, flags=re.S)
|
||||
if array_match:
|
||||
text = array_match.group(1)
|
||||
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
|
||||
if not isinstance(parsed, list):
|
||||
return []
|
||||
|
||||
normalized = []
|
||||
for item in parsed:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
question = str(item.get("question", "")).strip()
|
||||
answer = str(item.get("answer", "")).strip()
|
||||
question_type = str(item.get("question_type", "fact")).strip() or "fact"
|
||||
if not question or not answer:
|
||||
continue
|
||||
normalized.append({
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"question_type": question_type
|
||||
})
|
||||
return normalized
|
||||
|
||||
|
||||
async def call_generation_model(model: ModelConfig, prompt: str) -> str:
|
||||
"""Call configured chat model for question generation."""
|
||||
provider = model.provider
|
||||
api_base = (model.api_base or "").rstrip("/")
|
||||
api_key = model.api_key or ""
|
||||
model_name = model.model_name
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是问答数据构建助手。严格按 JSON 输出,不要输出额外说明。"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}
|
||||
],
|
||||
"temperature": 0.4,
|
||||
"response_format": {"type": "json_object"}
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
if provider == "minimax":
|
||||
response = await client.post(
|
||||
f"{api_base}/chat/completions_v2",
|
||||
headers=headers,
|
||||
json={k: v for k, v in payload.items() if k != "response_format"}
|
||||
)
|
||||
else:
|
||||
response = await client.post(
|
||||
f"{api_base}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
content = extract_text_from_response(data)
|
||||
if not content:
|
||||
raise ValueError("Model returned empty content")
|
||||
|
||||
if content.lstrip().startswith("{"):
|
||||
obj = json.loads(content)
|
||||
if isinstance(obj, dict) and isinstance(obj.get("questions"), list):
|
||||
return json.dumps(obj["questions"], ensure_ascii=False)
|
||||
return content
|
||||
|
||||
|
||||
async def process_generate_async(project_id: UUID, request: GenerateRequest):
|
||||
"""Generate QA pairs in background."""
|
||||
async with AsyncSessionLocal() as db:
|
||||
try:
|
||||
model_result = await db.execute(
|
||||
select(ModelConfig).where(ModelConfig.id == request.model_id, ModelConfig.project_id == None) # noqa: E711
|
||||
)
|
||||
model = model_result.scalar_one_or_none()
|
||||
if not model:
|
||||
return
|
||||
|
||||
model_type = normalize_model_type(model.model_type, model.model_name)
|
||||
if model_type not in {"chat", "vlm"}:
|
||||
raise ValidationException("Selected model must be chat/vlm type", field="model_id")
|
||||
|
||||
chunk_result = await db.execute(
|
||||
select(Chunk).where(Chunk.id.in_(request.chunk_ids), Chunk.project_id == project_id)
|
||||
)
|
||||
chunks = chunk_result.scalars().all()
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
created_count = 0
|
||||
skipped_count = 0
|
||||
for chunk in chunks:
|
||||
if request.dirty_data_filter and is_dirty_chunk(chunk.content):
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
prompt = build_generation_prompt(chunk, request)
|
||||
raw_text = await call_generation_model(model, prompt)
|
||||
qa_pairs = parse_generated_questions(raw_text)[:request.count]
|
||||
if not qa_pairs:
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
for item in qa_pairs:
|
||||
db.add(Question(
|
||||
project_id=project_id,
|
||||
chunk_id=chunk.id,
|
||||
content=item["question"],
|
||||
answer=item["answer"],
|
||||
question_type=item["question_type"],
|
||||
source="generated"
|
||||
))
|
||||
created_count += 1
|
||||
|
||||
await db.commit()
|
||||
|
||||
log_success(
|
||||
"问答批量生成完成",
|
||||
project_id=str(project_id),
|
||||
model_id=str(model.id),
|
||||
chunk_count=len(chunks),
|
||||
created_questions=created_count,
|
||||
skipped_chunks=skipped_count
|
||||
)
|
||||
except Exception as e:
|
||||
log_failure(
|
||||
"问答批量生成失败",
|
||||
project_id=str(project_id),
|
||||
model_id=str(request.model_id),
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
|
||||
@router.post("/generate", response_model=ApiResponse)
|
||||
@@ -34,36 +306,33 @@ async def generate_questions(
|
||||
request: GenerateRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Generate questions from chunks using LLM"""
|
||||
# Get chunks
|
||||
result = await db.execute(
|
||||
select(Chunk).where(Chunk.id.in_(request.chunk_ids), Chunk.project_id == project_id)
|
||||
"""Generate questions from chunks using LLM in background."""
|
||||
model_result = await db.execute(
|
||||
select(ModelConfig).where(ModelConfig.id == request.model_id, ModelConfig.project_id == None) # noqa: E711
|
||||
)
|
||||
chunks = result.scalars().all()
|
||||
model = model_result.scalar_one_or_none()
|
||||
if not model:
|
||||
raise ValidationException("Selected model not found", field="model_id")
|
||||
|
||||
if not chunks:
|
||||
model_type = normalize_model_type(model.model_type, model.model_name)
|
||||
if model_type not in {"chat", "vlm"}:
|
||||
raise ValidationException("Selected model must be chat/vlm type", field="model_id")
|
||||
if not model.api_key:
|
||||
raise ValidationException("Selected model is missing API Key", field="model_id")
|
||||
|
||||
chunk_result = await db.execute(
|
||||
select(Chunk.id).where(Chunk.id.in_(request.chunk_ids), Chunk.project_id == project_id)
|
||||
)
|
||||
valid_chunk_ids = [row[0] for row in chunk_result.all()]
|
||||
if not valid_chunk_ids:
|
||||
raise ValidationException("No valid chunks found", field="chunk_ids")
|
||||
|
||||
# Create sample questions (placeholder for LLM-based generation)
|
||||
created_questions = []
|
||||
for chunk in chunks:
|
||||
for i in range(request.count):
|
||||
question = Question(
|
||||
project_id=project_id,
|
||||
chunk_id=chunk.id,
|
||||
content=f"这是关于「{chunk.name}」的问题 {i+1}?",
|
||||
answer=f"这是问题 {i+1} 的答案。",
|
||||
question_type=request.question_types[0] if request.question_types else "fact",
|
||||
source="generated"
|
||||
)
|
||||
db.add(question)
|
||||
created_questions.append(question)
|
||||
|
||||
await db.commit()
|
||||
request_payload = request.model_copy(update={"chunk_ids": valid_chunk_ids})
|
||||
asyncio.create_task(process_generate_async(project_id, request_payload))
|
||||
|
||||
return ApiResponse.ok(
|
||||
data={"questions": len(created_questions)},
|
||||
message=f"Successfully generated {len(created_questions)} questions"
|
||||
data={"chunk_count": len(valid_chunk_ids), "status": "processing"},
|
||||
message="Question generation started in background"
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user