Files
YG-Datasets/backend/app/api/v1/chunks/__init__.py
Developer a280b4f014 feat(backend): 文件处理和语义分割 API 更新
- chunks API: 支持语义分割模式和 embedding 配置
- files API: 文件异步处理优化

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-19 10:11:59 +08:00

270 lines
8.6 KiB
Python

"""
Chunks API Router
"""
import asyncio
from pathlib import Path
from typing import List, Optional
from uuid import UUID
from pydantic import BaseModel, Field
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.api.response import ApiResponse, PaginatedResponse
from app.core.database import get_db
from app.core.exceptions import NotFoundException
from app.core.crud import CRUDBase
from app.core.logging import log_success, log_failure
from app.models.models import Chunk, File
from app.schemas.chunk import ChunkResponse
from app.schemas.chunk import ChunkCreateSchema, ChunkUpdateSchema
from app.services.text_splitter.splitter import get_splitter
from markitdown import MarkItDown
router = APIRouter()
# Initialize CRUD
chunk_crud = CRUDBase(Chunk)
# Initialize markitdown
markitdown = MarkItDown()
def get_project_ready_dir(project_id: str) -> Path:
"""获取项目的 ready 文件目录"""
base_dir = Path("/data/code/YG-Datasets/data") / project_id / "ready"
base_dir.mkdir(parents=True, exist_ok=True)
return base_dir
class SplitRequest(BaseModel):
"""Request model for splitting text"""
file_id: UUID
method: str = "recursive"
chunk_size: int = Field(500, ge=50, le=5000)
overlap: int = Field(50, ge=0, le=500)
separator: Optional[str] = None
# Embedding 相关参数(用于 semantic_embedding 方法)
embedding_provider: Optional[str] = Field(None, description="embedding provider: openai, minimax")
embedding_api_key: Optional[str] = Field(None, description="API key for embedding")
embedding_base_url: Optional[str] = Field(None, description="API base URL")
embedding_model: Optional[str] = Field(None, description="Embedding model name")
# 语义分割参数
similarity_threshold: float = Field(0.3, ge=0.0, le=1.0, description="Similarity threshold for semantic split")
min_chunk_size: int = Field(100, ge=10, le=1000, description="Minimum chunk size")
async def process_file_by_type(file: File) -> str:
"""Process file based on its type, convert to markdown"""
if not file.file_path:
raise NotFoundException("File", file.id)
# Supported types for markitdown
markitdown_types = ["pdf", "docx", "doc", "pptx", "ppt", "xlsx", "xls", "htm", "html"]
if file.file_type in markitdown_types:
# Use markitdown to convert to markdown
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None,
lambda: markitdown.convert(file.file_path)
)
return result.text_content
# Return raw text for txt, md files
loop = asyncio.get_event_loop()
content = await loop.run_in_executor(
None,
lambda: open(file.file_path, 'r', encoding='utf-8').read()
)
return content
@router.post("/split", response_model=ApiResponse)
async def split_text(
project_id: UUID,
request: SplitRequest,
db: AsyncSession = Depends(get_db)
):
"""Split text into chunks"""
try:
# Get file
result = await db.execute(
select(File).where(File.id == request.file_id, File.project_id == project_id)
)
file = result.scalar_one_or_none()
if not file:
raise NotFoundException("File", request.file_id)
# 记录开始处理
log_success(
"开始处理文件",
project_id=str(project_id),
file_id=str(file.id),
filename=file.filename,
method=request.method,
chunk_size=request.chunk_size,
overlap=request.overlap
)
# Process file
text = await process_file_by_type(file)
# Update file status
file.status = "processing"
await db.commit()
# Split text
kwargs = {"chunk_size": request.chunk_size, "overlap": request.overlap}
if request.method == "custom" and request.separator:
kwargs["separator"] = request.separator
# 如果使用 semantic_embedding 方法,传递 embedding 参数
if request.method == "semantic_embedding":
kwargs["embedding_provider_type"] = request.embedding_provider or "openai"
kwargs["embedding_api_key"] = request.embedding_api_key
kwargs["embedding_base_url"] = request.embedding_base_url or "https://api.minimax.chat/v1"
kwargs["embedding_model"] = request.embedding_model or "text-embedding-3-small"
kwargs["similarity_threshold"] = request.similarity_threshold
kwargs["min_chunk_size"] = request.min_chunk_size
splitter = get_splitter(request.method, **kwargs)
split_results = splitter.split(text)
# Save chunks
chunks = []
for chunk_data in split_results:
db_chunk = Chunk(
project_id=project_id,
file_id=file.id,
name=chunk_data.get("name", f"Chunk {chunk_data['index'] + 1}"),
content=chunk_data["content"],
word_count=chunk_data.get("word_count", len(chunk_data["content"].split()))
)
db.add(db_chunk)
chunks.append(db_chunk)
await db.commit()
# Save processed markdown to ready directory
ready_dir = get_project_ready_dir(str(project_id))
md_filename = f"{file.id}_{file.filename}.md"
md_path = ready_dir / md_filename
# Write markdown content to file
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
lambda: md_path.write_text(text, encoding='utf-8')
)
# Update file path to ready location
file.file_path = str(md_path)
file.status = "completed"
await db.commit()
# 记录成功日志
log_success(
"文件处理完成",
project_id=str(project_id),
file_id=str(file.id),
filename=file.filename,
chunk_count=len(chunks),
text_length=len(text),
ready_path=str(md_path)
)
return ApiResponse.ok(
data={"chunks": len(chunks)},
message=f"Successfully split into {len(chunks)} chunks"
)
except Exception as e:
# 记录失败日志
log_failure(
"文件处理失败",
project_id=str(project_id),
file_id=str(request.file_id),
error=str(e)
)
raise
@router.get("", response_model=ApiResponse)
async def list_chunks(
project_id: UUID,
file_id: Optional[UUID] = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
db: AsyncSession = Depends(get_db)
):
"""List chunks for a project"""
filters = {"project_id": project_id}
if file_id:
filters["file_id"] = file_id
skip = (page - 1) * page_size
chunks, total = await chunk_crud.get_multi(
db,
skip=skip,
limit=page_size,
filters=filters,
order_by="created_at",
descending=False
)
chunk_responses = [ChunkResponse.model_validate(c) for c in chunks]
return PaginatedResponse.ok(
items=chunk_responses,
page=page,
page_size=page_size,
total=total
)
@router.get("/{chunk_id}", response_model=ApiResponse)
async def get_chunk(
project_id: UUID,
chunk_id: UUID,
db: AsyncSession = Depends(get_db)
):
"""Get chunk by ID"""
chunk = await chunk_crud.get(db, chunk_id)
if not chunk or chunk.project_id != project_id:
raise NotFoundException("Chunk", chunk_id)
return ApiResponse.ok(data=ChunkResponse.model_validate(chunk))
@router.put("/{chunk_id}", response_model=ApiResponse)
async def update_chunk(
project_id: UUID,
chunk_id: UUID,
chunk: ChunkUpdateSchema,
db: AsyncSession = Depends(get_db)
):
"""Update chunk"""
db_chunk = await chunk_crud.get(db, chunk_id)
if not db_chunk or db_chunk.project_id != project_id:
raise NotFoundException("Chunk", chunk_id)
updated_chunk = await chunk_crud.update(db, db_chunk, chunk)
return ApiResponse.ok(
data=ChunkResponse.model_validate(updated_chunk),
message="Chunk updated successfully"
)
@router.delete("/{chunk_id}", response_model=ApiResponse)
async def delete_chunk(
project_id: UUID,
chunk_id: UUID,
db: AsyncSession = Depends(get_db)
):
"""Delete chunk"""
chunk = await chunk_crud.get(db, chunk_id)
if not chunk or chunk.project_id != project_id:
raise NotFoundException("Chunk", chunk_id)
await chunk_crud.delete(db, chunk_id)
return ApiResponse.ok(message="Chunk deleted successfully")