2026-03-17 14:36:31 +08:00
|
|
|
"""
|
|
|
|
|
Chunks API Router
|
|
|
|
|
"""
|
2026-03-17 17:29:58 +08:00
|
|
|
import asyncio
|
2026-03-18 10:44:09 +08:00
|
|
|
from pathlib import Path
|
2026-03-17 14:36:31 +08:00
|
|
|
from typing import List, Optional
|
|
|
|
|
from uuid import UUID
|
2026-03-17 17:29:58 +08:00
|
|
|
from pydantic import BaseModel, Field
|
2026-03-17 14:36:31 +08:00
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
from sqlalchemy import select
|
2026-03-17 17:29:58 +08:00
|
|
|
|
|
|
|
|
from app.api.response import ApiResponse, PaginatedResponse
|
2026-03-19 14:23:34 +08:00
|
|
|
from app.core.database import get_db, AsyncSessionLocal
|
2026-03-17 17:29:58 +08:00
|
|
|
from app.core.exceptions import NotFoundException
|
|
|
|
|
from app.core.crud import CRUDBase
|
2026-03-18 10:44:09 +08:00
|
|
|
from app.core.logging import log_success, log_failure
|
2026-03-17 14:36:31 +08:00
|
|
|
from app.models.models import Chunk, File
|
2026-03-17 17:29:58 +08:00
|
|
|
from app.schemas.chunk import ChunkResponse
|
2026-03-19 10:11:59 +08:00
|
|
|
from app.schemas.chunk import ChunkCreateSchema, ChunkUpdateSchema
|
2026-03-17 14:36:31 +08:00
|
|
|
from app.services.text_splitter.splitter import get_splitter
|
2026-03-18 10:44:09 +08:00
|
|
|
from markitdown import MarkItDown
|
2026-03-17 14:36:31 +08:00
|
|
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
2026-03-17 17:29:58 +08:00
|
|
|
# Initialize CRUD
|
|
|
|
|
chunk_crud = CRUDBase(Chunk)
|
|
|
|
|
|
2026-03-18 10:44:09 +08:00
|
|
|
# Initialize markitdown
|
|
|
|
|
markitdown = MarkItDown()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_project_ready_dir(project_id: str) -> Path:
|
|
|
|
|
"""获取项目的 ready 文件目录"""
|
|
|
|
|
base_dir = Path("/data/code/YG-Datasets/data") / project_id / "ready"
|
|
|
|
|
base_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
return base_dir
|
|
|
|
|
|
2026-03-17 14:36:31 +08:00
|
|
|
|
|
|
|
|
class SplitRequest(BaseModel):
|
|
|
|
|
"""Request model for splitting text"""
|
2026-03-17 17:29:58 +08:00
|
|
|
file_id: UUID
|
2026-03-17 14:36:31 +08:00
|
|
|
method: str = "recursive"
|
2026-03-17 17:29:58 +08:00
|
|
|
chunk_size: int = Field(500, ge=50, le=5000)
|
|
|
|
|
overlap: int = Field(50, ge=0, le=500)
|
2026-03-17 14:36:31 +08:00
|
|
|
separator: Optional[str] = None
|
2026-03-18 16:08:08 +08:00
|
|
|
# Embedding 相关参数(用于 semantic_embedding 方法)
|
|
|
|
|
embedding_provider: Optional[str] = Field(None, description="embedding provider: openai, minimax")
|
|
|
|
|
embedding_api_key: Optional[str] = Field(None, description="API key for embedding")
|
|
|
|
|
embedding_base_url: Optional[str] = Field(None, description="API base URL")
|
|
|
|
|
embedding_model: Optional[str] = Field(None, description="Embedding model name")
|
|
|
|
|
# 语义分割参数
|
|
|
|
|
similarity_threshold: float = Field(0.3, ge=0.0, le=1.0, description="Similarity threshold for semantic split")
|
|
|
|
|
min_chunk_size: int = Field(100, ge=10, le=1000, description="Minimum chunk size")
|
2026-03-17 14:36:31 +08:00
|
|
|
|
|
|
|
|
|
2026-03-17 17:29:58 +08:00
|
|
|
async def process_file_by_type(file: File) -> str:
|
2026-03-18 10:44:09 +08:00
|
|
|
"""Process file based on its type, convert to markdown"""
|
2026-03-17 14:36:31 +08:00
|
|
|
if not file.file_path:
|
2026-03-17 17:29:58 +08:00
|
|
|
raise NotFoundException("File", file.id)
|
2026-03-17 14:36:31 +08:00
|
|
|
|
2026-03-18 10:44:09 +08:00
|
|
|
# Supported types for markitdown
|
|
|
|
|
markitdown_types = ["pdf", "docx", "doc", "pptx", "ppt", "xlsx", "xls", "htm", "html"]
|
2026-03-17 14:36:31 +08:00
|
|
|
|
2026-03-18 10:44:09 +08:00
|
|
|
if file.file_type in markitdown_types:
|
|
|
|
|
# Use markitdown to convert to markdown
|
2026-03-17 17:29:58 +08:00
|
|
|
loop = asyncio.get_event_loop()
|
2026-03-18 10:44:09 +08:00
|
|
|
result = await loop.run_in_executor(
|
2026-03-17 17:29:58 +08:00
|
|
|
None,
|
2026-03-18 10:44:09 +08:00
|
|
|
lambda: markitdown.convert(file.file_path)
|
2026-03-17 17:29:58 +08:00
|
|
|
)
|
2026-03-18 10:44:09 +08:00
|
|
|
return result.text_content
|
2026-03-17 14:36:31 +08:00
|
|
|
|
2026-03-18 10:44:09 +08:00
|
|
|
# Return raw text for txt, md files
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
content = await loop.run_in_executor(
|
|
|
|
|
None,
|
|
|
|
|
lambda: open(file.file_path, 'r', encoding='utf-8').read()
|
|
|
|
|
)
|
|
|
|
|
return content
|
2026-03-17 14:36:31 +08:00
|
|
|
|
|
|
|
|
|
2026-03-19 14:23:34 +08:00
|
|
|
async def process_split_async(
|
|
|
|
|
project_id: UUID,
|
|
|
|
|
request: SplitRequest,
|
|
|
|
|
):
|
|
|
|
|
"""Run chunk splitting in background."""
|
|
|
|
|
async with AsyncSessionLocal() as db:
|
|
|
|
|
file = None
|
|
|
|
|
try:
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
select(File).where(File.id == request.file_id, File.project_id == project_id)
|
|
|
|
|
)
|
|
|
|
|
file = result.scalar_one_or_none()
|
|
|
|
|
if not file:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
text = await process_file_by_type(file)
|
|
|
|
|
|
|
|
|
|
kwargs = {"chunk_size": request.chunk_size, "overlap": request.overlap}
|
|
|
|
|
if request.method == "custom" and request.separator:
|
|
|
|
|
kwargs["separator"] = request.separator
|
|
|
|
|
|
|
|
|
|
if request.method == "semantic_embedding":
|
|
|
|
|
kwargs["embedding_provider_type"] = request.embedding_provider or "openai"
|
|
|
|
|
kwargs["embedding_api_key"] = request.embedding_api_key
|
|
|
|
|
kwargs["embedding_base_url"] = request.embedding_base_url or "https://api.minimax.chat/v1"
|
|
|
|
|
kwargs["embedding_model"] = request.embedding_model or "text-embedding-3-small"
|
|
|
|
|
kwargs["similarity_threshold"] = request.similarity_threshold
|
|
|
|
|
kwargs["min_chunk_size"] = request.min_chunk_size
|
|
|
|
|
|
|
|
|
|
splitter = get_splitter(request.method, **kwargs)
|
|
|
|
|
split_results = splitter.split(text)
|
|
|
|
|
|
|
|
|
|
await db.execute(
|
|
|
|
|
Chunk.__table__.delete().where(
|
|
|
|
|
Chunk.project_id == project_id,
|
|
|
|
|
Chunk.file_id == file.id
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
chunks = []
|
|
|
|
|
for chunk_data in split_results:
|
|
|
|
|
db_chunk = Chunk(
|
|
|
|
|
project_id=project_id,
|
|
|
|
|
file_id=file.id,
|
|
|
|
|
name=chunk_data.get("name", f"Chunk {chunk_data['index'] + 1}"),
|
|
|
|
|
content=chunk_data["content"],
|
|
|
|
|
word_count=chunk_data.get("word_count", len(chunk_data["content"].split()))
|
|
|
|
|
)
|
|
|
|
|
db.add(db_chunk)
|
|
|
|
|
chunks.append(db_chunk)
|
|
|
|
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
|
|
|
|
ready_dir = get_project_ready_dir(str(project_id))
|
|
|
|
|
|
|
|
|
|
# 删除旧的 markdown 文件(可能有两种命名格式)
|
|
|
|
|
old_md_files = list(ready_dir.glob(f"{file.id}*.md"))
|
|
|
|
|
for old_file in old_md_files:
|
|
|
|
|
try:
|
|
|
|
|
old_file.unlink()
|
|
|
|
|
except Exception:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
md_filename = f"{file.id}.md"
|
|
|
|
|
md_path = ready_dir / md_filename
|
|
|
|
|
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
await loop.run_in_executor(
|
|
|
|
|
None,
|
|
|
|
|
lambda: md_path.write_text(text, encoding='utf-8')
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
file.file_path = str(md_path)
|
|
|
|
|
file.status = "completed"
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
|
|
|
|
log_success(
|
|
|
|
|
"文件分割完成",
|
|
|
|
|
project_id=str(project_id),
|
|
|
|
|
file_id=str(file.id),
|
|
|
|
|
filename=file.filename,
|
|
|
|
|
method=request.method,
|
|
|
|
|
chunk_count=len(chunks),
|
|
|
|
|
text_length=len(text),
|
|
|
|
|
ready_path=str(md_path)
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
if file:
|
|
|
|
|
file.status = "failed"
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
|
|
|
|
log_failure(
|
|
|
|
|
"文件分割失败",
|
|
|
|
|
project_id=str(project_id),
|
|
|
|
|
file_id=str(request.file_id),
|
|
|
|
|
method=request.method,
|
|
|
|
|
error=str(e)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2026-03-17 17:29:58 +08:00
|
|
|
@router.post("/split", response_model=ApiResponse)
|
2026-03-17 14:36:31 +08:00
|
|
|
async def split_text(
|
|
|
|
|
project_id: UUID,
|
|
|
|
|
request: SplitRequest,
|
|
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
|
|
|
|
"""Split text into chunks"""
|
2026-03-18 10:44:09 +08:00
|
|
|
try:
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
select(File).where(File.id == request.file_id, File.project_id == project_id)
|
|
|
|
|
)
|
|
|
|
|
file = result.scalar_one_or_none()
|
|
|
|
|
if not file:
|
|
|
|
|
raise NotFoundException("File", request.file_id)
|
|
|
|
|
|
|
|
|
|
# 记录开始处理
|
|
|
|
|
log_success(
|
|
|
|
|
"开始处理文件",
|
|
|
|
|
project_id=str(project_id),
|
|
|
|
|
file_id=str(file.id),
|
|
|
|
|
filename=file.filename,
|
|
|
|
|
method=request.method,
|
|
|
|
|
chunk_size=request.chunk_size,
|
|
|
|
|
overlap=request.overlap
|
2026-03-17 14:36:31 +08:00
|
|
|
)
|
|
|
|
|
|
2026-03-18 10:44:09 +08:00
|
|
|
file.status = "processing"
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
2026-03-19 14:23:34 +08:00
|
|
|
asyncio.create_task(
|
|
|
|
|
process_split_async(
|
2026-03-18 10:44:09 +08:00
|
|
|
project_id=project_id,
|
2026-03-19 14:23:34 +08:00
|
|
|
request=request,
|
2026-03-18 10:44:09 +08:00
|
|
|
)
|
|
|
|
|
)
|
2026-03-17 14:36:31 +08:00
|
|
|
|
2026-03-18 10:44:09 +08:00
|
|
|
return ApiResponse.ok(
|
2026-03-19 14:23:34 +08:00
|
|
|
data={"file_id": str(file.id), "status": file.status},
|
|
|
|
|
message="Split task started, processing in background"
|
2026-03-18 10:44:09 +08:00
|
|
|
)
|
|
|
|
|
except Exception as e:
|
2026-03-19 14:23:34 +08:00
|
|
|
if 'file' in locals() and file:
|
|
|
|
|
file.status = "failed"
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
2026-03-18 10:44:09 +08:00
|
|
|
log_failure(
|
2026-03-19 14:23:34 +08:00
|
|
|
"分割任务启动失败",
|
2026-03-18 10:44:09 +08:00
|
|
|
project_id=str(project_id),
|
|
|
|
|
file_id=str(request.file_id),
|
|
|
|
|
error=str(e)
|
|
|
|
|
)
|
|
|
|
|
raise
|
2026-03-17 14:36:31 +08:00
|
|
|
|
|
|
|
|
|
2026-03-17 17:29:58 +08:00
|
|
|
@router.get("", response_model=ApiResponse)
|
2026-03-17 14:36:31 +08:00
|
|
|
async def list_chunks(
|
|
|
|
|
project_id: UUID,
|
|
|
|
|
file_id: Optional[UUID] = Query(None),
|
2026-03-17 17:29:58 +08:00
|
|
|
page: int = Query(1, ge=1),
|
|
|
|
|
page_size: int = Query(20, ge=1, le=100),
|
2026-03-17 14:36:31 +08:00
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
|
|
|
|
"""List chunks for a project"""
|
2026-03-17 17:29:58 +08:00
|
|
|
filters = {"project_id": project_id}
|
2026-03-17 14:36:31 +08:00
|
|
|
if file_id:
|
2026-03-17 17:29:58 +08:00
|
|
|
filters["file_id"] = file_id
|
|
|
|
|
|
|
|
|
|
skip = (page - 1) * page_size
|
|
|
|
|
chunks, total = await chunk_crud.get_multi(
|
|
|
|
|
db,
|
|
|
|
|
skip=skip,
|
|
|
|
|
limit=page_size,
|
|
|
|
|
filters=filters,
|
|
|
|
|
order_by="created_at",
|
2026-03-19 10:11:59 +08:00
|
|
|
descending=False
|
2026-03-17 17:29:58 +08:00
|
|
|
)
|
2026-03-17 14:36:31 +08:00
|
|
|
|
2026-03-17 17:29:58 +08:00
|
|
|
chunk_responses = [ChunkResponse.model_validate(c) for c in chunks]
|
|
|
|
|
return PaginatedResponse.ok(
|
|
|
|
|
items=chunk_responses,
|
|
|
|
|
page=page,
|
|
|
|
|
page_size=page_size,
|
|
|
|
|
total=total
|
|
|
|
|
)
|
2026-03-17 14:36:31 +08:00
|
|
|
|
|
|
|
|
|
2026-03-17 17:29:58 +08:00
|
|
|
@router.get("/{chunk_id}", response_model=ApiResponse)
|
|
|
|
|
async def get_chunk(
|
|
|
|
|
project_id: UUID,
|
|
|
|
|
chunk_id: UUID,
|
|
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
2026-03-17 14:36:31 +08:00
|
|
|
"""Get chunk by ID"""
|
2026-03-17 17:29:58 +08:00
|
|
|
chunk = await chunk_crud.get(db, chunk_id)
|
|
|
|
|
if not chunk or chunk.project_id != project_id:
|
|
|
|
|
raise NotFoundException("Chunk", chunk_id)
|
2026-03-17 14:36:31 +08:00
|
|
|
|
2026-03-17 17:29:58 +08:00
|
|
|
return ApiResponse.ok(data=ChunkResponse.model_validate(chunk))
|
2026-03-17 14:36:31 +08:00
|
|
|
|
2026-03-17 17:29:58 +08:00
|
|
|
|
|
|
|
|
@router.put("/{chunk_id}", response_model=ApiResponse)
|
2026-03-17 14:36:31 +08:00
|
|
|
async def update_chunk(
|
|
|
|
|
project_id: UUID,
|
|
|
|
|
chunk_id: UUID,
|
2026-03-19 10:11:59 +08:00
|
|
|
chunk: ChunkUpdateSchema,
|
2026-03-17 14:36:31 +08:00
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
|
|
|
|
"""Update chunk"""
|
2026-03-17 17:29:58 +08:00
|
|
|
db_chunk = await chunk_crud.get(db, chunk_id)
|
|
|
|
|
if not db_chunk or db_chunk.project_id != project_id:
|
|
|
|
|
raise NotFoundException("Chunk", chunk_id)
|
|
|
|
|
|
|
|
|
|
updated_chunk = await chunk_crud.update(db, db_chunk, chunk)
|
|
|
|
|
return ApiResponse.ok(
|
|
|
|
|
data=ChunkResponse.model_validate(updated_chunk),
|
|
|
|
|
message="Chunk updated successfully"
|
2026-03-17 14:36:31 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2026-03-17 17:29:58 +08:00
|
|
|
@router.delete("/{chunk_id}", response_model=ApiResponse)
|
|
|
|
|
async def delete_chunk(
|
|
|
|
|
project_id: UUID,
|
|
|
|
|
chunk_id: UUID,
|
|
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
2026-03-17 14:36:31 +08:00
|
|
|
"""Delete chunk"""
|
2026-03-17 17:29:58 +08:00
|
|
|
chunk = await chunk_crud.get(db, chunk_id)
|
|
|
|
|
if not chunk or chunk.project_id != project_id:
|
|
|
|
|
raise NotFoundException("Chunk", chunk_id)
|
2026-03-17 14:36:31 +08:00
|
|
|
|
2026-03-17 17:29:58 +08:00
|
|
|
await chunk_crud.delete(db, chunk_id)
|
|
|
|
|
return ApiResponse.ok(message="Chunk deleted successfully")
|