refactor: 前端架构重构 - 提取 CSS 和逻辑到独立模块
前端重构: - 删除旧的大体积 Vue 组件(HomeView, FileManage, TextSplit 等) - 删除旧的 composables(useFormatters, useModels, useProjects) - 新增 core/, page-logic/, pages/, shared/ 模块化目录结构 - 提取 CSS 到 styles/pages/ 目录 - 添加全局样式 variables.css 和 common.css 后端 API 更新: - chunks: 语义分割 API 增强 - files: 文件处理 API 更新 - models: 模型管理 API 更新 - questions: 问答管理 API 更新 - database: 数据库连接优化 - semantic_embedding: 语义嵌入服务优化 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -11,7 +11,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.api.response import ApiResponse, PaginatedResponse
|
||||
from app.core.database import get_db
|
||||
from app.core.database import get_db, AsyncSessionLocal
|
||||
from app.core.exceptions import NotFoundException
|
||||
from app.core.crud import CRUDBase
|
||||
from app.core.logging import log_success, log_failure
|
||||
@@ -80,6 +80,106 @@ async def process_file_by_type(file: File) -> str:
|
||||
return content
|
||||
|
||||
|
||||
async def process_split_async(
|
||||
project_id: UUID,
|
||||
request: SplitRequest,
|
||||
):
|
||||
"""Run chunk splitting in background."""
|
||||
async with AsyncSessionLocal() as db:
|
||||
file = None
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(File).where(File.id == request.file_id, File.project_id == project_id)
|
||||
)
|
||||
file = result.scalar_one_or_none()
|
||||
if not file:
|
||||
return
|
||||
|
||||
text = await process_file_by_type(file)
|
||||
|
||||
kwargs = {"chunk_size": request.chunk_size, "overlap": request.overlap}
|
||||
if request.method == "custom" and request.separator:
|
||||
kwargs["separator"] = request.separator
|
||||
|
||||
if request.method == "semantic_embedding":
|
||||
kwargs["embedding_provider_type"] = request.embedding_provider or "openai"
|
||||
kwargs["embedding_api_key"] = request.embedding_api_key
|
||||
kwargs["embedding_base_url"] = request.embedding_base_url or "https://api.minimax.chat/v1"
|
||||
kwargs["embedding_model"] = request.embedding_model or "text-embedding-3-small"
|
||||
kwargs["similarity_threshold"] = request.similarity_threshold
|
||||
kwargs["min_chunk_size"] = request.min_chunk_size
|
||||
|
||||
splitter = get_splitter(request.method, **kwargs)
|
||||
split_results = splitter.split(text)
|
||||
|
||||
await db.execute(
|
||||
Chunk.__table__.delete().where(
|
||||
Chunk.project_id == project_id,
|
||||
Chunk.file_id == file.id
|
||||
)
|
||||
)
|
||||
|
||||
chunks = []
|
||||
for chunk_data in split_results:
|
||||
db_chunk = Chunk(
|
||||
project_id=project_id,
|
||||
file_id=file.id,
|
||||
name=chunk_data.get("name", f"Chunk {chunk_data['index'] + 1}"),
|
||||
content=chunk_data["content"],
|
||||
word_count=chunk_data.get("word_count", len(chunk_data["content"].split()))
|
||||
)
|
||||
db.add(db_chunk)
|
||||
chunks.append(db_chunk)
|
||||
|
||||
await db.commit()
|
||||
|
||||
ready_dir = get_project_ready_dir(str(project_id))
|
||||
|
||||
# 删除旧的 markdown 文件(可能有两种命名格式)
|
||||
old_md_files = list(ready_dir.glob(f"{file.id}*.md"))
|
||||
for old_file in old_md_files:
|
||||
try:
|
||||
old_file.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
md_filename = f"{file.id}.md"
|
||||
md_path = ready_dir / md_filename
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: md_path.write_text(text, encoding='utf-8')
|
||||
)
|
||||
|
||||
file.file_path = str(md_path)
|
||||
file.status = "completed"
|
||||
await db.commit()
|
||||
|
||||
log_success(
|
||||
"文件分割完成",
|
||||
project_id=str(project_id),
|
||||
file_id=str(file.id),
|
||||
filename=file.filename,
|
||||
method=request.method,
|
||||
chunk_count=len(chunks),
|
||||
text_length=len(text),
|
||||
ready_path=str(md_path)
|
||||
)
|
||||
except Exception as e:
|
||||
if file:
|
||||
file.status = "failed"
|
||||
await db.commit()
|
||||
|
||||
log_failure(
|
||||
"文件分割失败",
|
||||
project_id=str(project_id),
|
||||
file_id=str(request.file_id),
|
||||
method=request.method,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
|
||||
@router.post("/split", response_model=ApiResponse)
|
||||
async def split_text(
|
||||
project_id: UUID,
|
||||
@@ -88,7 +188,6 @@ async def split_text(
|
||||
):
|
||||
"""Split text into chunks"""
|
||||
try:
|
||||
# Get file
|
||||
result = await db.execute(
|
||||
select(File).where(File.id == request.file_id, File.project_id == project_id)
|
||||
)
|
||||
@@ -107,81 +206,27 @@ async def split_text(
|
||||
overlap=request.overlap
|
||||
)
|
||||
|
||||
# Process file
|
||||
text = await process_file_by_type(file)
|
||||
|
||||
# Update file status
|
||||
file.status = "processing"
|
||||
await db.commit()
|
||||
|
||||
# Split text
|
||||
kwargs = {"chunk_size": request.chunk_size, "overlap": request.overlap}
|
||||
if request.method == "custom" and request.separator:
|
||||
kwargs["separator"] = request.separator
|
||||
|
||||
# 如果使用 semantic_embedding 方法,传递 embedding 参数
|
||||
if request.method == "semantic_embedding":
|
||||
kwargs["embedding_provider_type"] = request.embedding_provider or "openai"
|
||||
kwargs["embedding_api_key"] = request.embedding_api_key
|
||||
kwargs["embedding_base_url"] = request.embedding_base_url or "https://api.minimax.chat/v1"
|
||||
kwargs["embedding_model"] = request.embedding_model or "text-embedding-3-small"
|
||||
kwargs["similarity_threshold"] = request.similarity_threshold
|
||||
kwargs["min_chunk_size"] = request.min_chunk_size
|
||||
|
||||
splitter = get_splitter(request.method, **kwargs)
|
||||
split_results = splitter.split(text)
|
||||
|
||||
# Save chunks
|
||||
chunks = []
|
||||
for chunk_data in split_results:
|
||||
db_chunk = Chunk(
|
||||
asyncio.create_task(
|
||||
process_split_async(
|
||||
project_id=project_id,
|
||||
file_id=file.id,
|
||||
name=chunk_data.get("name", f"Chunk {chunk_data['index'] + 1}"),
|
||||
content=chunk_data["content"],
|
||||
word_count=chunk_data.get("word_count", len(chunk_data["content"].split()))
|
||||
request=request,
|
||||
)
|
||||
db.add(db_chunk)
|
||||
chunks.append(db_chunk)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Save processed markdown to ready directory
|
||||
ready_dir = get_project_ready_dir(str(project_id))
|
||||
md_filename = f"{file.id}_{file.filename}.md"
|
||||
md_path = ready_dir / md_filename
|
||||
|
||||
# Write markdown content to file
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: md_path.write_text(text, encoding='utf-8')
|
||||
)
|
||||
|
||||
# Update file path to ready location
|
||||
file.file_path = str(md_path)
|
||||
file.status = "completed"
|
||||
await db.commit()
|
||||
|
||||
# 记录成功日志
|
||||
log_success(
|
||||
"文件处理完成",
|
||||
project_id=str(project_id),
|
||||
file_id=str(file.id),
|
||||
filename=file.filename,
|
||||
chunk_count=len(chunks),
|
||||
text_length=len(text),
|
||||
ready_path=str(md_path)
|
||||
)
|
||||
|
||||
return ApiResponse.ok(
|
||||
data={"chunks": len(chunks)},
|
||||
message=f"Successfully split into {len(chunks)} chunks"
|
||||
data={"file_id": str(file.id), "status": file.status},
|
||||
message="Split task started, processing in background"
|
||||
)
|
||||
except Exception as e:
|
||||
# 记录失败日志
|
||||
if 'file' in locals() and file:
|
||||
file.status = "failed"
|
||||
await db.commit()
|
||||
|
||||
log_failure(
|
||||
"文件处理失败",
|
||||
"分割任务启动失败",
|
||||
project_id=str(project_id),
|
||||
file_id=str(request.file_id),
|
||||
error=str(e)
|
||||
|
||||
@@ -9,6 +9,7 @@ from uuid import UUID, uuid4
|
||||
from fastapi import APIRouter, Depends, UploadFile, File, Query
|
||||
from fastapi.responses import FileResponse, PlainTextResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.api.response import ApiResponse, PaginatedResponse
|
||||
from app.core.config import get_settings
|
||||
@@ -17,6 +18,7 @@ from app.core.exceptions import ValidationException, NotFoundException
|
||||
from app.core.crud import CRUDBase
|
||||
from app.core.logging import log_success, log_failure
|
||||
from app.models.models import File as FileModel
|
||||
from app.models.models import Chunk, Question
|
||||
from app.schemas.file import FileResponse, FileCreateSchema
|
||||
from markitdown import MarkItDown
|
||||
|
||||
@@ -329,11 +331,27 @@ async def delete_file(
|
||||
file_id: UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Delete file"""
|
||||
"""Delete file and all related data (markdown, chunks, questions)"""
|
||||
file = await file_crud.get(db, file_id)
|
||||
if not file or file.project_id != project_id:
|
||||
raise NotFoundException("File", file_id)
|
||||
|
||||
# Delete related chunks and their questions (explicit deletion for safety)
|
||||
chunks_result = await db.execute(
|
||||
select(Chunk).where(Chunk.file_id == file_id)
|
||||
)
|
||||
chunks = chunks_result.scalars().all()
|
||||
for chunk in chunks:
|
||||
# Delete questions related to this chunk
|
||||
questions_result = await db.execute(
|
||||
select(Question).where(Question.chunk_id == chunk.id)
|
||||
)
|
||||
questions = questions_result.scalars().all()
|
||||
for question in questions:
|
||||
await db.delete(question)
|
||||
# Delete chunk
|
||||
await db.delete(chunk)
|
||||
|
||||
# Delete file from raw directory
|
||||
if file.file_path and os.path.exists(file.file_path):
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
@@ -342,16 +360,27 @@ async def delete_file(
|
||||
file.file_path
|
||||
)
|
||||
|
||||
# Delete file from ready directory (processed markdown)
|
||||
ready_path = Path("/data/code/YG-Datasets/data") / str(project_id) / "ready" / f"{file_id}.md"
|
||||
if ready_path.exists():
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
os.remove,
|
||||
str(ready_path)
|
||||
)
|
||||
# Delete file from ready directory (processed markdown) - try both naming conventions
|
||||
ready_dir = Path("/data/code/YG-Datasets/data") / str(project_id) / "ready"
|
||||
if ready_dir.exists():
|
||||
# Try file_id.md (from upload process)
|
||||
ready_path = ready_dir / f"{file_id}.md"
|
||||
if ready_path.exists():
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
os.remove,
|
||||
str(ready_path)
|
||||
)
|
||||
# Try file_id_filename.md (from split process)
|
||||
for md_file in ready_dir.glob(f"{file_id}_*.md"):
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
os.remove,
|
||||
str(md_file)
|
||||
)
|
||||
|
||||
await file_crud.delete(db, file_id)
|
||||
await db.commit()
|
||||
return ApiResponse.ok(message="File deleted successfully")
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,38 @@ from app.schemas.model import ModelCreate, ModelUpdate, ModelResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
VALID_MODEL_TYPES = {"chat", "vlm", "embedding", "rerank"}
|
||||
|
||||
|
||||
def normalize_model_type(model_type: str | None, model_name: str | None) -> str:
|
||||
"""Normalize model type, with keyword fallback for legacy records."""
|
||||
if model_type in VALID_MODEL_TYPES and model_type != "chat":
|
||||
return model_type
|
||||
|
||||
normalized_name = (model_name or "").strip().lower()
|
||||
|
||||
rerank_keywords = ("rerank", "bce-reranker", "gte-rerank")
|
||||
embedding_keywords = (
|
||||
"embedding",
|
||||
"embed",
|
||||
"text-embedding",
|
||||
"bge-",
|
||||
"bge_m3",
|
||||
"gte-",
|
||||
"m3e",
|
||||
"e5-",
|
||||
"jina-embeddings",
|
||||
)
|
||||
vlm_keywords = ("vl", "vision", "visual", "multimodal", "qwen-vl", "gpt-4o")
|
||||
|
||||
if any(keyword in normalized_name for keyword in rerank_keywords):
|
||||
return "rerank"
|
||||
if any(keyword in normalized_name for keyword in embedding_keywords):
|
||||
return "embedding"
|
||||
if any(keyword in normalized_name for keyword in vlm_keywords):
|
||||
return "vlm"
|
||||
return model_type if model_type in VALID_MODEL_TYPES else "chat"
|
||||
|
||||
|
||||
async def test_model_connection(model: ModelConfig) -> dict:
|
||||
"""Test model connection by calling the API"""
|
||||
@@ -23,6 +55,7 @@ async def test_model_connection(model: ModelConfig) -> dict:
|
||||
api_base = model.api_base or ""
|
||||
provider = model.provider
|
||||
model_name = model.model_name
|
||||
model_type = normalize_model_type(model.model_type, model_name)
|
||||
api_key = model.api_key
|
||||
|
||||
headers = {
|
||||
@@ -32,7 +65,7 @@ async def test_model_connection(model: ModelConfig) -> dict:
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
if provider == "openai":
|
||||
if model_type in {"chat", "vlm"} and provider in {"openai", "ali"}:
|
||||
# OpenAI compatible API test
|
||||
response = await client.post(
|
||||
f"{api_base.rstrip('/')}/chat/completions",
|
||||
@@ -43,7 +76,7 @@ async def test_model_connection(model: ModelConfig) -> dict:
|
||||
"max_tokens": 5
|
||||
}
|
||||
)
|
||||
elif provider == "minimax":
|
||||
elif model_type in {"chat", "vlm"} and provider == "minimax":
|
||||
# MiniMax API test
|
||||
response = await client.post(
|
||||
f"{api_base.rstrip('/')}/chat/completions_v2",
|
||||
@@ -56,7 +89,7 @@ async def test_model_connection(model: ModelConfig) -> dict:
|
||||
"messages": [{"role": "user", "content": "Hi"}]
|
||||
}
|
||||
)
|
||||
elif provider == "glm":
|
||||
elif model_type in {"chat", "vlm"} and provider == "glm":
|
||||
# GLM API test
|
||||
response = await client.post(
|
||||
f"{api_base.rstrip('/')}/chat/completions",
|
||||
@@ -66,8 +99,21 @@ async def test_model_connection(model: ModelConfig) -> dict:
|
||||
"messages": [{"role": "user", "content": "Hi"}]
|
||||
}
|
||||
)
|
||||
elif model_type == "embedding" and provider in {"openai", "ali", "glm"}:
|
||||
response = await client.post(
|
||||
f"{api_base.rstrip('/')}/embeddings",
|
||||
headers=headers,
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": "test"
|
||||
}
|
||||
)
|
||||
elif model_type == "embedding" and provider == "minimax":
|
||||
return {"success": False, "message": "MiniMax embedding 自动测试暂未接入,请手动确认端点与模型"}
|
||||
elif model_type == "rerank":
|
||||
return {"success": False, "message": "Rerank 自动测试暂未接入,请先保存配置并在实际流程中验证"}
|
||||
else:
|
||||
return {"success": False, "message": f"Unsupported provider: {provider}"}
|
||||
return {"success": False, "message": f"Unsupported provider/type: {provider}/{model_type}"}
|
||||
|
||||
if response.status_code == 200:
|
||||
return {"success": True, "message": "Connection successful"}
|
||||
@@ -114,6 +160,7 @@ async def create_model(model: ModelCreate, db: AsyncSession = Depends(get_db)):
|
||||
|
||||
db_model = ModelConfig(
|
||||
provider=model.provider,
|
||||
model_type=model.model_type,
|
||||
model_name=model.model_name,
|
||||
api_key=model.api_key,
|
||||
api_base=model.api_base,
|
||||
@@ -248,6 +295,7 @@ async def test_model(model_id: str, db: AsyncSession = Depends(get_db)):
|
||||
test_result = await test_model_connection(model)
|
||||
|
||||
# Save connection status to database
|
||||
model.model_type = normalize_model_type(model.model_type, model.model_name)
|
||||
model.connection_status = "connected" if test_result["success"] else "disconnected"
|
||||
await db.commit()
|
||||
await db.refresh(model)
|
||||
|
||||
@@ -1,31 +1,303 @@
|
||||
"""
|
||||
Questions API Router
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.response import ApiResponse, PaginatedResponse
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import NotFoundException, ValidationException
|
||||
from app.core.crud import CRUDBase
|
||||
from app.models.models import Question, Chunk
|
||||
from app.schemas.question import QuestionResponse
|
||||
from app.schemas.question import QuestionCreateSchema
|
||||
from app.core.database import AsyncSessionLocal, get_db
|
||||
from app.core.exceptions import NotFoundException, ValidationException
|
||||
from app.core.logging import log_failure, log_success
|
||||
from app.models.models import Chunk, ModelConfig, Question
|
||||
from app.schemas.question import QuestionCreateSchema, QuestionResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Initialize CRUD
|
||||
question_crud = CRUDBase(Question)
|
||||
|
||||
VALID_MODEL_TYPES = {"chat", "vlm", "embedding", "rerank"}
|
||||
|
||||
DEFAULT_PRESET_PROMPT = (
|
||||
"你是一名高质量中文问答数据构建助手。"
|
||||
"请基于给定 chunk 内容生成准确、自然、可用于训练的数据集问答对。"
|
||||
"问题必须清晰具体,答案必须直接来自内容或基于内容做合理概括,"
|
||||
"不要编造原文没有的信息,不要输出与目录、导航、页眉页脚、噪声文字相关的问题。"
|
||||
)
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
"""Request model for generating questions"""
|
||||
chunk_ids: List[UUID] = Field(..., min_length=1)
|
||||
count: int = Field(5, ge=1, le=50)
|
||||
question_types: List[str] = ["fact", "summary"]
|
||||
model_id: UUID
|
||||
count: int = Field(3, ge=1, le=10)
|
||||
dirty_data_filter: bool = True
|
||||
thinking_mode: bool = True
|
||||
preset_prompt: str = Field(default=DEFAULT_PRESET_PROMPT, min_length=1, max_length=4000)
|
||||
|
||||
|
||||
def normalize_model_type(model_type: str | None, model_name: str | None) -> str:
|
||||
"""Normalize model type, with keyword fallback for legacy records."""
|
||||
if model_type in VALID_MODEL_TYPES and model_type != "chat":
|
||||
return model_type
|
||||
|
||||
normalized_name = (model_name or "").strip().lower()
|
||||
rerank_keywords = ("rerank", "bce-reranker", "gte-rerank")
|
||||
embedding_keywords = (
|
||||
"embedding",
|
||||
"embed",
|
||||
"text-embedding",
|
||||
"bge-",
|
||||
"bge_m3",
|
||||
"gte-",
|
||||
"m3e",
|
||||
"e5-",
|
||||
"jina-embeddings",
|
||||
)
|
||||
vlm_keywords = ("vl", "vision", "visual", "multimodal", "qwen-vl", "gpt-4o")
|
||||
|
||||
if any(keyword in normalized_name for keyword in rerank_keywords):
|
||||
return "rerank"
|
||||
if any(keyword in normalized_name for keyword in embedding_keywords):
|
||||
return "embedding"
|
||||
if any(keyword in normalized_name for keyword in vlm_keywords):
|
||||
return "vlm"
|
||||
return model_type if model_type in VALID_MODEL_TYPES else "chat"
|
||||
|
||||
|
||||
def is_dirty_chunk(content: str) -> bool:
|
||||
"""Heuristic dirty-data filter for low-value chunks."""
|
||||
normalized = re.sub(r"\s+", " ", (content or "")).strip()
|
||||
if len(normalized) < 40:
|
||||
return True
|
||||
|
||||
if len(re.sub(r"[^\u4e00-\u9fffA-Za-z0-9]", "", normalized)) < 24:
|
||||
return True
|
||||
|
||||
lowered = normalized.lower()
|
||||
if lowered in {"目录", "contents", "table of contents"}:
|
||||
return True
|
||||
|
||||
lines = [line.strip() for line in (content or "").splitlines() if line.strip()]
|
||||
if lines:
|
||||
short_lines = sum(1 for line in lines if len(line) <= 18)
|
||||
dotted_lines = sum(1 for line in lines if re.search(r"[·•…\.]{3,}|\s\d+$", line))
|
||||
if short_lines / len(lines) > 0.7 and len(lines) >= 3:
|
||||
return True
|
||||
if dotted_lines / len(lines) > 0.4:
|
||||
return True
|
||||
|
||||
punctuation_ratio = sum(1 for ch in normalized if not ch.isalnum() and not ("\u4e00" <= ch <= "\u9fff")) / max(len(normalized), 1)
|
||||
if punctuation_ratio > 0.45:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def build_generation_prompt(chunk: Chunk, request: GenerateRequest) -> str:
|
||||
"""Build user prompt for QA generation."""
|
||||
thinking_instruction = (
|
||||
"请先对内容做简短分析,识别核心事实、概念、关系与潜在考点,然后再生成问答。"
|
||||
"分析过程只用于提高质量,不要在最终输出中暴露你的思维链。"
|
||||
if request.thinking_mode
|
||||
else "直接基于内容生成高质量问答。"
|
||||
)
|
||||
|
||||
return (
|
||||
f"{request.preset_prompt}\n\n"
|
||||
"输出要求:\n"
|
||||
f"1. 生成 {request.count} 组问答。\n"
|
||||
"2. 只输出 JSON 数组,不要输出解释、标题、Markdown。\n"
|
||||
'3. 每个对象结构为 {"question":"...","answer":"...","question_type":"fact|summary|reasoning"}。\n'
|
||||
"4. 问题避免重复,答案避免空泛。\n"
|
||||
"5. 如果内容不足以生成高质量问答,请返回空数组 []。\n"
|
||||
f"6. {thinking_instruction}\n\n"
|
||||
f"Chunk 名称:{chunk.name or '未命名分片'}\n"
|
||||
f"Chunk 内容:\n{chunk.content}"
|
||||
)
|
||||
|
||||
|
||||
def extract_text_from_response(data: dict) -> str:
|
||||
"""Extract response text from provider response."""
|
||||
choices = data.get("choices") or []
|
||||
if choices:
|
||||
message = choices[0].get("message") or {}
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = [item.get("text", "") for item in content if isinstance(item, dict)]
|
||||
return "\n".join(part for part in parts if part)
|
||||
return ""
|
||||
|
||||
|
||||
def parse_generated_questions(raw_text: str) -> List[dict]:
|
||||
"""Parse JSON array from model output."""
|
||||
text = (raw_text or "").strip()
|
||||
if not text:
|
||||
return []
|
||||
|
||||
fenced_match = re.search(r"```json\s*(.*?)\s*```", text, flags=re.S)
|
||||
if fenced_match:
|
||||
text = fenced_match.group(1).strip()
|
||||
|
||||
if not text.startswith("["):
|
||||
array_match = re.search(r"(\[\s*\{.*\}\s*\])", text, flags=re.S)
|
||||
if array_match:
|
||||
text = array_match.group(1)
|
||||
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
|
||||
if not isinstance(parsed, list):
|
||||
return []
|
||||
|
||||
normalized = []
|
||||
for item in parsed:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
question = str(item.get("question", "")).strip()
|
||||
answer = str(item.get("answer", "")).strip()
|
||||
question_type = str(item.get("question_type", "fact")).strip() or "fact"
|
||||
if not question or not answer:
|
||||
continue
|
||||
normalized.append({
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"question_type": question_type
|
||||
})
|
||||
return normalized
|
||||
|
||||
|
||||
async def call_generation_model(model: ModelConfig, prompt: str) -> str:
|
||||
"""Call configured chat model for question generation."""
|
||||
provider = model.provider
|
||||
api_base = (model.api_base or "").rstrip("/")
|
||||
api_key = model.api_key or ""
|
||||
model_name = model.model_name
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是问答数据构建助手。严格按 JSON 输出,不要输出额外说明。"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}
|
||||
],
|
||||
"temperature": 0.4,
|
||||
"response_format": {"type": "json_object"}
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
if provider == "minimax":
|
||||
response = await client.post(
|
||||
f"{api_base}/chat/completions_v2",
|
||||
headers=headers,
|
||||
json={k: v for k, v in payload.items() if k != "response_format"}
|
||||
)
|
||||
else:
|
||||
response = await client.post(
|
||||
f"{api_base}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
content = extract_text_from_response(data)
|
||||
if not content:
|
||||
raise ValueError("Model returned empty content")
|
||||
|
||||
if content.lstrip().startswith("{"):
|
||||
obj = json.loads(content)
|
||||
if isinstance(obj, dict) and isinstance(obj.get("questions"), list):
|
||||
return json.dumps(obj["questions"], ensure_ascii=False)
|
||||
return content
|
||||
|
||||
|
||||
async def process_generate_async(project_id: UUID, request: GenerateRequest):
|
||||
"""Generate QA pairs in background."""
|
||||
async with AsyncSessionLocal() as db:
|
||||
try:
|
||||
model_result = await db.execute(
|
||||
select(ModelConfig).where(ModelConfig.id == request.model_id, ModelConfig.project_id == None) # noqa: E711
|
||||
)
|
||||
model = model_result.scalar_one_or_none()
|
||||
if not model:
|
||||
return
|
||||
|
||||
model_type = normalize_model_type(model.model_type, model.model_name)
|
||||
if model_type not in {"chat", "vlm"}:
|
||||
raise ValidationException("Selected model must be chat/vlm type", field="model_id")
|
||||
|
||||
chunk_result = await db.execute(
|
||||
select(Chunk).where(Chunk.id.in_(request.chunk_ids), Chunk.project_id == project_id)
|
||||
)
|
||||
chunks = chunk_result.scalars().all()
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
created_count = 0
|
||||
skipped_count = 0
|
||||
for chunk in chunks:
|
||||
if request.dirty_data_filter and is_dirty_chunk(chunk.content):
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
prompt = build_generation_prompt(chunk, request)
|
||||
raw_text = await call_generation_model(model, prompt)
|
||||
qa_pairs = parse_generated_questions(raw_text)[:request.count]
|
||||
if not qa_pairs:
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
for item in qa_pairs:
|
||||
db.add(Question(
|
||||
project_id=project_id,
|
||||
chunk_id=chunk.id,
|
||||
content=item["question"],
|
||||
answer=item["answer"],
|
||||
question_type=item["question_type"],
|
||||
source="generated"
|
||||
))
|
||||
created_count += 1
|
||||
|
||||
await db.commit()
|
||||
|
||||
log_success(
|
||||
"问答批量生成完成",
|
||||
project_id=str(project_id),
|
||||
model_id=str(model.id),
|
||||
chunk_count=len(chunks),
|
||||
created_questions=created_count,
|
||||
skipped_chunks=skipped_count
|
||||
)
|
||||
except Exception as e:
|
||||
log_failure(
|
||||
"问答批量生成失败",
|
||||
project_id=str(project_id),
|
||||
model_id=str(request.model_id),
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
|
||||
@router.post("/generate", response_model=ApiResponse)
|
||||
@@ -34,36 +306,33 @@ async def generate_questions(
|
||||
request: GenerateRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Generate questions from chunks using LLM"""
|
||||
# Get chunks
|
||||
result = await db.execute(
|
||||
select(Chunk).where(Chunk.id.in_(request.chunk_ids), Chunk.project_id == project_id)
|
||||
"""Generate questions from chunks using LLM in background."""
|
||||
model_result = await db.execute(
|
||||
select(ModelConfig).where(ModelConfig.id == request.model_id, ModelConfig.project_id == None) # noqa: E711
|
||||
)
|
||||
chunks = result.scalars().all()
|
||||
model = model_result.scalar_one_or_none()
|
||||
if not model:
|
||||
raise ValidationException("Selected model not found", field="model_id")
|
||||
|
||||
if not chunks:
|
||||
model_type = normalize_model_type(model.model_type, model.model_name)
|
||||
if model_type not in {"chat", "vlm"}:
|
||||
raise ValidationException("Selected model must be chat/vlm type", field="model_id")
|
||||
if not model.api_key:
|
||||
raise ValidationException("Selected model is missing API Key", field="model_id")
|
||||
|
||||
chunk_result = await db.execute(
|
||||
select(Chunk.id).where(Chunk.id.in_(request.chunk_ids), Chunk.project_id == project_id)
|
||||
)
|
||||
valid_chunk_ids = [row[0] for row in chunk_result.all()]
|
||||
if not valid_chunk_ids:
|
||||
raise ValidationException("No valid chunks found", field="chunk_ids")
|
||||
|
||||
# Create sample questions (placeholder for LLM-based generation)
|
||||
created_questions = []
|
||||
for chunk in chunks:
|
||||
for i in range(request.count):
|
||||
question = Question(
|
||||
project_id=project_id,
|
||||
chunk_id=chunk.id,
|
||||
content=f"这是关于「{chunk.name}」的问题 {i+1}?",
|
||||
answer=f"这是问题 {i+1} 的答案。",
|
||||
question_type=request.question_types[0] if request.question_types else "fact",
|
||||
source="generated"
|
||||
)
|
||||
db.add(question)
|
||||
created_questions.append(question)
|
||||
|
||||
await db.commit()
|
||||
request_payload = request.model_copy(update={"chunk_ids": valid_chunk_ids})
|
||||
asyncio.create_task(process_generate_async(project_id, request_payload))
|
||||
|
||||
return ApiResponse.ok(
|
||||
data={"questions": len(created_questions)},
|
||||
message=f"Successfully generated {len(created_questions)} questions"
|
||||
data={"chunk_count": len(valid_chunk_ids), "status": "processing"},
|
||||
message="Question generation started in background"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy import create_engine, event, inspect, text
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.core.config import get_settings
|
||||
@@ -65,9 +65,28 @@ async def init_db():
|
||||
logger.info("Initializing database...")
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
await conn.run_sync(_ensure_legacy_columns)
|
||||
logger.info("Database initialized successfully")
|
||||
|
||||
|
||||
def _ensure_legacy_columns(sync_conn):
|
||||
"""Patch legacy tables with newly introduced columns."""
|
||||
inspector = inspect(sync_conn)
|
||||
if "model_configs" not in inspector.get_table_names():
|
||||
return
|
||||
|
||||
columns = {column["name"] for column in inspector.get_columns("model_configs")}
|
||||
if "model_type" in columns:
|
||||
return
|
||||
|
||||
logger.info("Adding missing model_type column to model_configs table")
|
||||
dialect = sync_conn.dialect.name
|
||||
if dialect == "postgresql":
|
||||
sync_conn.execute(text("ALTER TABLE model_configs ADD COLUMN model_type VARCHAR(50) NOT NULL DEFAULT 'chat'"))
|
||||
else:
|
||||
sync_conn.execute(text("ALTER TABLE model_configs ADD COLUMN model_type VARCHAR(50) NOT NULL DEFAULT 'chat'"))
|
||||
|
||||
|
||||
async def close_db():
|
||||
"""Close database connections"""
|
||||
logger.info("Closing database connections...")
|
||||
|
||||
@@ -137,7 +137,8 @@ class ModelConfig(Base, UUIDMixin, TimestampMixin):
|
||||
__tablename__ = "model_configs"
|
||||
|
||||
project_id = Column(UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"), nullable=True)
|
||||
provider = Column(String(50), nullable=False) # minimax, glm, openai
|
||||
provider = Column(String(50), nullable=False) # minimax, glm, openai, ali
|
||||
model_type = Column(String(50), nullable=False, default="chat") # chat, vlm, embedding, rerank
|
||||
model_name = Column(String(100))
|
||||
api_key = Column(String(500))
|
||||
api_base = Column(String(500))
|
||||
|
||||
@@ -9,7 +9,8 @@ from uuid import UUID
|
||||
|
||||
class ModelBase(BaseModel):
|
||||
"""Base model schema"""
|
||||
provider: str = Field(..., description="Model provider: minimax, glm, openai")
|
||||
provider: str = Field(..., description="Model provider: minimax, glm, openai, ali")
|
||||
model_type: str = Field(default="chat", description="Model type: chat, vlm, embedding, rerank")
|
||||
model_name: str = Field(..., description="Model name")
|
||||
api_key: Optional[str] = Field(None, description="API key")
|
||||
api_base: Optional[str] = Field(None, description="API base URL")
|
||||
@@ -24,6 +25,7 @@ class ModelCreate(ModelBase):
|
||||
class ModelUpdate(BaseModel):
|
||||
"""Model update schema"""
|
||||
provider: Optional[str] = None
|
||||
model_type: Optional[str] = None
|
||||
model_name: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
api_base: Optional[str] = None
|
||||
|
||||
@@ -8,6 +8,7 @@ import httpx
|
||||
import numpy as np
|
||||
from typing import List, Dict, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
@@ -109,32 +110,28 @@ class EmbeddingSplitter:
|
||||
|
||||
def _tokenize_sentences(self, text: str) -> List[str]:
|
||||
"""将文本切分为句子"""
|
||||
# 中英文句末符号
|
||||
# 先按换行分割,保持段落结构
|
||||
paragraphs = re.split(r'\n+', text)
|
||||
|
||||
paragraphs = re.split(r'\n\s*\n+', text)
|
||||
sentences = []
|
||||
for para in paragraphs:
|
||||
if not para.strip():
|
||||
para = para.strip()
|
||||
if not para:
|
||||
continue
|
||||
|
||||
# 按句子符号分割
|
||||
# 中文:。!?;
|
||||
# 英文:. ! ? ;
|
||||
parts = re.split(r'([。!?;\n]|(?<=[.!?])\s+)', para)
|
||||
parts = re.split(r'(?<=[。!?;.!?])\s+|(?<=[。!?;])', para)
|
||||
buffer = []
|
||||
|
||||
# 重新组合句子
|
||||
current_sentence = ""
|
||||
for part in parts:
|
||||
if part in '。!?;.\n':
|
||||
if current_sentence.strip():
|
||||
sentences.append(current_sentence.strip())
|
||||
current_sentence = ""
|
||||
elif part and part.strip():
|
||||
current_sentence += part
|
||||
# 处理最后一个句子
|
||||
if current_sentence.strip():
|
||||
sentences.append(current_sentence.strip())
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
|
||||
# 过短的片段先暂存,尽量与后一句合并,避免 embedding 粒度过碎
|
||||
if len(part) < 8 and buffer:
|
||||
buffer[-1] = f"{buffer[-1]} {part}".strip()
|
||||
else:
|
||||
buffer.append(part)
|
||||
|
||||
sentences.extend(buffer)
|
||||
|
||||
return sentences
|
||||
|
||||
@@ -162,51 +159,48 @@ class EmbeddingSplitter:
|
||||
if not similarities:
|
||||
return []
|
||||
|
||||
window = self.window_size
|
||||
window = max(1, self.window_size)
|
||||
smoothed = []
|
||||
|
||||
for i in range(len(similarities)):
|
||||
start = max(0, i - window + 1)
|
||||
end = i + 1
|
||||
start = max(0, i - window)
|
||||
end = min(len(similarities), i + window + 1)
|
||||
window_vals = similarities[start:end]
|
||||
smoothed.append(sum(window_vals) / len(window_vals))
|
||||
|
||||
return smoothed
|
||||
|
||||
def _detect_boundaries(self, similarities: List[float]) -> List[int]:
|
||||
def _detect_boundaries(self, similarities: List[float], sentence_lengths: List[int]) -> List[int]:
|
||||
"""检测分割点(相似度显著下降的位置)"""
|
||||
if not similarities:
|
||||
return [0]
|
||||
|
||||
# 平滑
|
||||
smoothed = self._smooth_similarities(similarities)
|
||||
|
||||
# 计算深度分数(类似 TextTiling)
|
||||
depth_scores = []
|
||||
for i in range(1, len(smoothed) - 1):
|
||||
# 当前位置的深度 = 当前位置的值 - 平均值
|
||||
# 但更准确的是:左侧平均 - 右侧平均
|
||||
left_avg = sum(smoothed[max(0, i - self.window_size):i]) / self.window_size
|
||||
right_avg = sum(smoothed[i:min(len(smoothed), i + self.window_size)]) / self.window_size
|
||||
depth = left_avg - right_avg
|
||||
depth_scores.append(depth)
|
||||
|
||||
# 如果没有足够的点,直接返回
|
||||
if not depth_scores:
|
||||
if len(smoothed) <= 1:
|
||||
return [0]
|
||||
|
||||
# 阈值判断
|
||||
mean_depth = np.mean(depth_scores)
|
||||
std_depth = np.std(depth_scores)
|
||||
|
||||
# 找分割点:depth 显著高于均值的位置
|
||||
threshold = mean_depth + 0.5 * std_depth
|
||||
mean_sim = float(np.mean(smoothed))
|
||||
std_sim = float(np.std(smoothed))
|
||||
dynamic_threshold = max(0.0, min(0.95, mean_sim - 0.5 * std_sim))
|
||||
effective_threshold = max(self.similarity_threshold, dynamic_threshold)
|
||||
|
||||
boundaries = [0] # 起始点
|
||||
for i, depth in enumerate(depth_scores):
|
||||
if depth > threshold and depth > self.similarity_threshold:
|
||||
boundaries.append(i + 1) # 对应相似度的下一个位置
|
||||
boundaries.append(len(self._tokenize_sentences.__name__)) # 结束点
|
||||
accumulated_chars = 0
|
||||
|
||||
for i, sim in enumerate(smoothed):
|
||||
accumulated_chars += sentence_lengths[i]
|
||||
|
||||
left_sim = smoothed[i - 1] if i > 0 else 1.0
|
||||
right_sim = smoothed[i + 1] if i < len(smoothed) - 1 else 1.0
|
||||
is_local_min = sim <= left_sim and sim <= right_sim
|
||||
has_enough_context = accumulated_chars >= self.min_chunk_size
|
||||
oversize_guard = accumulated_chars >= self.chunk_size
|
||||
|
||||
if (is_local_min and has_enough_context and sim <= effective_threshold) or oversize_guard:
|
||||
boundaries.append(i + 1)
|
||||
accumulated_chars = 0
|
||||
|
||||
boundaries.append(len(sentence_lengths))
|
||||
|
||||
return sorted(list(set(boundaries)))
|
||||
|
||||
@@ -225,7 +219,12 @@ class EmbeddingSplitter:
|
||||
for i in range(len(boundaries) - 1):
|
||||
start = boundaries[i]
|
||||
end = boundaries[i + 1]
|
||||
chunk_text = ' '.join(sentences[start:end])
|
||||
if start >= end:
|
||||
continue
|
||||
|
||||
chunk_text = ' '.join(sentences[start:end]).strip()
|
||||
if not chunk_text:
|
||||
continue
|
||||
|
||||
# 如果 chunk 过大,递归分割
|
||||
if len(chunk_text) > self.chunk_size * 1.5:
|
||||
@@ -278,14 +277,22 @@ class EmbeddingSplitter:
|
||||
merged = [chunks[0]]
|
||||
|
||||
for chunk in chunks[1:]:
|
||||
# 如果前一个 chunk 太小,合并
|
||||
if merged[-1]["char_count"] < self.min_chunk_size:
|
||||
merged[-1]["content"] += " " + chunk["content"]
|
||||
merged[-1]["word_count"] += chunk["word_count"]
|
||||
merged[-1]["char_count"] += chunk["char_count"]
|
||||
previous = merged[-1]
|
||||
should_merge = (
|
||||
previous["char_count"] < self.min_chunk_size or
|
||||
chunk["char_count"] < self.min_chunk_size
|
||||
)
|
||||
|
||||
if should_merge and previous["char_count"] + chunk["char_count"] <= self.chunk_size * 1.5:
|
||||
previous["content"] += " " + chunk["content"]
|
||||
previous["word_count"] += chunk["word_count"]
|
||||
previous["char_count"] += chunk["char_count"]
|
||||
else:
|
||||
merged.append(chunk)
|
||||
|
||||
for index, chunk in enumerate(merged):
|
||||
chunk["index"] = index
|
||||
|
||||
return merged
|
||||
|
||||
async def split_with_embedding(self, text: str) -> List[Dict]:
|
||||
@@ -295,8 +302,8 @@ class EmbeddingSplitter:
|
||||
if not sentences:
|
||||
return []
|
||||
|
||||
# 过滤过短的句子
|
||||
sentences = [s for s in sentences if len(s) >= 10]
|
||||
# 过滤纯噪音片段,但保留正常短句
|
||||
sentences = [s for s in sentences if len(s.strip()) >= 4]
|
||||
|
||||
if not sentences:
|
||||
return []
|
||||
@@ -312,17 +319,22 @@ class EmbeddingSplitter:
|
||||
|
||||
# 3. 调用 Embedding API
|
||||
try:
|
||||
if self.embedding_provider is None:
|
||||
raise ValueError("embedding provider is not configured")
|
||||
embeddings = await self.embedding_provider.get_embeddings(sentences)
|
||||
except Exception as e:
|
||||
# 如果 embedding 失败,降级到规则分割
|
||||
print(f"Embedding failed, falling back to rule-based: {e}")
|
||||
return self._fallback_split(text)
|
||||
|
||||
if len(embeddings) != len(sentences):
|
||||
return self._fallback_split(text)
|
||||
|
||||
# 4. 计算相似度
|
||||
similarities = self._compute_similarities(embeddings)
|
||||
|
||||
# 5. 检测分割点
|
||||
boundaries = self._detect_boundaries(similarities)
|
||||
boundaries = self._detect_boundaries(similarities, [len(sentence) for sentence in sentences])
|
||||
|
||||
# 6. 组装 chunks
|
||||
chunks = self._assemble_chunks(sentences, boundaries)
|
||||
@@ -387,7 +399,7 @@ class SemanticEmbeddingSplitter(EmbeddingSplitter):
|
||||
|
||||
def create_embedding_provider(provider: str, api_key: str, base_url: str, model: str = None) -> EmbeddingProvider:
|
||||
"""创建 Embedding 提供商"""
|
||||
if provider in ["openai", "compatible"]:
|
||||
if provider in ["openai", "compatible", "ali", "glm"]:
|
||||
return OpenAIEmbedding(api_key, base_url, model or "text-embedding-3-small")
|
||||
elif provider == "minimax":
|
||||
return MiniMaxEmbedding(api_key, base_url)
|
||||
|
||||
2410
backend/logs/2026-03-18/app.log.2026-03-18
Normal file
2410
backend/logs/2026-03-18/app.log.2026-03-18
Normal file
File diff suppressed because it is too large
Load Diff
1
backend/uploads/.gitkeep
Normal file
1
backend/uploads/.gitkeep
Normal file
@@ -0,0 +1 @@
|
||||
# This file ensures the uploads directory is tracked in git
|
||||
3657
backend/uv.lock
generated
Normal file
3657
backend/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
BIN
backend/ygdataset.db
Normal file
BIN
backend/ygdataset.db
Normal file
Binary file not shown.
Reference in New Issue
Block a user