""" Chunks API Router """ import asyncio from pathlib import Path from typing import List, Optional from uuid import UUID from pydantic import BaseModel, Field from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from app.api.response import ApiResponse, PaginatedResponse from app.core.database import get_db from app.core.exceptions import NotFoundException from app.core.crud import CRUDBase from app.core.logging import log_success, log_failure from app.models.models import Chunk, File from app.schemas.chunk import ChunkResponse from app.schemas.chunk import ChunkCreateSchema, ChunkUpdateSchema from app.services.text_splitter.splitter import get_splitter from markitdown import MarkItDown router = APIRouter() # Initialize CRUD chunk_crud = CRUDBase(Chunk) # Initialize markitdown markitdown = MarkItDown() def get_project_ready_dir(project_id: str) -> Path: """获取项目的 ready 文件目录""" base_dir = Path("/data/code/YG-Datasets/data") / project_id / "ready" base_dir.mkdir(parents=True, exist_ok=True) return base_dir class SplitRequest(BaseModel): """Request model for splitting text""" file_id: UUID method: str = "recursive" chunk_size: int = Field(500, ge=50, le=5000) overlap: int = Field(50, ge=0, le=500) separator: Optional[str] = None # Embedding 相关参数(用于 semantic_embedding 方法) embedding_provider: Optional[str] = Field(None, description="embedding provider: openai, minimax") embedding_api_key: Optional[str] = Field(None, description="API key for embedding") embedding_base_url: Optional[str] = Field(None, description="API base URL") embedding_model: Optional[str] = Field(None, description="Embedding model name") # 语义分割参数 similarity_threshold: float = Field(0.3, ge=0.0, le=1.0, description="Similarity threshold for semantic split") min_chunk_size: int = Field(100, ge=10, le=1000, description="Minimum chunk size") async def process_file_by_type(file: File) -> str: """Process file based on its type, convert to markdown""" if not file.file_path: raise NotFoundException("File", file.id) # Supported types for markitdown markitdown_types = ["pdf", "docx", "doc", "pptx", "ppt", "xlsx", "xls", "htm", "html"] if file.file_type in markitdown_types: # Use markitdown to convert to markdown loop = asyncio.get_event_loop() result = await loop.run_in_executor( None, lambda: markitdown.convert(file.file_path) ) return result.text_content # Return raw text for txt, md files loop = asyncio.get_event_loop() content = await loop.run_in_executor( None, lambda: open(file.file_path, 'r', encoding='utf-8').read() ) return content @router.post("/split", response_model=ApiResponse) async def split_text( project_id: UUID, request: SplitRequest, db: AsyncSession = Depends(get_db) ): """Split text into chunks""" try: # Get file result = await db.execute( select(File).where(File.id == request.file_id, File.project_id == project_id) ) file = result.scalar_one_or_none() if not file: raise NotFoundException("File", request.file_id) # 记录开始处理 log_success( "开始处理文件", project_id=str(project_id), file_id=str(file.id), filename=file.filename, method=request.method, chunk_size=request.chunk_size, overlap=request.overlap ) # Process file text = await process_file_by_type(file) # Update file status file.status = "processing" await db.commit() # Split text kwargs = {"chunk_size": request.chunk_size, "overlap": request.overlap} if request.method == "custom" and request.separator: kwargs["separator"] = request.separator # 如果使用 semantic_embedding 方法,传递 embedding 参数 if request.method == "semantic_embedding": kwargs["embedding_provider_type"] = request.embedding_provider or "openai" kwargs["embedding_api_key"] = request.embedding_api_key kwargs["embedding_base_url"] = request.embedding_base_url or "https://api.minimax.chat/v1" kwargs["embedding_model"] = request.embedding_model or "text-embedding-3-small" kwargs["similarity_threshold"] = request.similarity_threshold kwargs["min_chunk_size"] = request.min_chunk_size splitter = get_splitter(request.method, **kwargs) split_results = splitter.split(text) # Save chunks chunks = [] for chunk_data in split_results: db_chunk = Chunk( project_id=project_id, file_id=file.id, name=chunk_data.get("name", f"Chunk {chunk_data['index'] + 1}"), content=chunk_data["content"], word_count=chunk_data.get("word_count", len(chunk_data["content"].split())) ) db.add(db_chunk) chunks.append(db_chunk) await db.commit() # Save processed markdown to ready directory ready_dir = get_project_ready_dir(str(project_id)) md_filename = f"{file.id}_{file.filename}.md" md_path = ready_dir / md_filename # Write markdown content to file loop = asyncio.get_event_loop() await loop.run_in_executor( None, lambda: md_path.write_text(text, encoding='utf-8') ) # Update file path to ready location file.file_path = str(md_path) file.status = "completed" await db.commit() # 记录成功日志 log_success( "文件处理完成", project_id=str(project_id), file_id=str(file.id), filename=file.filename, chunk_count=len(chunks), text_length=len(text), ready_path=str(md_path) ) return ApiResponse.ok( data={"chunks": len(chunks)}, message=f"Successfully split into {len(chunks)} chunks" ) except Exception as e: # 记录失败日志 log_failure( "文件处理失败", project_id=str(project_id), file_id=str(request.file_id), error=str(e) ) raise @router.get("", response_model=ApiResponse) async def list_chunks( project_id: UUID, file_id: Optional[UUID] = Query(None), page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100), db: AsyncSession = Depends(get_db) ): """List chunks for a project""" filters = {"project_id": project_id} if file_id: filters["file_id"] = file_id skip = (page - 1) * page_size chunks, total = await chunk_crud.get_multi( db, skip=skip, limit=page_size, filters=filters, order_by="created_at", descending=False ) chunk_responses = [ChunkResponse.model_validate(c) for c in chunks] return PaginatedResponse.ok( items=chunk_responses, page=page, page_size=page_size, total=total ) @router.get("/{chunk_id}", response_model=ApiResponse) async def get_chunk( project_id: UUID, chunk_id: UUID, db: AsyncSession = Depends(get_db) ): """Get chunk by ID""" chunk = await chunk_crud.get(db, chunk_id) if not chunk or chunk.project_id != project_id: raise NotFoundException("Chunk", chunk_id) return ApiResponse.ok(data=ChunkResponse.model_validate(chunk)) @router.put("/{chunk_id}", response_model=ApiResponse) async def update_chunk( project_id: UUID, chunk_id: UUID, chunk: ChunkUpdateSchema, db: AsyncSession = Depends(get_db) ): """Update chunk""" db_chunk = await chunk_crud.get(db, chunk_id) if not db_chunk or db_chunk.project_id != project_id: raise NotFoundException("Chunk", chunk_id) updated_chunk = await chunk_crud.update(db, db_chunk, chunk) return ApiResponse.ok( data=ChunkResponse.model_validate(updated_chunk), message="Chunk updated successfully" ) @router.delete("/{chunk_id}", response_model=ApiResponse) async def delete_chunk( project_id: UUID, chunk_id: UUID, db: AsyncSession = Depends(get_db) ): """Delete chunk""" chunk = await chunk_crud.get(db, chunk_id) if not chunk or chunk.project_id != project_id: raise NotFoundException("Chunk", chunk_id) await chunk_crud.delete(db, chunk_id) return ApiResponse.ok(message="Chunk deleted successfully")