Files
YG-Datasets/backend/app/api/v1/chunks/__init__.py

315 lines
10 KiB
Python
Raw Normal View History

2026-03-17 14:36:31 +08:00
"""
Chunks API Router
"""
import asyncio
from pathlib import Path
2026-03-17 14:36:31 +08:00
from typing import List, Optional
from uuid import UUID
from pydantic import BaseModel, Field
2026-03-17 14:36:31 +08:00
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.api.response import ApiResponse, PaginatedResponse
from app.core.database import get_db, AsyncSessionLocal
from app.core.exceptions import NotFoundException
from app.core.crud import CRUDBase
from app.core.logging import log_success, log_failure
2026-03-17 14:36:31 +08:00
from app.models.models import Chunk, File
from app.schemas.chunk import ChunkResponse
from app.schemas.chunk import ChunkCreateSchema, ChunkUpdateSchema
2026-03-17 14:36:31 +08:00
from app.services.text_splitter.splitter import get_splitter
from markitdown import MarkItDown
2026-03-17 14:36:31 +08:00
router = APIRouter()
# Initialize CRUD
chunk_crud = CRUDBase(Chunk)
# Initialize markitdown
markitdown = MarkItDown()
def get_project_ready_dir(project_id: str) -> Path:
"""获取项目的 ready 文件目录"""
base_dir = Path("/data/code/YG-Datasets/data") / project_id / "ready"
base_dir.mkdir(parents=True, exist_ok=True)
return base_dir
2026-03-17 14:36:31 +08:00
class SplitRequest(BaseModel):
"""Request model for splitting text"""
file_id: UUID
2026-03-17 14:36:31 +08:00
method: str = "recursive"
chunk_size: int = Field(500, ge=50, le=5000)
overlap: int = Field(50, ge=0, le=500)
2026-03-17 14:36:31 +08:00
separator: Optional[str] = None
# Embedding 相关参数(用于 semantic_embedding 方法)
embedding_provider: Optional[str] = Field(None, description="embedding provider: openai, minimax")
embedding_api_key: Optional[str] = Field(None, description="API key for embedding")
embedding_base_url: Optional[str] = Field(None, description="API base URL")
embedding_model: Optional[str] = Field(None, description="Embedding model name")
# 语义分割参数
similarity_threshold: float = Field(0.3, ge=0.0, le=1.0, description="Similarity threshold for semantic split")
min_chunk_size: int = Field(100, ge=10, le=1000, description="Minimum chunk size")
2026-03-17 14:36:31 +08:00
async def process_file_by_type(file: File) -> str:
"""Process file based on its type, convert to markdown"""
2026-03-17 14:36:31 +08:00
if not file.file_path:
raise NotFoundException("File", file.id)
2026-03-17 14:36:31 +08:00
# Supported types for markitdown
markitdown_types = ["pdf", "docx", "doc", "pptx", "ppt", "xlsx", "xls", "htm", "html"]
2026-03-17 14:36:31 +08:00
if file.file_type in markitdown_types:
# Use markitdown to convert to markdown
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None,
lambda: markitdown.convert(file.file_path)
)
return result.text_content
2026-03-17 14:36:31 +08:00
# Return raw text for txt, md files
loop = asyncio.get_event_loop()
content = await loop.run_in_executor(
None,
lambda: open(file.file_path, 'r', encoding='utf-8').read()
)
return content
2026-03-17 14:36:31 +08:00
async def process_split_async(
project_id: UUID,
request: SplitRequest,
):
"""Run chunk splitting in background."""
async with AsyncSessionLocal() as db:
file = None
try:
result = await db.execute(
select(File).where(File.id == request.file_id, File.project_id == project_id)
)
file = result.scalar_one_or_none()
if not file:
return
text = await process_file_by_type(file)
kwargs = {"chunk_size": request.chunk_size, "overlap": request.overlap}
if request.method == "custom" and request.separator:
kwargs["separator"] = request.separator
if request.method == "semantic_embedding":
kwargs["embedding_provider_type"] = request.embedding_provider or "openai"
kwargs["embedding_api_key"] = request.embedding_api_key
kwargs["embedding_base_url"] = request.embedding_base_url or "https://api.minimax.chat/v1"
kwargs["embedding_model"] = request.embedding_model or "text-embedding-3-small"
kwargs["similarity_threshold"] = request.similarity_threshold
kwargs["min_chunk_size"] = request.min_chunk_size
splitter = get_splitter(request.method, **kwargs)
split_results = splitter.split(text)
await db.execute(
Chunk.__table__.delete().where(
Chunk.project_id == project_id,
Chunk.file_id == file.id
)
)
chunks = []
for chunk_data in split_results:
db_chunk = Chunk(
project_id=project_id,
file_id=file.id,
name=chunk_data.get("name", f"Chunk {chunk_data['index'] + 1}"),
content=chunk_data["content"],
word_count=chunk_data.get("word_count", len(chunk_data["content"].split()))
)
db.add(db_chunk)
chunks.append(db_chunk)
await db.commit()
ready_dir = get_project_ready_dir(str(project_id))
# 删除旧的 markdown 文件(可能有两种命名格式)
old_md_files = list(ready_dir.glob(f"{file.id}*.md"))
for old_file in old_md_files:
try:
old_file.unlink()
except Exception:
pass
md_filename = f"{file.id}.md"
md_path = ready_dir / md_filename
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
lambda: md_path.write_text(text, encoding='utf-8')
)
file.file_path = str(md_path)
file.status = "completed"
await db.commit()
log_success(
"文件分割完成",
project_id=str(project_id),
file_id=str(file.id),
filename=file.filename,
method=request.method,
chunk_count=len(chunks),
text_length=len(text),
ready_path=str(md_path)
)
except Exception as e:
if file:
file.status = "failed"
await db.commit()
log_failure(
"文件分割失败",
project_id=str(project_id),
file_id=str(request.file_id),
method=request.method,
error=str(e)
)
@router.post("/split", response_model=ApiResponse)
2026-03-17 14:36:31 +08:00
async def split_text(
project_id: UUID,
request: SplitRequest,
db: AsyncSession = Depends(get_db)
):
"""Split text into chunks"""
try:
result = await db.execute(
select(File).where(File.id == request.file_id, File.project_id == project_id)
)
file = result.scalar_one_or_none()
if not file:
raise NotFoundException("File", request.file_id)
# 记录开始处理
log_success(
"开始处理文件",
project_id=str(project_id),
file_id=str(file.id),
filename=file.filename,
method=request.method,
chunk_size=request.chunk_size,
overlap=request.overlap
2026-03-17 14:36:31 +08:00
)
file.status = "processing"
await db.commit()
asyncio.create_task(
process_split_async(
project_id=project_id,
request=request,
)
)
2026-03-17 14:36:31 +08:00
return ApiResponse.ok(
data={"file_id": str(file.id), "status": file.status},
message="Split task started, processing in background"
)
except Exception as e:
if 'file' in locals() and file:
file.status = "failed"
await db.commit()
log_failure(
"分割任务启动失败",
project_id=str(project_id),
file_id=str(request.file_id),
error=str(e)
)
raise
2026-03-17 14:36:31 +08:00
@router.get("", response_model=ApiResponse)
2026-03-17 14:36:31 +08:00
async def list_chunks(
project_id: UUID,
file_id: Optional[UUID] = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
2026-03-17 14:36:31 +08:00
db: AsyncSession = Depends(get_db)
):
"""List chunks for a project"""
filters = {"project_id": project_id}
2026-03-17 14:36:31 +08:00
if file_id:
filters["file_id"] = file_id
skip = (page - 1) * page_size
chunks, total = await chunk_crud.get_multi(
db,
skip=skip,
limit=page_size,
filters=filters,
order_by="created_at",
descending=False
)
2026-03-17 14:36:31 +08:00
chunk_responses = [ChunkResponse.model_validate(c) for c in chunks]
return PaginatedResponse.ok(
items=chunk_responses,
page=page,
page_size=page_size,
total=total
)
2026-03-17 14:36:31 +08:00
@router.get("/{chunk_id}", response_model=ApiResponse)
async def get_chunk(
project_id: UUID,
chunk_id: UUID,
db: AsyncSession = Depends(get_db)
):
2026-03-17 14:36:31 +08:00
"""Get chunk by ID"""
chunk = await chunk_crud.get(db, chunk_id)
if not chunk or chunk.project_id != project_id:
raise NotFoundException("Chunk", chunk_id)
2026-03-17 14:36:31 +08:00
return ApiResponse.ok(data=ChunkResponse.model_validate(chunk))
2026-03-17 14:36:31 +08:00
@router.put("/{chunk_id}", response_model=ApiResponse)
2026-03-17 14:36:31 +08:00
async def update_chunk(
project_id: UUID,
chunk_id: UUID,
chunk: ChunkUpdateSchema,
2026-03-17 14:36:31 +08:00
db: AsyncSession = Depends(get_db)
):
"""Update chunk"""
db_chunk = await chunk_crud.get(db, chunk_id)
if not db_chunk or db_chunk.project_id != project_id:
raise NotFoundException("Chunk", chunk_id)
updated_chunk = await chunk_crud.update(db, db_chunk, chunk)
return ApiResponse.ok(
data=ChunkResponse.model_validate(updated_chunk),
message="Chunk updated successfully"
2026-03-17 14:36:31 +08:00
)
@router.delete("/{chunk_id}", response_model=ApiResponse)
async def delete_chunk(
project_id: UUID,
chunk_id: UUID,
db: AsyncSession = Depends(get_db)
):
2026-03-17 14:36:31 +08:00
"""Delete chunk"""
chunk = await chunk_crud.get(db, chunk_id)
if not chunk or chunk.project_id != project_id:
raise NotFoundException("Chunk", chunk_id)
2026-03-17 14:36:31 +08:00
await chunk_crud.delete(db, chunk_id)
return ApiResponse.ok(message="Chunk deleted successfully")