2026-03-21 10:13:29 +08:00
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, BackgroundTasks, Form
|
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
|
from app.database import get_db
|
|
|
|
|
|
from app.models.document import Document, DocumentChunk
|
|
|
|
|
|
from app.models.user import User
|
|
|
|
|
|
from app.routers.auth import get_current_user
|
|
|
|
|
|
from app.services.document_service import DocumentService
|
|
|
|
|
|
from app.services.knowledge_service import KnowledgeService
|
2026-03-22 13:42:16 +08:00
|
|
|
|
from app.schemas.document import DocumentChunkOut, DocumentChunkUpdate, DocumentOut
|
2026-03-21 10:13:29 +08:00
|
|
|
|
from dataclasses import asdict
|
|
|
|
|
|
|
|
|
|
|
|
router = APIRouter(prefix="/api/documents", tags=["知识库"])
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-03-22 13:42:16 +08:00
|
|
|
|
@router.get("", response_model=list[DocumentOut])
|
2026-03-21 10:13:29 +08:00
|
|
|
|
async def list_documents(
|
|
|
|
|
|
folder_id: Optional[str] = None,
|
|
|
|
|
|
current_user: User = Depends(get_current_user),
|
|
|
|
|
|
db: AsyncSession = Depends(get_db),
|
|
|
|
|
|
):
|
|
|
|
|
|
query = select(Document).where(Document.user_id == current_user.id)
|
|
|
|
|
|
if folder_id:
|
|
|
|
|
|
query = query.where(Document.folder_id == folder_id)
|
|
|
|
|
|
result = await db.execute(query.order_by(Document.created_at.desc()))
|
|
|
|
|
|
return result.scalars().all()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/upload", status_code=201)
|
|
|
|
|
|
async def upload_document(
|
|
|
|
|
|
background: BackgroundTasks,
|
|
|
|
|
|
file: UploadFile = File(...),
|
|
|
|
|
|
folder_id: Optional[str] = Form(None),
|
|
|
|
|
|
current_user: User = Depends(get_current_user),
|
|
|
|
|
|
db: AsyncSession = Depends(get_db),
|
|
|
|
|
|
):
|
|
|
|
|
|
"""上传文档,自动分块并向量化"""
|
|
|
|
|
|
doc_svc = DocumentService(db)
|
2026-03-22 13:42:16 +08:00
|
|
|
|
try:
|
|
|
|
|
|
doc = await doc_svc.upload_document(current_user.id, file, folder_id=folder_id)
|
|
|
|
|
|
except ValueError as error:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=str(error)) from error
|
2026-03-21 10:13:29 +08:00
|
|
|
|
|
|
|
|
|
|
# 后台索引到 ChromaDB
|
|
|
|
|
|
def index_task():
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
from app.database import async_session
|
|
|
|
|
|
from app.services.knowledge_service import KnowledgeService
|
|
|
|
|
|
|
|
|
|
|
|
async def _index():
|
|
|
|
|
|
async with async_session() as session:
|
|
|
|
|
|
kb_svc = KnowledgeService(session, user_id=current_user.id)
|
|
|
|
|
|
await kb_svc.index_document(doc.id, user_id=current_user.id)
|
|
|
|
|
|
|
|
|
|
|
|
asyncio.run(_index())
|
|
|
|
|
|
|
|
|
|
|
|
background.add_task(index_task)
|
|
|
|
|
|
return {"id": doc.id, "title": doc.title, "chunk_count": doc.chunk_count, "status": "上传成功,正在索引..."}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/{document_id}")
|
|
|
|
|
|
async def get_document(
|
|
|
|
|
|
document_id: str,
|
|
|
|
|
|
current_user: User = Depends(get_current_user),
|
|
|
|
|
|
db: AsyncSession = Depends(get_db),
|
|
|
|
|
|
):
|
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
|
select(Document).where(
|
|
|
|
|
|
Document.id == document_id,
|
|
|
|
|
|
Document.user_id == current_user.id,
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
doc = result.scalar_one_or_none()
|
|
|
|
|
|
if not doc:
|
|
|
|
|
|
raise HTTPException(status_code=404, detail="文档不存在")
|
|
|
|
|
|
return doc
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-03-22 13:42:16 +08:00
|
|
|
|
@router.get("/{document_id}/chunks", response_model=list[DocumentChunkOut])
|
2026-03-21 10:13:29 +08:00
|
|
|
|
async def get_document_chunks(
|
|
|
|
|
|
document_id: str,
|
|
|
|
|
|
current_user: User = Depends(get_current_user),
|
|
|
|
|
|
db: AsyncSession = Depends(get_db),
|
|
|
|
|
|
):
|
|
|
|
|
|
"""获取文档的所有 chunks"""
|
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
|
select(Document).where(
|
|
|
|
|
|
Document.id == document_id,
|
|
|
|
|
|
Document.user_id == current_user.id,
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
doc = result.scalar_one_or_none()
|
|
|
|
|
|
if not doc:
|
|
|
|
|
|
raise HTTPException(status_code=404, detail="文档不存在")
|
|
|
|
|
|
|
|
|
|
|
|
chunks_result = await db.execute(
|
|
|
|
|
|
select(DocumentChunk)
|
|
|
|
|
|
.where(DocumentChunk.document_id == document_id)
|
|
|
|
|
|
.order_by(DocumentChunk.chunk_index)
|
|
|
|
|
|
)
|
|
|
|
|
|
return chunks_result.scalars().all()
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-03-22 13:42:16 +08:00
|
|
|
|
@router.put("/{document_id}/chunks/{chunk_id}", response_model=DocumentChunkOut)
|
|
|
|
|
|
async def update_document_chunk(
|
|
|
|
|
|
document_id: str,
|
|
|
|
|
|
chunk_id: str,
|
|
|
|
|
|
payload: DocumentChunkUpdate,
|
|
|
|
|
|
current_user: User = Depends(get_current_user),
|
|
|
|
|
|
db: AsyncSession = Depends(get_db),
|
|
|
|
|
|
):
|
|
|
|
|
|
doc_svc = DocumentService(db)
|
|
|
|
|
|
kb_svc = KnowledgeService(db, user_id=current_user.id)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
chunk = await doc_svc.update_document_chunk(current_user.id, document_id, chunk_id, payload.content)
|
|
|
|
|
|
except ValueError as error:
|
|
|
|
|
|
raise HTTPException(status_code=404, detail=str(error)) from error
|
|
|
|
|
|
|
|
|
|
|
|
reindexed = await kb_svc.reindex_document_chunks(document_id, current_user.id)
|
|
|
|
|
|
if not reindexed:
|
|
|
|
|
|
raise HTTPException(status_code=500, detail="切片更新后重新索引失败")
|
|
|
|
|
|
|
|
|
|
|
|
refreshed_chunk_result = await db.execute(
|
|
|
|
|
|
select(DocumentChunk).where(DocumentChunk.id == chunk.id)
|
|
|
|
|
|
)
|
|
|
|
|
|
refreshed_chunk = refreshed_chunk_result.scalar_one()
|
|
|
|
|
|
return refreshed_chunk
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-03-21 10:13:29 +08:00
|
|
|
|
@router.delete("/{document_id}", status_code=204)
|
|
|
|
|
|
async def delete_document(
|
|
|
|
|
|
document_id: str,
|
|
|
|
|
|
current_user: User = Depends(get_current_user),
|
|
|
|
|
|
db: AsyncSession = Depends(get_db),
|
|
|
|
|
|
):
|
|
|
|
|
|
"""删除文档"""
|
|
|
|
|
|
doc_svc = DocumentService(db)
|
|
|
|
|
|
await doc_svc.delete_document(current_user.id, document_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/search")
|
|
|
|
|
|
async def search_documents(
|
|
|
|
|
|
query: str,
|
|
|
|
|
|
top_k: int = 5,
|
|
|
|
|
|
mode: str = "hybrid",
|
|
|
|
|
|
current_user: User = Depends(get_current_user),
|
|
|
|
|
|
db: AsyncSession = Depends(get_db),
|
|
|
|
|
|
):
|
|
|
|
|
|
"""
|
|
|
|
|
|
搜索知识库
|
|
|
|
|
|
|
|
|
|
|
|
- query: 搜索查询
|
|
|
|
|
|
- top_k: 返回数量,默认5
|
|
|
|
|
|
- mode: hybrid(混合)/ semantic(语义)/ keyword(关键词)
|
|
|
|
|
|
"""
|
|
|
|
|
|
kb_svc = KnowledgeService(db, user_id=current_user.id)
|
|
|
|
|
|
|
|
|
|
|
|
if mode == "keyword":
|
|
|
|
|
|
results = await kb_svc._keyword_search(query, current_user.id, top_k)
|
|
|
|
|
|
elif mode == "semantic":
|
2026-03-22 13:42:16 +08:00
|
|
|
|
results = await kb_svc.retrieve(query, current_user.id, top_k=top_k, use_rerank=True)
|
2026-03-21 10:13:29 +08:00
|
|
|
|
else:
|
|
|
|
|
|
results = await kb_svc.hybrid_search(query, current_user.id, top_k)
|
|
|
|
|
|
|
|
|
|
|
|
return [asdict(r) for r in results]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/{document_id}/content")
|
|
|
|
|
|
async def get_document_content(
|
|
|
|
|
|
document_id: str,
|
|
|
|
|
|
current_user: User = Depends(get_current_user),
|
|
|
|
|
|
db: AsyncSession = Depends(get_db),
|
|
|
|
|
|
):
|
|
|
|
|
|
"""获取文档的文本内容(用于AI理解)"""
|
|
|
|
|
|
from app.services.document_service import DocumentService
|
|
|
|
|
|
|
|
|
|
|
|
doc_svc = DocumentService(db)
|
|
|
|
|
|
content = await doc_svc.get_document_content(current_user.id, document_id)
|
|
|
|
|
|
|
|
|
|
|
|
if content is None:
|
|
|
|
|
|
raise HTTPException(status_code=404, detail="文档不存在或无内容")
|
|
|
|
|
|
|
|
|
|
|
|
return {"content": content}
|