from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, BackgroundTasks, Form from typing import Optional from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from app.database import get_db from app.models.document import Document, DocumentChunk from app.models.user import User from app.routers.auth import get_current_user from app.services.document_service import DocumentService from app.services.knowledge_service import KnowledgeService from app.schemas.document import DocumentChunkOut, DocumentChunkUpdate, DocumentOut from dataclasses import asdict router = APIRouter(prefix="/api/documents", tags=["知识库"]) @router.get("", response_model=list[DocumentOut]) async def list_documents( folder_id: Optional[str] = None, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): query = select(Document).where(Document.user_id == current_user.id) if folder_id: query = query.where(Document.folder_id == folder_id) result = await db.execute(query.order_by(Document.created_at.desc())) return result.scalars().all() @router.post("/upload", status_code=201) async def upload_document( background: BackgroundTasks, file: UploadFile = File(...), folder_id: Optional[str] = Form(None), current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """上传文档,自动分块并向量化""" doc_svc = DocumentService(db) try: doc = await doc_svc.upload_document(current_user.id, file, folder_id=folder_id) except ValueError as error: raise HTTPException(status_code=400, detail=str(error)) from error # 后台索引到 ChromaDB def index_task(): import asyncio from app.database import async_session from app.services.knowledge_service import KnowledgeService async def _index(): async with async_session() as session: kb_svc = KnowledgeService(session, user_id=current_user.id) await kb_svc.index_document(doc.id, user_id=current_user.id) asyncio.run(_index()) background.add_task(index_task) return {"id": doc.id, "title": doc.title, "chunk_count": doc.chunk_count, "status": "上传成功,正在索引..."} @router.get("/{document_id}") async def get_document( document_id: str, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): result = await db.execute( select(Document).where( Document.id == document_id, Document.user_id == current_user.id, ) ) doc = result.scalar_one_or_none() if not doc: raise HTTPException(status_code=404, detail="文档不存在") return doc @router.get("/{document_id}/chunks", response_model=list[DocumentChunkOut]) async def get_document_chunks( document_id: str, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """获取文档的所有 chunks""" result = await db.execute( select(Document).where( Document.id == document_id, Document.user_id == current_user.id, ) ) doc = result.scalar_one_or_none() if not doc: raise HTTPException(status_code=404, detail="文档不存在") chunks_result = await db.execute( select(DocumentChunk) .where(DocumentChunk.document_id == document_id) .order_by(DocumentChunk.chunk_index) ) return chunks_result.scalars().all() @router.put("/{document_id}/chunks/{chunk_id}", response_model=DocumentChunkOut) async def update_document_chunk( document_id: str, chunk_id: str, payload: DocumentChunkUpdate, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): doc_svc = DocumentService(db) kb_svc = KnowledgeService(db, user_id=current_user.id) try: chunk = await doc_svc.update_document_chunk(current_user.id, document_id, chunk_id, payload.content) except ValueError as error: raise HTTPException(status_code=404, detail=str(error)) from error reindexed = await kb_svc.reindex_document_chunks(document_id, current_user.id) if not reindexed: raise HTTPException(status_code=500, detail="切片更新后重新索引失败") refreshed_chunk_result = await db.execute( select(DocumentChunk).where(DocumentChunk.id == chunk.id) ) refreshed_chunk = refreshed_chunk_result.scalar_one() return refreshed_chunk @router.delete("/{document_id}", status_code=204) async def delete_document( document_id: str, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """删除文档""" doc_svc = DocumentService(db) await doc_svc.delete_document(current_user.id, document_id) @router.post("/search") async def search_documents( query: str, top_k: int = 5, mode: str = "hybrid", current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """ 搜索知识库 - query: 搜索查询 - top_k: 返回数量,默认5 - mode: hybrid(混合)/ semantic(语义)/ keyword(关键词) """ kb_svc = KnowledgeService(db, user_id=current_user.id) if mode == "keyword": results = await kb_svc._keyword_search(query, current_user.id, top_k) elif mode == "semantic": results = await kb_svc.retrieve(query, current_user.id, top_k=top_k, use_rerank=True) else: results = await kb_svc.hybrid_search(query, current_user.id, top_k) return [asdict(r) for r in results] @router.get("/{document_id}/content") async def get_document_content( document_id: str, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """获取文档的文本内容(用于AI理解)""" from app.services.document_service import DocumentService doc_svc = DocumentService(db) content = await doc_svc.get_document_content(current_user.id, document_id) if content is None: raise HTTPException(status_code=404, detail="文档不存在或无内容") return {"content": content}