155 lines
4.9 KiB
Python
155 lines
4.9 KiB
Python
|
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, BackgroundTasks, Form
|
|||
|
|
from typing import Optional
|
|||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
|
from sqlalchemy import select
|
|||
|
|
from app.database import get_db
|
|||
|
|
from app.models.document import Document, DocumentChunk
|
|||
|
|
from app.models.user import User
|
|||
|
|
from app.routers.auth import get_current_user
|
|||
|
|
from app.services.document_service import DocumentService
|
|||
|
|
from app.services.knowledge_service import KnowledgeService
|
|||
|
|
from dataclasses import asdict
|
|||
|
|
|
|||
|
|
router = APIRouter(prefix="/api/documents", tags=["知识库"])
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("", response_model=list)
|
|||
|
|
async def list_documents(
|
|||
|
|
folder_id: Optional[str] = None,
|
|||
|
|
current_user: User = Depends(get_current_user),
|
|||
|
|
db: AsyncSession = Depends(get_db),
|
|||
|
|
):
|
|||
|
|
query = select(Document).where(Document.user_id == current_user.id)
|
|||
|
|
if folder_id:
|
|||
|
|
query = query.where(Document.folder_id == folder_id)
|
|||
|
|
result = await db.execute(query.order_by(Document.created_at.desc()))
|
|||
|
|
return result.scalars().all()
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/upload", status_code=201)
|
|||
|
|
async def upload_document(
|
|||
|
|
background: BackgroundTasks,
|
|||
|
|
file: UploadFile = File(...),
|
|||
|
|
folder_id: Optional[str] = Form(None),
|
|||
|
|
current_user: User = Depends(get_current_user),
|
|||
|
|
db: AsyncSession = Depends(get_db),
|
|||
|
|
):
|
|||
|
|
"""上传文档,自动分块并向量化"""
|
|||
|
|
doc_svc = DocumentService(db)
|
|||
|
|
doc = await doc_svc.upload_document(current_user.id, file, folder_id=folder_id)
|
|||
|
|
|
|||
|
|
# 后台索引到 ChromaDB
|
|||
|
|
def index_task():
|
|||
|
|
import asyncio
|
|||
|
|
from app.database import async_session
|
|||
|
|
from app.services.knowledge_service import KnowledgeService
|
|||
|
|
|
|||
|
|
async def _index():
|
|||
|
|
async with async_session() as session:
|
|||
|
|
kb_svc = KnowledgeService(session, user_id=current_user.id)
|
|||
|
|
await kb_svc.index_document(doc.id, user_id=current_user.id)
|
|||
|
|
|
|||
|
|
asyncio.run(_index())
|
|||
|
|
|
|||
|
|
background.add_task(index_task)
|
|||
|
|
return {"id": doc.id, "title": doc.title, "chunk_count": doc.chunk_count, "status": "上传成功,正在索引..."}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/{document_id}")
|
|||
|
|
async def get_document(
|
|||
|
|
document_id: str,
|
|||
|
|
current_user: User = Depends(get_current_user),
|
|||
|
|
db: AsyncSession = Depends(get_db),
|
|||
|
|
):
|
|||
|
|
result = await db.execute(
|
|||
|
|
select(Document).where(
|
|||
|
|
Document.id == document_id,
|
|||
|
|
Document.user_id == current_user.id,
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
doc = result.scalar_one_or_none()
|
|||
|
|
if not doc:
|
|||
|
|
raise HTTPException(status_code=404, detail="文档不存在")
|
|||
|
|
return doc
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/{document_id}/chunks")
|
|||
|
|
async def get_document_chunks(
|
|||
|
|
document_id: str,
|
|||
|
|
current_user: User = Depends(get_current_user),
|
|||
|
|
db: AsyncSession = Depends(get_db),
|
|||
|
|
):
|
|||
|
|
"""获取文档的所有 chunks"""
|
|||
|
|
result = await db.execute(
|
|||
|
|
select(Document).where(
|
|||
|
|
Document.id == document_id,
|
|||
|
|
Document.user_id == current_user.id,
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
doc = result.scalar_one_or_none()
|
|||
|
|
if not doc:
|
|||
|
|
raise HTTPException(status_code=404, detail="文档不存在")
|
|||
|
|
|
|||
|
|
chunks_result = await db.execute(
|
|||
|
|
select(DocumentChunk)
|
|||
|
|
.where(DocumentChunk.document_id == document_id)
|
|||
|
|
.order_by(DocumentChunk.chunk_index)
|
|||
|
|
)
|
|||
|
|
return chunks_result.scalars().all()
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.delete("/{document_id}", status_code=204)
|
|||
|
|
async def delete_document(
|
|||
|
|
document_id: str,
|
|||
|
|
current_user: User = Depends(get_current_user),
|
|||
|
|
db: AsyncSession = Depends(get_db),
|
|||
|
|
):
|
|||
|
|
"""删除文档"""
|
|||
|
|
doc_svc = DocumentService(db)
|
|||
|
|
await doc_svc.delete_document(current_user.id, document_id)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/search")
|
|||
|
|
async def search_documents(
|
|||
|
|
query: str,
|
|||
|
|
top_k: int = 5,
|
|||
|
|
mode: str = "hybrid",
|
|||
|
|
current_user: User = Depends(get_current_user),
|
|||
|
|
db: AsyncSession = Depends(get_db),
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
搜索知识库
|
|||
|
|
|
|||
|
|
- query: 搜索查询
|
|||
|
|
- top_k: 返回数量,默认5
|
|||
|
|
- mode: hybrid(混合)/ semantic(语义)/ keyword(关键词)
|
|||
|
|
"""
|
|||
|
|
kb_svc = KnowledgeService(db, user_id=current_user.id)
|
|||
|
|
|
|||
|
|
if mode == "keyword":
|
|||
|
|
results = await kb_svc._keyword_search(query, current_user.id, top_k)
|
|||
|
|
elif mode == "semantic":
|
|||
|
|
results = await kb_svc.retrieve(query, current_user.id, top_k, use_rerank=True)
|
|||
|
|
else:
|
|||
|
|
results = await kb_svc.hybrid_search(query, current_user.id, top_k)
|
|||
|
|
|
|||
|
|
return [asdict(r) for r in results]
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/{document_id}/content")
|
|||
|
|
async def get_document_content(
|
|||
|
|
document_id: str,
|
|||
|
|
current_user: User = Depends(get_current_user),
|
|||
|
|
db: AsyncSession = Depends(get_db),
|
|||
|
|
):
|
|||
|
|
"""获取文档的文本内容(用于AI理解)"""
|
|||
|
|
from app.services.document_service import DocumentService
|
|||
|
|
|
|||
|
|
doc_svc = DocumentService(db)
|
|||
|
|
content = await doc_svc.get_document_content(current_user.id, document_id)
|
|||
|
|
|
|||
|
|
if content is None:
|
|||
|
|
raise HTTPException(status_code=404, detail="文档不存在或无内容")
|
|||
|
|
|
|||
|
|
return {"content": content}
|