refactor: 前端架构重构 - 提取 CSS 和逻辑到独立模块

前端重构:
- 删除旧的大体积 Vue 组件(HomeView, FileManage, TextSplit 等)
- 删除旧的 composables(useFormatters, useModels, useProjects)
- 新增 core/, page-logic/, pages/, shared/ 模块化目录结构
- 提取 CSS 到 styles/pages/ 目录
- 添加全局样式 variables.css 和 common.css

后端 API 更新:
- chunks: 语义分割 API 增强
- files: 文件处理 API 更新
- models: 模型管理 API 更新
- questions: 问答管理 API 更新
- database: 数据库连接优化
- semantic_embedding: 语义嵌入服务优化

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Developer
2026-03-19 14:23:34 +08:00
parent a280b4f014
commit 6aa271c4f7
75 changed files with 22636 additions and 6519 deletions

View File

@@ -11,7 +11,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.api.response import ApiResponse, PaginatedResponse
from app.core.database import get_db
from app.core.database import get_db, AsyncSessionLocal
from app.core.exceptions import NotFoundException
from app.core.crud import CRUDBase
from app.core.logging import log_success, log_failure
@@ -80,6 +80,106 @@ async def process_file_by_type(file: File) -> str:
return content
async def process_split_async(
project_id: UUID,
request: SplitRequest,
):
"""Run chunk splitting in background."""
async with AsyncSessionLocal() as db:
file = None
try:
result = await db.execute(
select(File).where(File.id == request.file_id, File.project_id == project_id)
)
file = result.scalar_one_or_none()
if not file:
return
text = await process_file_by_type(file)
kwargs = {"chunk_size": request.chunk_size, "overlap": request.overlap}
if request.method == "custom" and request.separator:
kwargs["separator"] = request.separator
if request.method == "semantic_embedding":
kwargs["embedding_provider_type"] = request.embedding_provider or "openai"
kwargs["embedding_api_key"] = request.embedding_api_key
kwargs["embedding_base_url"] = request.embedding_base_url or "https://api.minimax.chat/v1"
kwargs["embedding_model"] = request.embedding_model or "text-embedding-3-small"
kwargs["similarity_threshold"] = request.similarity_threshold
kwargs["min_chunk_size"] = request.min_chunk_size
splitter = get_splitter(request.method, **kwargs)
split_results = splitter.split(text)
await db.execute(
Chunk.__table__.delete().where(
Chunk.project_id == project_id,
Chunk.file_id == file.id
)
)
chunks = []
for chunk_data in split_results:
db_chunk = Chunk(
project_id=project_id,
file_id=file.id,
name=chunk_data.get("name", f"Chunk {chunk_data['index'] + 1}"),
content=chunk_data["content"],
word_count=chunk_data.get("word_count", len(chunk_data["content"].split()))
)
db.add(db_chunk)
chunks.append(db_chunk)
await db.commit()
ready_dir = get_project_ready_dir(str(project_id))
# 删除旧的 markdown 文件(可能有两种命名格式)
old_md_files = list(ready_dir.glob(f"{file.id}*.md"))
for old_file in old_md_files:
try:
old_file.unlink()
except Exception:
pass
md_filename = f"{file.id}.md"
md_path = ready_dir / md_filename
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
lambda: md_path.write_text(text, encoding='utf-8')
)
file.file_path = str(md_path)
file.status = "completed"
await db.commit()
log_success(
"文件分割完成",
project_id=str(project_id),
file_id=str(file.id),
filename=file.filename,
method=request.method,
chunk_count=len(chunks),
text_length=len(text),
ready_path=str(md_path)
)
except Exception as e:
if file:
file.status = "failed"
await db.commit()
log_failure(
"文件分割失败",
project_id=str(project_id),
file_id=str(request.file_id),
method=request.method,
error=str(e)
)
@router.post("/split", response_model=ApiResponse)
async def split_text(
project_id: UUID,
@@ -88,7 +188,6 @@ async def split_text(
):
"""Split text into chunks"""
try:
# Get file
result = await db.execute(
select(File).where(File.id == request.file_id, File.project_id == project_id)
)
@@ -107,81 +206,27 @@ async def split_text(
overlap=request.overlap
)
# Process file
text = await process_file_by_type(file)
# Update file status
file.status = "processing"
await db.commit()
# Split text
kwargs = {"chunk_size": request.chunk_size, "overlap": request.overlap}
if request.method == "custom" and request.separator:
kwargs["separator"] = request.separator
# 如果使用 semantic_embedding 方法,传递 embedding 参数
if request.method == "semantic_embedding":
kwargs["embedding_provider_type"] = request.embedding_provider or "openai"
kwargs["embedding_api_key"] = request.embedding_api_key
kwargs["embedding_base_url"] = request.embedding_base_url or "https://api.minimax.chat/v1"
kwargs["embedding_model"] = request.embedding_model or "text-embedding-3-small"
kwargs["similarity_threshold"] = request.similarity_threshold
kwargs["min_chunk_size"] = request.min_chunk_size
splitter = get_splitter(request.method, **kwargs)
split_results = splitter.split(text)
# Save chunks
chunks = []
for chunk_data in split_results:
db_chunk = Chunk(
asyncio.create_task(
process_split_async(
project_id=project_id,
file_id=file.id,
name=chunk_data.get("name", f"Chunk {chunk_data['index'] + 1}"),
content=chunk_data["content"],
word_count=chunk_data.get("word_count", len(chunk_data["content"].split()))
request=request,
)
db.add(db_chunk)
chunks.append(db_chunk)
await db.commit()
# Save processed markdown to ready directory
ready_dir = get_project_ready_dir(str(project_id))
md_filename = f"{file.id}_{file.filename}.md"
md_path = ready_dir / md_filename
# Write markdown content to file
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
lambda: md_path.write_text(text, encoding='utf-8')
)
# Update file path to ready location
file.file_path = str(md_path)
file.status = "completed"
await db.commit()
# 记录成功日志
log_success(
"文件处理完成",
project_id=str(project_id),
file_id=str(file.id),
filename=file.filename,
chunk_count=len(chunks),
text_length=len(text),
ready_path=str(md_path)
)
return ApiResponse.ok(
data={"chunks": len(chunks)},
message=f"Successfully split into {len(chunks)} chunks"
data={"file_id": str(file.id), "status": file.status},
message="Split task started, processing in background"
)
except Exception as e:
# 记录失败日志
if 'file' in locals() and file:
file.status = "failed"
await db.commit()
log_failure(
"文件处理失败",
"分割任务启动失败",
project_id=str(project_id),
file_id=str(request.file_id),
error=str(e)

View File

@@ -9,6 +9,7 @@ from uuid import UUID, uuid4
from fastapi import APIRouter, Depends, UploadFile, File, Query
from fastapi.responses import FileResponse, PlainTextResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.api.response import ApiResponse, PaginatedResponse
from app.core.config import get_settings
@@ -17,6 +18,7 @@ from app.core.exceptions import ValidationException, NotFoundException
from app.core.crud import CRUDBase
from app.core.logging import log_success, log_failure
from app.models.models import File as FileModel
from app.models.models import Chunk, Question
from app.schemas.file import FileResponse, FileCreateSchema
from markitdown import MarkItDown
@@ -329,11 +331,27 @@ async def delete_file(
file_id: UUID,
db: AsyncSession = Depends(get_db)
):
"""Delete file"""
"""Delete file and all related data (markdown, chunks, questions)"""
file = await file_crud.get(db, file_id)
if not file or file.project_id != project_id:
raise NotFoundException("File", file_id)
# Delete related chunks and their questions (explicit deletion for safety)
chunks_result = await db.execute(
select(Chunk).where(Chunk.file_id == file_id)
)
chunks = chunks_result.scalars().all()
for chunk in chunks:
# Delete questions related to this chunk
questions_result = await db.execute(
select(Question).where(Question.chunk_id == chunk.id)
)
questions = questions_result.scalars().all()
for question in questions:
await db.delete(question)
# Delete chunk
await db.delete(chunk)
# Delete file from raw directory
if file.file_path and os.path.exists(file.file_path):
await asyncio.get_event_loop().run_in_executor(
@@ -342,16 +360,27 @@ async def delete_file(
file.file_path
)
# Delete file from ready directory (processed markdown)
ready_path = Path("/data/code/YG-Datasets/data") / str(project_id) / "ready" / f"{file_id}.md"
if ready_path.exists():
await asyncio.get_event_loop().run_in_executor(
None,
os.remove,
str(ready_path)
)
# Delete file from ready directory (processed markdown) - try both naming conventions
ready_dir = Path("/data/code/YG-Datasets/data") / str(project_id) / "ready"
if ready_dir.exists():
# Try file_id.md (from upload process)
ready_path = ready_dir / f"{file_id}.md"
if ready_path.exists():
await asyncio.get_event_loop().run_in_executor(
None,
os.remove,
str(ready_path)
)
# Try file_id_filename.md (from split process)
for md_file in ready_dir.glob(f"{file_id}_*.md"):
await asyncio.get_event_loop().run_in_executor(
None,
os.remove,
str(md_file)
)
await file_crud.delete(db, file_id)
await db.commit()
return ApiResponse.ok(message="File deleted successfully")

View File

@@ -14,6 +14,38 @@ from app.schemas.model import ModelCreate, ModelUpdate, ModelResponse
router = APIRouter()
VALID_MODEL_TYPES = {"chat", "vlm", "embedding", "rerank"}
def normalize_model_type(model_type: str | None, model_name: str | None) -> str:
"""Normalize model type, with keyword fallback for legacy records."""
if model_type in VALID_MODEL_TYPES and model_type != "chat":
return model_type
normalized_name = (model_name or "").strip().lower()
rerank_keywords = ("rerank", "bce-reranker", "gte-rerank")
embedding_keywords = (
"embedding",
"embed",
"text-embedding",
"bge-",
"bge_m3",
"gte-",
"m3e",
"e5-",
"jina-embeddings",
)
vlm_keywords = ("vl", "vision", "visual", "multimodal", "qwen-vl", "gpt-4o")
if any(keyword in normalized_name for keyword in rerank_keywords):
return "rerank"
if any(keyword in normalized_name for keyword in embedding_keywords):
return "embedding"
if any(keyword in normalized_name for keyword in vlm_keywords):
return "vlm"
return model_type if model_type in VALID_MODEL_TYPES else "chat"
async def test_model_connection(model: ModelConfig) -> dict:
"""Test model connection by calling the API"""
@@ -23,6 +55,7 @@ async def test_model_connection(model: ModelConfig) -> dict:
api_base = model.api_base or ""
provider = model.provider
model_name = model.model_name
model_type = normalize_model_type(model.model_type, model_name)
api_key = model.api_key
headers = {
@@ -32,7 +65,7 @@ async def test_model_connection(model: ModelConfig) -> dict:
try:
async with httpx.AsyncClient(timeout=10.0) as client:
if provider == "openai":
if model_type in {"chat", "vlm"} and provider in {"openai", "ali"}:
# OpenAI compatible API test
response = await client.post(
f"{api_base.rstrip('/')}/chat/completions",
@@ -43,7 +76,7 @@ async def test_model_connection(model: ModelConfig) -> dict:
"max_tokens": 5
}
)
elif provider == "minimax":
elif model_type in {"chat", "vlm"} and provider == "minimax":
# MiniMax API test
response = await client.post(
f"{api_base.rstrip('/')}/chat/completions_v2",
@@ -56,7 +89,7 @@ async def test_model_connection(model: ModelConfig) -> dict:
"messages": [{"role": "user", "content": "Hi"}]
}
)
elif provider == "glm":
elif model_type in {"chat", "vlm"} and provider == "glm":
# GLM API test
response = await client.post(
f"{api_base.rstrip('/')}/chat/completions",
@@ -66,8 +99,21 @@ async def test_model_connection(model: ModelConfig) -> dict:
"messages": [{"role": "user", "content": "Hi"}]
}
)
elif model_type == "embedding" and provider in {"openai", "ali", "glm"}:
response = await client.post(
f"{api_base.rstrip('/')}/embeddings",
headers=headers,
json={
"model": model_name,
"input": "test"
}
)
elif model_type == "embedding" and provider == "minimax":
return {"success": False, "message": "MiniMax embedding 自动测试暂未接入,请手动确认端点与模型"}
elif model_type == "rerank":
return {"success": False, "message": "Rerank 自动测试暂未接入,请先保存配置并在实际流程中验证"}
else:
return {"success": False, "message": f"Unsupported provider: {provider}"}
return {"success": False, "message": f"Unsupported provider/type: {provider}/{model_type}"}
if response.status_code == 200:
return {"success": True, "message": "Connection successful"}
@@ -114,6 +160,7 @@ async def create_model(model: ModelCreate, db: AsyncSession = Depends(get_db)):
db_model = ModelConfig(
provider=model.provider,
model_type=model.model_type,
model_name=model.model_name,
api_key=model.api_key,
api_base=model.api_base,
@@ -248,6 +295,7 @@ async def test_model(model_id: str, db: AsyncSession = Depends(get_db)):
test_result = await test_model_connection(model)
# Save connection status to database
model.model_type = normalize_model_type(model.model_type, model.model_name)
model.connection_status = "connected" if test_result["success"] else "disconnected"
await db.commit()
await db.refresh(model)

View File

@@ -1,31 +1,303 @@
"""
Questions API Router
"""
import asyncio
import json
import re
from typing import List, Optional
from uuid import UUID
from pydantic import BaseModel, Field
import httpx
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.response import ApiResponse, PaginatedResponse
from app.core.database import get_db
from app.core.exceptions import NotFoundException, ValidationException
from app.core.crud import CRUDBase
from app.models.models import Question, Chunk
from app.schemas.question import QuestionResponse
from app.schemas.question import QuestionCreateSchema
from app.core.database import AsyncSessionLocal, get_db
from app.core.exceptions import NotFoundException, ValidationException
from app.core.logging import log_failure, log_success
from app.models.models import Chunk, ModelConfig, Question
from app.schemas.question import QuestionCreateSchema, QuestionResponse
router = APIRouter()
# Initialize CRUD
question_crud = CRUDBase(Question)
VALID_MODEL_TYPES = {"chat", "vlm", "embedding", "rerank"}
DEFAULT_PRESET_PROMPT = (
"你是一名高质量中文问答数据构建助手。"
"请基于给定 chunk 内容生成准确、自然、可用于训练的数据集问答对。"
"问题必须清晰具体,答案必须直接来自内容或基于内容做合理概括,"
"不要编造原文没有的信息,不要输出与目录、导航、页眉页脚、噪声文字相关的问题。"
)
class GenerateRequest(BaseModel):
"""Request model for generating questions"""
chunk_ids: List[UUID] = Field(..., min_length=1)
count: int = Field(5, ge=1, le=50)
question_types: List[str] = ["fact", "summary"]
model_id: UUID
count: int = Field(3, ge=1, le=10)
dirty_data_filter: bool = True
thinking_mode: bool = True
preset_prompt: str = Field(default=DEFAULT_PRESET_PROMPT, min_length=1, max_length=4000)
def normalize_model_type(model_type: str | None, model_name: str | None) -> str:
"""Normalize model type, with keyword fallback for legacy records."""
if model_type in VALID_MODEL_TYPES and model_type != "chat":
return model_type
normalized_name = (model_name or "").strip().lower()
rerank_keywords = ("rerank", "bce-reranker", "gte-rerank")
embedding_keywords = (
"embedding",
"embed",
"text-embedding",
"bge-",
"bge_m3",
"gte-",
"m3e",
"e5-",
"jina-embeddings",
)
vlm_keywords = ("vl", "vision", "visual", "multimodal", "qwen-vl", "gpt-4o")
if any(keyword in normalized_name for keyword in rerank_keywords):
return "rerank"
if any(keyword in normalized_name for keyword in embedding_keywords):
return "embedding"
if any(keyword in normalized_name for keyword in vlm_keywords):
return "vlm"
return model_type if model_type in VALID_MODEL_TYPES else "chat"
def is_dirty_chunk(content: str) -> bool:
"""Heuristic dirty-data filter for low-value chunks."""
normalized = re.sub(r"\s+", " ", (content or "")).strip()
if len(normalized) < 40:
return True
if len(re.sub(r"[^\u4e00-\u9fffA-Za-z0-9]", "", normalized)) < 24:
return True
lowered = normalized.lower()
if lowered in {"目录", "contents", "table of contents"}:
return True
lines = [line.strip() for line in (content or "").splitlines() if line.strip()]
if lines:
short_lines = sum(1 for line in lines if len(line) <= 18)
dotted_lines = sum(1 for line in lines if re.search(r"[·•…\.]{3,}|\s\d+$", line))
if short_lines / len(lines) > 0.7 and len(lines) >= 3:
return True
if dotted_lines / len(lines) > 0.4:
return True
punctuation_ratio = sum(1 for ch in normalized if not ch.isalnum() and not ("\u4e00" <= ch <= "\u9fff")) / max(len(normalized), 1)
if punctuation_ratio > 0.45:
return True
return False
def build_generation_prompt(chunk: Chunk, request: GenerateRequest) -> str:
"""Build user prompt for QA generation."""
thinking_instruction = (
"请先对内容做简短分析,识别核心事实、概念、关系与潜在考点,然后再生成问答。"
"分析过程只用于提高质量,不要在最终输出中暴露你的思维链。"
if request.thinking_mode
else "直接基于内容生成高质量问答。"
)
return (
f"{request.preset_prompt}\n\n"
"输出要求:\n"
f"1. 生成 {request.count} 组问答。\n"
"2. 只输出 JSON 数组不要输出解释、标题、Markdown。\n"
'3. 每个对象结构为 {"question":"...","answer":"...","question_type":"fact|summary|reasoning"}。\n'
"4. 问题避免重复,答案避免空泛。\n"
"5. 如果内容不足以生成高质量问答,请返回空数组 []。\n"
f"6. {thinking_instruction}\n\n"
f"Chunk 名称:{chunk.name or '未命名分片'}\n"
f"Chunk 内容:\n{chunk.content}"
)
def extract_text_from_response(data: dict) -> str:
"""Extract response text from provider response."""
choices = data.get("choices") or []
if choices:
message = choices[0].get("message") or {}
content = message.get("content")
if isinstance(content, str):
return content
if isinstance(content, list):
parts = [item.get("text", "") for item in content if isinstance(item, dict)]
return "\n".join(part for part in parts if part)
return ""
def parse_generated_questions(raw_text: str) -> List[dict]:
"""Parse JSON array from model output."""
text = (raw_text or "").strip()
if not text:
return []
fenced_match = re.search(r"```json\s*(.*?)\s*```", text, flags=re.S)
if fenced_match:
text = fenced_match.group(1).strip()
if not text.startswith("["):
array_match = re.search(r"(\[\s*\{.*\}\s*\])", text, flags=re.S)
if array_match:
text = array_match.group(1)
try:
parsed = json.loads(text)
except json.JSONDecodeError:
return []
if not isinstance(parsed, list):
return []
normalized = []
for item in parsed:
if not isinstance(item, dict):
continue
question = str(item.get("question", "")).strip()
answer = str(item.get("answer", "")).strip()
question_type = str(item.get("question_type", "fact")).strip() or "fact"
if not question or not answer:
continue
normalized.append({
"question": question,
"answer": answer,
"question_type": question_type
})
return normalized
async def call_generation_model(model: ModelConfig, prompt: str) -> str:
"""Call configured chat model for question generation."""
provider = model.provider
api_base = (model.api_base or "").rstrip("/")
api_key = model.api_key or ""
model_name = model.model_name
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
payload = {
"model": model_name,
"messages": [
{
"role": "system",
"content": "你是问答数据构建助手。严格按 JSON 输出,不要输出额外说明。"
},
{
"role": "user",
"content": prompt
}
],
"temperature": 0.4,
"response_format": {"type": "json_object"}
}
async with httpx.AsyncClient(timeout=120.0) as client:
if provider == "minimax":
response = await client.post(
f"{api_base}/chat/completions_v2",
headers=headers,
json={k: v for k, v in payload.items() if k != "response_format"}
)
else:
response = await client.post(
f"{api_base}/chat/completions",
headers=headers,
json=payload
)
response.raise_for_status()
data = response.json()
content = extract_text_from_response(data)
if not content:
raise ValueError("Model returned empty content")
if content.lstrip().startswith("{"):
obj = json.loads(content)
if isinstance(obj, dict) and isinstance(obj.get("questions"), list):
return json.dumps(obj["questions"], ensure_ascii=False)
return content
async def process_generate_async(project_id: UUID, request: GenerateRequest):
"""Generate QA pairs in background."""
async with AsyncSessionLocal() as db:
try:
model_result = await db.execute(
select(ModelConfig).where(ModelConfig.id == request.model_id, ModelConfig.project_id == None) # noqa: E711
)
model = model_result.scalar_one_or_none()
if not model:
return
model_type = normalize_model_type(model.model_type, model.model_name)
if model_type not in {"chat", "vlm"}:
raise ValidationException("Selected model must be chat/vlm type", field="model_id")
chunk_result = await db.execute(
select(Chunk).where(Chunk.id.in_(request.chunk_ids), Chunk.project_id == project_id)
)
chunks = chunk_result.scalars().all()
if not chunks:
return
created_count = 0
skipped_count = 0
for chunk in chunks:
if request.dirty_data_filter and is_dirty_chunk(chunk.content):
skipped_count += 1
continue
prompt = build_generation_prompt(chunk, request)
raw_text = await call_generation_model(model, prompt)
qa_pairs = parse_generated_questions(raw_text)[:request.count]
if not qa_pairs:
skipped_count += 1
continue
for item in qa_pairs:
db.add(Question(
project_id=project_id,
chunk_id=chunk.id,
content=item["question"],
answer=item["answer"],
question_type=item["question_type"],
source="generated"
))
created_count += 1
await db.commit()
log_success(
"问答批量生成完成",
project_id=str(project_id),
model_id=str(model.id),
chunk_count=len(chunks),
created_questions=created_count,
skipped_chunks=skipped_count
)
except Exception as e:
log_failure(
"问答批量生成失败",
project_id=str(project_id),
model_id=str(request.model_id),
error=str(e)
)
@router.post("/generate", response_model=ApiResponse)
@@ -34,36 +306,33 @@ async def generate_questions(
request: GenerateRequest,
db: AsyncSession = Depends(get_db)
):
"""Generate questions from chunks using LLM"""
# Get chunks
result = await db.execute(
select(Chunk).where(Chunk.id.in_(request.chunk_ids), Chunk.project_id == project_id)
"""Generate questions from chunks using LLM in background."""
model_result = await db.execute(
select(ModelConfig).where(ModelConfig.id == request.model_id, ModelConfig.project_id == None) # noqa: E711
)
chunks = result.scalars().all()
model = model_result.scalar_one_or_none()
if not model:
raise ValidationException("Selected model not found", field="model_id")
if not chunks:
model_type = normalize_model_type(model.model_type, model.model_name)
if model_type not in {"chat", "vlm"}:
raise ValidationException("Selected model must be chat/vlm type", field="model_id")
if not model.api_key:
raise ValidationException("Selected model is missing API Key", field="model_id")
chunk_result = await db.execute(
select(Chunk.id).where(Chunk.id.in_(request.chunk_ids), Chunk.project_id == project_id)
)
valid_chunk_ids = [row[0] for row in chunk_result.all()]
if not valid_chunk_ids:
raise ValidationException("No valid chunks found", field="chunk_ids")
# Create sample questions (placeholder for LLM-based generation)
created_questions = []
for chunk in chunks:
for i in range(request.count):
question = Question(
project_id=project_id,
chunk_id=chunk.id,
content=f"这是关于「{chunk.name}」的问题 {i+1}",
answer=f"这是问题 {i+1} 的答案。",
question_type=request.question_types[0] if request.question_types else "fact",
source="generated"
)
db.add(question)
created_questions.append(question)
await db.commit()
request_payload = request.model_copy(update={"chunk_ids": valid_chunk_ids})
asyncio.create_task(process_generate_async(project_id, request_payload))
return ApiResponse.ok(
data={"questions": len(created_questions)},
message=f"Successfully generated {len(created_questions)} questions"
data={"chunk_count": len(valid_chunk_ids), "status": "processing"},
message="Question generation started in background"
)

View File

@@ -7,7 +7,7 @@ from contextlib import asynccontextmanager
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy import create_engine, event
from sqlalchemy import create_engine, event, inspect, text
from sqlalchemy.pool import NullPool
from app.core.config import get_settings
@@ -65,9 +65,28 @@ async def init_db():
logger.info("Initializing database...")
async with async_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
await conn.run_sync(_ensure_legacy_columns)
logger.info("Database initialized successfully")
def _ensure_legacy_columns(sync_conn):
"""Patch legacy tables with newly introduced columns."""
inspector = inspect(sync_conn)
if "model_configs" not in inspector.get_table_names():
return
columns = {column["name"] for column in inspector.get_columns("model_configs")}
if "model_type" in columns:
return
logger.info("Adding missing model_type column to model_configs table")
dialect = sync_conn.dialect.name
if dialect == "postgresql":
sync_conn.execute(text("ALTER TABLE model_configs ADD COLUMN model_type VARCHAR(50) NOT NULL DEFAULT 'chat'"))
else:
sync_conn.execute(text("ALTER TABLE model_configs ADD COLUMN model_type VARCHAR(50) NOT NULL DEFAULT 'chat'"))
async def close_db():
"""Close database connections"""
logger.info("Closing database connections...")

View File

@@ -137,7 +137,8 @@ class ModelConfig(Base, UUIDMixin, TimestampMixin):
__tablename__ = "model_configs"
project_id = Column(UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"), nullable=True)
provider = Column(String(50), nullable=False) # minimax, glm, openai
provider = Column(String(50), nullable=False) # minimax, glm, openai, ali
model_type = Column(String(50), nullable=False, default="chat") # chat, vlm, embedding, rerank
model_name = Column(String(100))
api_key = Column(String(500))
api_base = Column(String(500))

View File

@@ -9,7 +9,8 @@ from uuid import UUID
class ModelBase(BaseModel):
"""Base model schema"""
provider: str = Field(..., description="Model provider: minimax, glm, openai")
provider: str = Field(..., description="Model provider: minimax, glm, openai, ali")
model_type: str = Field(default="chat", description="Model type: chat, vlm, embedding, rerank")
model_name: str = Field(..., description="Model name")
api_key: Optional[str] = Field(None, description="API key")
api_base: Optional[str] = Field(None, description="API base URL")
@@ -24,6 +25,7 @@ class ModelCreate(ModelBase):
class ModelUpdate(BaseModel):
"""Model update schema"""
provider: Optional[str] = None
model_type: Optional[str] = None
model_name: Optional[str] = None
api_key: Optional[str] = None
api_base: Optional[str] = None

View File

@@ -8,6 +8,7 @@ import httpx
import numpy as np
from typing import List, Dict, Optional
from abc import ABC, abstractmethod
from langchain_text_splitters import RecursiveCharacterTextSplitter
class EmbeddingProvider(ABC):
@@ -109,32 +110,28 @@ class EmbeddingSplitter:
def _tokenize_sentences(self, text: str) -> List[str]:
"""将文本切分为句子"""
# 中英文句末符号
# 先按换行分割,保持段落结构
paragraphs = re.split(r'\n+', text)
paragraphs = re.split(r'\n\s*\n+', text)
sentences = []
for para in paragraphs:
if not para.strip():
para = para.strip()
if not para:
continue
# 按句子符号分割
# 中文:。!?;
# 英文:. ! ? ;
parts = re.split(r'([。!?;\n]|(?<=[.!?])\s+)', para)
parts = re.split(r'(?<=[。!?;.!?])\s+|(?<=[。!?;])', para)
buffer = []
# 重新组合句子
current_sentence = ""
for part in parts:
if part in '。!?;.\n':
if current_sentence.strip():
sentences.append(current_sentence.strip())
current_sentence = ""
elif part and part.strip():
current_sentence += part
# 处理最后一个句子
if current_sentence.strip():
sentences.append(current_sentence.strip())
part = part.strip()
if not part:
continue
# 过短的片段先暂存,尽量与后一句合并,避免 embedding 粒度过碎
if len(part) < 8 and buffer:
buffer[-1] = f"{buffer[-1]} {part}".strip()
else:
buffer.append(part)
sentences.extend(buffer)
return sentences
@@ -162,51 +159,48 @@ class EmbeddingSplitter:
if not similarities:
return []
window = self.window_size
window = max(1, self.window_size)
smoothed = []
for i in range(len(similarities)):
start = max(0, i - window + 1)
end = i + 1
start = max(0, i - window)
end = min(len(similarities), i + window + 1)
window_vals = similarities[start:end]
smoothed.append(sum(window_vals) / len(window_vals))
return smoothed
def _detect_boundaries(self, similarities: List[float]) -> List[int]:
def _detect_boundaries(self, similarities: List[float], sentence_lengths: List[int]) -> List[int]:
"""检测分割点(相似度显著下降的位置)"""
if not similarities:
return [0]
# 平滑
smoothed = self._smooth_similarities(similarities)
# 计算深度分数(类似 TextTiling
depth_scores = []
for i in range(1, len(smoothed) - 1):
# 当前位置的深度 = 当前位置的值 - 平均值
# 但更准确的是:左侧平均 - 右侧平均
left_avg = sum(smoothed[max(0, i - self.window_size):i]) / self.window_size
right_avg = sum(smoothed[i:min(len(smoothed), i + self.window_size)]) / self.window_size
depth = left_avg - right_avg
depth_scores.append(depth)
# 如果没有足够的点,直接返回
if not depth_scores:
if len(smoothed) <= 1:
return [0]
# 阈值判断
mean_depth = np.mean(depth_scores)
std_depth = np.std(depth_scores)
# 找分割点depth 显著高于均值的位置
threshold = mean_depth + 0.5 * std_depth
mean_sim = float(np.mean(smoothed))
std_sim = float(np.std(smoothed))
dynamic_threshold = max(0.0, min(0.95, mean_sim - 0.5 * std_sim))
effective_threshold = max(self.similarity_threshold, dynamic_threshold)
boundaries = [0] # 起始点
for i, depth in enumerate(depth_scores):
if depth > threshold and depth > self.similarity_threshold:
boundaries.append(i + 1) # 对应相似度的下一个位置
boundaries.append(len(self._tokenize_sentences.__name__)) # 结束点
accumulated_chars = 0
for i, sim in enumerate(smoothed):
accumulated_chars += sentence_lengths[i]
left_sim = smoothed[i - 1] if i > 0 else 1.0
right_sim = smoothed[i + 1] if i < len(smoothed) - 1 else 1.0
is_local_min = sim <= left_sim and sim <= right_sim
has_enough_context = accumulated_chars >= self.min_chunk_size
oversize_guard = accumulated_chars >= self.chunk_size
if (is_local_min and has_enough_context and sim <= effective_threshold) or oversize_guard:
boundaries.append(i + 1)
accumulated_chars = 0
boundaries.append(len(sentence_lengths))
return sorted(list(set(boundaries)))
@@ -225,7 +219,12 @@ class EmbeddingSplitter:
for i in range(len(boundaries) - 1):
start = boundaries[i]
end = boundaries[i + 1]
chunk_text = ' '.join(sentences[start:end])
if start >= end:
continue
chunk_text = ' '.join(sentences[start:end]).strip()
if not chunk_text:
continue
# 如果 chunk 过大,递归分割
if len(chunk_text) > self.chunk_size * 1.5:
@@ -278,14 +277,22 @@ class EmbeddingSplitter:
merged = [chunks[0]]
for chunk in chunks[1:]:
# 如果前一个 chunk 太小,合并
if merged[-1]["char_count"] < self.min_chunk_size:
merged[-1]["content"] += " " + chunk["content"]
merged[-1]["word_count"] += chunk["word_count"]
merged[-1]["char_count"] += chunk["char_count"]
previous = merged[-1]
should_merge = (
previous["char_count"] < self.min_chunk_size or
chunk["char_count"] < self.min_chunk_size
)
if should_merge and previous["char_count"] + chunk["char_count"] <= self.chunk_size * 1.5:
previous["content"] += " " + chunk["content"]
previous["word_count"] += chunk["word_count"]
previous["char_count"] += chunk["char_count"]
else:
merged.append(chunk)
for index, chunk in enumerate(merged):
chunk["index"] = index
return merged
async def split_with_embedding(self, text: str) -> List[Dict]:
@@ -295,8 +302,8 @@ class EmbeddingSplitter:
if not sentences:
return []
# 过滤过短的句子
sentences = [s for s in sentences if len(s) >= 10]
# 过滤纯噪音片段,但保留正常短句
sentences = [s for s in sentences if len(s.strip()) >= 4]
if not sentences:
return []
@@ -312,17 +319,22 @@ class EmbeddingSplitter:
# 3. 调用 Embedding API
try:
if self.embedding_provider is None:
raise ValueError("embedding provider is not configured")
embeddings = await self.embedding_provider.get_embeddings(sentences)
except Exception as e:
# 如果 embedding 失败,降级到规则分割
print(f"Embedding failed, falling back to rule-based: {e}")
return self._fallback_split(text)
if len(embeddings) != len(sentences):
return self._fallback_split(text)
# 4. 计算相似度
similarities = self._compute_similarities(embeddings)
# 5. 检测分割点
boundaries = self._detect_boundaries(similarities)
boundaries = self._detect_boundaries(similarities, [len(sentence) for sentence in sentences])
# 6. 组装 chunks
chunks = self._assemble_chunks(sentences, boundaries)
@@ -387,7 +399,7 @@ class SemanticEmbeddingSplitter(EmbeddingSplitter):
def create_embedding_provider(provider: str, api_key: str, base_url: str, model: str = None) -> EmbeddingProvider:
"""创建 Embedding 提供商"""
if provider in ["openai", "compatible"]:
if provider in ["openai", "compatible", "ali", "glm"]:
return OpenAIEmbedding(api_key, base_url, model or "text-embedding-3-small")
elif provider == "minimax":
return MiniMaxEmbedding(api_key, base_url)

File diff suppressed because it is too large Load Diff

1
backend/uploads/.gitkeep Normal file
View File

@@ -0,0 +1 @@
# This file ensures the uploads directory is tracked in git

3657
backend/uv.lock generated Normal file

File diff suppressed because it is too large Load Diff

BIN
backend/ygdataset.db Normal file

Binary file not shown.