155 lines
4.9 KiB
Python
155 lines
4.9 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, BackgroundTasks, Form
|
||
from typing import Optional
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select
|
||
from app.database import get_db
|
||
from app.models.document import Document, DocumentChunk
|
||
from app.models.user import User
|
||
from app.routers.auth import get_current_user
|
||
from app.services.document_service import DocumentService
|
||
from app.services.knowledge_service import KnowledgeService
|
||
from dataclasses import asdict
|
||
|
||
router = APIRouter(prefix="/api/documents", tags=["知识库"])
|
||
|
||
|
||
@router.get("", response_model=list)
|
||
async def list_documents(
|
||
folder_id: Optional[str] = None,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
query = select(Document).where(Document.user_id == current_user.id)
|
||
if folder_id:
|
||
query = query.where(Document.folder_id == folder_id)
|
||
result = await db.execute(query.order_by(Document.created_at.desc()))
|
||
return result.scalars().all()
|
||
|
||
|
||
@router.post("/upload", status_code=201)
|
||
async def upload_document(
|
||
background: BackgroundTasks,
|
||
file: UploadFile = File(...),
|
||
folder_id: Optional[str] = Form(None),
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""上传文档,自动分块并向量化"""
|
||
doc_svc = DocumentService(db)
|
||
doc = await doc_svc.upload_document(current_user.id, file, folder_id=folder_id)
|
||
|
||
# 后台索引到 ChromaDB
|
||
def index_task():
|
||
import asyncio
|
||
from app.database import async_session
|
||
from app.services.knowledge_service import KnowledgeService
|
||
|
||
async def _index():
|
||
async with async_session() as session:
|
||
kb_svc = KnowledgeService(session, user_id=current_user.id)
|
||
await kb_svc.index_document(doc.id, user_id=current_user.id)
|
||
|
||
asyncio.run(_index())
|
||
|
||
background.add_task(index_task)
|
||
return {"id": doc.id, "title": doc.title, "chunk_count": doc.chunk_count, "status": "上传成功,正在索引..."}
|
||
|
||
|
||
@router.get("/{document_id}")
|
||
async def get_document(
|
||
document_id: str,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
result = await db.execute(
|
||
select(Document).where(
|
||
Document.id == document_id,
|
||
Document.user_id == current_user.id,
|
||
)
|
||
)
|
||
doc = result.scalar_one_or_none()
|
||
if not doc:
|
||
raise HTTPException(status_code=404, detail="文档不存在")
|
||
return doc
|
||
|
||
|
||
@router.get("/{document_id}/chunks")
|
||
async def get_document_chunks(
|
||
document_id: str,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""获取文档的所有 chunks"""
|
||
result = await db.execute(
|
||
select(Document).where(
|
||
Document.id == document_id,
|
||
Document.user_id == current_user.id,
|
||
)
|
||
)
|
||
doc = result.scalar_one_or_none()
|
||
if not doc:
|
||
raise HTTPException(status_code=404, detail="文档不存在")
|
||
|
||
chunks_result = await db.execute(
|
||
select(DocumentChunk)
|
||
.where(DocumentChunk.document_id == document_id)
|
||
.order_by(DocumentChunk.chunk_index)
|
||
)
|
||
return chunks_result.scalars().all()
|
||
|
||
|
||
@router.delete("/{document_id}", status_code=204)
|
||
async def delete_document(
|
||
document_id: str,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""删除文档"""
|
||
doc_svc = DocumentService(db)
|
||
await doc_svc.delete_document(current_user.id, document_id)
|
||
|
||
|
||
@router.post("/search")
|
||
async def search_documents(
|
||
query: str,
|
||
top_k: int = 5,
|
||
mode: str = "hybrid",
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
搜索知识库
|
||
|
||
- query: 搜索查询
|
||
- top_k: 返回数量,默认5
|
||
- mode: hybrid(混合)/ semantic(语义)/ keyword(关键词)
|
||
"""
|
||
kb_svc = KnowledgeService(db, user_id=current_user.id)
|
||
|
||
if mode == "keyword":
|
||
results = await kb_svc._keyword_search(query, current_user.id, top_k)
|
||
elif mode == "semantic":
|
||
results = await kb_svc.retrieve(query, current_user.id, top_k, use_rerank=True)
|
||
else:
|
||
results = await kb_svc.hybrid_search(query, current_user.id, top_k)
|
||
|
||
return [asdict(r) for r in results]
|
||
|
||
|
||
@router.get("/{document_id}/content")
|
||
async def get_document_content(
|
||
document_id: str,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""获取文档的文本内容(用于AI理解)"""
|
||
from app.services.document_service import DocumentService
|
||
|
||
doc_svc = DocumentService(db)
|
||
content = await doc_svc.get_document_content(current_user.id, document_id)
|
||
|
||
if content is None:
|
||
raise HTTPException(status_code=404, detail="文档不存在或无内容")
|
||
|
||
return {"content": content}
|