""" Chunks API Router """ import asyncio from pathlib import Path from typing import List, Optional from uuid import UUID from pydantic import BaseModel, Field from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from app.api.response import ApiResponse, PaginatedResponse from app.core.database import get_db, AsyncSessionLocal from app.core.exceptions import NotFoundException from app.core.crud import CRUDBase from app.core.logging import log_success, log_failure from app.models.models import Chunk, File from app.schemas.chunk import ChunkResponse from app.schemas.chunk import ChunkCreateSchema, ChunkUpdateSchema from app.services.text_splitter.splitter import get_splitter from markitdown import MarkItDown router = APIRouter() # Initialize CRUD chunk_crud = CRUDBase(Chunk) # Initialize markitdown markitdown = MarkItDown() def get_project_ready_dir(project_id: str) -> Path: """获取项目的 ready 文件目录""" base_dir = Path("/data/code/YG-Datasets/data") / project_id / "ready" base_dir.mkdir(parents=True, exist_ok=True) return base_dir class SplitRequest(BaseModel): """Request model for splitting text""" file_id: UUID method: str = "recursive" chunk_size: int = Field(500, ge=50, le=5000) overlap: int = Field(50, ge=0, le=500) separator: Optional[str] = None # Embedding 相关参数(用于 semantic_embedding 方法) embedding_provider: Optional[str] = Field(None, description="embedding provider: openai, minimax") embedding_api_key: Optional[str] = Field(None, description="API key for embedding") embedding_base_url: Optional[str] = Field(None, description="API base URL") embedding_model: Optional[str] = Field(None, description="Embedding model name") # 语义分割参数 similarity_threshold: float = Field(0.3, ge=0.0, le=1.0, description="Similarity threshold for semantic split") min_chunk_size: int = Field(100, ge=10, le=1000, description="Minimum chunk size") async def process_file_by_type(file: File) -> str: """Process file based on its type, convert to markdown""" if not file.file_path: raise NotFoundException("File", file.id) # Supported types for markitdown markitdown_types = ["pdf", "docx", "doc", "pptx", "ppt", "xlsx", "xls", "htm", "html"] if file.file_type in markitdown_types: # Use markitdown to convert to markdown loop = asyncio.get_event_loop() result = await loop.run_in_executor( None, lambda: markitdown.convert(file.file_path) ) return result.text_content # Return raw text for txt, md files loop = asyncio.get_event_loop() content = await loop.run_in_executor( None, lambda: open(file.file_path, 'r', encoding='utf-8').read() ) return content async def process_split_async( project_id: UUID, request: SplitRequest, ): """Run chunk splitting in background.""" async with AsyncSessionLocal() as db: file = None try: result = await db.execute( select(File).where(File.id == request.file_id, File.project_id == project_id) ) file = result.scalar_one_or_none() if not file: return text = await process_file_by_type(file) kwargs = {"chunk_size": request.chunk_size, "overlap": request.overlap} if request.method == "custom" and request.separator: kwargs["separator"] = request.separator if request.method == "semantic_embedding": kwargs["embedding_provider_type"] = request.embedding_provider or "openai" kwargs["embedding_api_key"] = request.embedding_api_key kwargs["embedding_base_url"] = request.embedding_base_url or "https://api.minimax.chat/v1" kwargs["embedding_model"] = request.embedding_model or "text-embedding-3-small" kwargs["similarity_threshold"] = request.similarity_threshold kwargs["min_chunk_size"] = request.min_chunk_size splitter = get_splitter(request.method, **kwargs) split_results = splitter.split(text) await db.execute( Chunk.__table__.delete().where( Chunk.project_id == project_id, Chunk.file_id == file.id ) ) chunks = [] for chunk_data in split_results: db_chunk = Chunk( project_id=project_id, file_id=file.id, name=chunk_data.get("name", f"Chunk {chunk_data['index'] + 1}"), content=chunk_data["content"], word_count=chunk_data.get("word_count", len(chunk_data["content"].split())) ) db.add(db_chunk) chunks.append(db_chunk) await db.commit() ready_dir = get_project_ready_dir(str(project_id)) # 删除旧的 markdown 文件(可能有两种命名格式) old_md_files = list(ready_dir.glob(f"{file.id}*.md")) for old_file in old_md_files: try: old_file.unlink() except Exception: pass md_filename = f"{file.id}.md" md_path = ready_dir / md_filename loop = asyncio.get_event_loop() await loop.run_in_executor( None, lambda: md_path.write_text(text, encoding='utf-8') ) file.file_path = str(md_path) file.status = "completed" await db.commit() log_success( "文件分割完成", project_id=str(project_id), file_id=str(file.id), filename=file.filename, method=request.method, chunk_count=len(chunks), text_length=len(text), ready_path=str(md_path) ) except Exception as e: if file: file.status = "failed" await db.commit() log_failure( "文件分割失败", project_id=str(project_id), file_id=str(request.file_id), method=request.method, error=str(e) ) @router.post("/split", response_model=ApiResponse) async def split_text( project_id: UUID, request: SplitRequest, db: AsyncSession = Depends(get_db) ): """Split text into chunks""" try: result = await db.execute( select(File).where(File.id == request.file_id, File.project_id == project_id) ) file = result.scalar_one_or_none() if not file: raise NotFoundException("File", request.file_id) # 记录开始处理 log_success( "开始处理文件", project_id=str(project_id), file_id=str(file.id), filename=file.filename, method=request.method, chunk_size=request.chunk_size, overlap=request.overlap ) file.status = "processing" await db.commit() asyncio.create_task( process_split_async( project_id=project_id, request=request, ) ) return ApiResponse.ok( data={"file_id": str(file.id), "status": file.status}, message="Split task started, processing in background" ) except Exception as e: if 'file' in locals() and file: file.status = "failed" await db.commit() log_failure( "分割任务启动失败", project_id=str(project_id), file_id=str(request.file_id), error=str(e) ) raise @router.get("", response_model=ApiResponse) async def list_chunks( project_id: UUID, file_id: Optional[UUID] = Query(None), page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100), db: AsyncSession = Depends(get_db) ): """List chunks for a project""" filters = {"project_id": project_id} if file_id: filters["file_id"] = file_id skip = (page - 1) * page_size chunks, total = await chunk_crud.get_multi( db, skip=skip, limit=page_size, filters=filters, order_by="created_at", descending=False ) chunk_responses = [ChunkResponse.model_validate(c) for c in chunks] return PaginatedResponse.ok( items=chunk_responses, page=page, page_size=page_size, total=total ) @router.get("/{chunk_id}", response_model=ApiResponse) async def get_chunk( project_id: UUID, chunk_id: UUID, db: AsyncSession = Depends(get_db) ): """Get chunk by ID""" chunk = await chunk_crud.get(db, chunk_id) if not chunk or chunk.project_id != project_id: raise NotFoundException("Chunk", chunk_id) return ApiResponse.ok(data=ChunkResponse.model_validate(chunk)) @router.put("/{chunk_id}", response_model=ApiResponse) async def update_chunk( project_id: UUID, chunk_id: UUID, chunk: ChunkUpdateSchema, db: AsyncSession = Depends(get_db) ): """Update chunk""" db_chunk = await chunk_crud.get(db, chunk_id) if not db_chunk or db_chunk.project_id != project_id: raise NotFoundException("Chunk", chunk_id) updated_chunk = await chunk_crud.update(db, db_chunk, chunk) return ApiResponse.ok( data=ChunkResponse.model_validate(updated_chunk), message="Chunk updated successfully" ) @router.delete("/{chunk_id}", response_model=ApiResponse) async def delete_chunk( project_id: UUID, chunk_id: UUID, db: AsyncSession = Depends(get_db) ): """Delete chunk""" chunk = await chunk_crud.get(db, chunk_id) if not chunk or chunk.project_id != project_id: raise NotFoundException("Chunk", chunk_id) await chunk_crud.delete(db, chunk_id) return ApiResponse.ok(message="Chunk deleted successfully")