183 lines
5.4 KiB
Python
183 lines
5.4 KiB
Python
|
|
"""
|
||
|
|
Chunks API Router
|
||
|
|
"""
|
||
|
|
from typing import List, Optional
|
||
|
|
from uuid import UUID
|
||
|
|
from pydantic import BaseModel
|
||
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
|
from sqlalchemy import select
|
||
|
|
from app.core.database import get_db
|
||
|
|
from app.models.models import Chunk, File
|
||
|
|
from app.schemas.base import ChunkCreate, ChunkResponse
|
||
|
|
from app.services.text_splitter.splitter import get_splitter
|
||
|
|
from app.services.file_processor.pdf_processor import process_pdf
|
||
|
|
from app.services.file_processor.docx_processor import process_docx
|
||
|
|
from app.services.file_processor.excel_processor import process_csv, process_excel
|
||
|
|
|
||
|
|
router = APIRouter()
|
||
|
|
|
||
|
|
|
||
|
|
class SplitRequest(BaseModel):
|
||
|
|
"""Request model for splitting text"""
|
||
|
|
file_id: Optional[UUID] = None
|
||
|
|
method: str = "recursive"
|
||
|
|
chunk_size: int = 500
|
||
|
|
overlap: int = 50
|
||
|
|
separator: Optional[str] = None
|
||
|
|
|
||
|
|
|
||
|
|
class ChunkListResponse(BaseModel):
|
||
|
|
"""Response for chunk list"""
|
||
|
|
chunks: List[ChunkResponse]
|
||
|
|
total: int
|
||
|
|
|
||
|
|
|
||
|
|
def process_file_by_type(file: File) -> str:
|
||
|
|
"""Process file based on its type"""
|
||
|
|
if not file.file_path:
|
||
|
|
raise HTTPException(status_code=400, detail="File path not found")
|
||
|
|
|
||
|
|
processors = {
|
||
|
|
"pdf": process_pdf,
|
||
|
|
"docx": process_docx,
|
||
|
|
"xlsx": process_excel,
|
||
|
|
"csv": process_csv,
|
||
|
|
}
|
||
|
|
|
||
|
|
processor = processors.get(file.file_type)
|
||
|
|
if not processor:
|
||
|
|
# Return raw text for txt, md files
|
||
|
|
with open(file.file_path, 'r', encoding='utf-8') as f:
|
||
|
|
return f.read()
|
||
|
|
|
||
|
|
return processor(file.file_path)
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/split", response_model=dict)
|
||
|
|
async def split_text(
|
||
|
|
project_id: UUID,
|
||
|
|
request: SplitRequest,
|
||
|
|
db: AsyncSession = Depends(get_db)
|
||
|
|
):
|
||
|
|
"""Split text into chunks"""
|
||
|
|
# Get file
|
||
|
|
if request.file_id:
|
||
|
|
result = await db.execute(
|
||
|
|
select(File).where(File.id == request.file_id, File.project_id == project_id)
|
||
|
|
)
|
||
|
|
file = result.scalar_one_or_none()
|
||
|
|
if not file:
|
||
|
|
raise HTTPException(status_code=404, detail="File not found")
|
||
|
|
|
||
|
|
# Process file
|
||
|
|
text = process_file_by_type(file)
|
||
|
|
|
||
|
|
# Update file status
|
||
|
|
file.status = "processing"
|
||
|
|
await db.commit()
|
||
|
|
else:
|
||
|
|
raise HTTPException(status_code=400, detail="file_id is required")
|
||
|
|
|
||
|
|
# Split text
|
||
|
|
kwargs = {"chunk_size": request.chunk_size, "overlap": request.overlap}
|
||
|
|
if request.method == "custom" and request.separator:
|
||
|
|
kwargs["separator"] = request.separator
|
||
|
|
|
||
|
|
splitter = get_splitter(request.method, **kwargs)
|
||
|
|
split_results = splitter.split(text)
|
||
|
|
|
||
|
|
# Save chunks
|
||
|
|
chunks = []
|
||
|
|
for chunk_data in split_results:
|
||
|
|
db_chunk = Chunk(
|
||
|
|
project_id=project_id,
|
||
|
|
file_id=file.id,
|
||
|
|
name=chunk_data.get("name", f"Chunk {chunk_data['index'] + 1}"),
|
||
|
|
content=chunk_data["content"],
|
||
|
|
word_count=chunk_data.get("word_count", len(chunk_data["content"].split()))
|
||
|
|
)
|
||
|
|
db.add(db_chunk)
|
||
|
|
chunks.append(db_chunk)
|
||
|
|
|
||
|
|
await db.commit()
|
||
|
|
|
||
|
|
# Update file status
|
||
|
|
file.status = "completed"
|
||
|
|
await db.commit()
|
||
|
|
|
||
|
|
return {"chunks": len(chunks), "message": f"Successfully split into {len(chunks)} chunks"}
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/", response_model=dict)
|
||
|
|
async def list_chunks(
|
||
|
|
project_id: UUID,
|
||
|
|
file_id: Optional[UUID] = Query(None),
|
||
|
|
db: AsyncSession = Depends(get_db)
|
||
|
|
):
|
||
|
|
"""List chunks for a project"""
|
||
|
|
query = select(Chunk).where(Chunk.project_id == project_id)
|
||
|
|
|
||
|
|
if file_id:
|
||
|
|
query = query.where(Chunk.file_id == file_id)
|
||
|
|
|
||
|
|
query = query.order_by(Chunk.created_at.desc())
|
||
|
|
|
||
|
|
result = await db.execute(query)
|
||
|
|
chunks = result.scalars().all()
|
||
|
|
|
||
|
|
return {
|
||
|
|
"chunks": [ChunkResponse.model_validate(c) for c in chunks],
|
||
|
|
"total": len(chunks)
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/{chunk_id}", response_model=dict)
|
||
|
|
async def get_chunk(project_id: UUID, chunk_id: UUID, db: AsyncSession = Depends(get_db)):
|
||
|
|
"""Get chunk by ID"""
|
||
|
|
result = await db.execute(
|
||
|
|
select(Chunk).where(Chunk.id == chunk_id, Chunk.project_id == project_id)
|
||
|
|
)
|
||
|
|
chunk = result.scalar_one_or_none()
|
||
|
|
if not chunk:
|
||
|
|
raise HTTPException(status_code=404, detail="Chunk not found")
|
||
|
|
return ChunkResponse.model_validate(chunk)
|
||
|
|
|
||
|
|
|
||
|
|
@router.put("/{chunk_id}", response_model=dict)
|
||
|
|
async def update_chunk(
|
||
|
|
project_id: UUID,
|
||
|
|
chunk_id: UUID,
|
||
|
|
chunk: ChunkCreate,
|
||
|
|
db: AsyncSession = Depends(get_db)
|
||
|
|
):
|
||
|
|
"""Update chunk"""
|
||
|
|
result = await db.execute(
|
||
|
|
select(Chunk).where(Chunk.id == chunk_id, Chunk.project_id == project_id)
|
||
|
|
)
|
||
|
|
db_chunk = result.scalar_one_or_none()
|
||
|
|
if not db_chunk:
|
||
|
|
raise HTTPException(status_code=404, detail="Chunk not found")
|
||
|
|
|
||
|
|
for key, value in chunk.model_dump(exclude_unset=True).items():
|
||
|
|
setattr(db_chunk, key, value)
|
||
|
|
|
||
|
|
await db.commit()
|
||
|
|
await db.refresh(db_chunk)
|
||
|
|
return ChunkResponse.model_validate(db_chunk)
|
||
|
|
|
||
|
|
|
||
|
|
@router.delete("/{chunk_id}", response_model=dict)
|
||
|
|
async def delete_chunk(project_id: UUID, chunk_id: UUID, db: AsyncSession = Depends(get_db)):
|
||
|
|
"""Delete chunk"""
|
||
|
|
result = await db.execute(
|
||
|
|
select(Chunk).where(Chunk.id == chunk_id, Chunk.project_id == project_id)
|
||
|
|
)
|
||
|
|
chunk = result.scalar_one_or_none()
|
||
|
|
if not chunk:
|
||
|
|
raise HTTPException(status_code=404, detail="Chunk not found")
|
||
|
|
|
||
|
|
await db.delete(chunk)
|
||
|
|
await db.commit()
|
||
|
|
return {"message": "Chunk deleted successfully"}
|