Files
JARVIS/backend/app/routers/document.py

155 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, BackgroundTasks, Form
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.database import get_db
from app.models.document import Document, DocumentChunk
from app.models.user import User
from app.routers.auth import get_current_user
from app.services.document_service import DocumentService
from app.services.knowledge_service import KnowledgeService
from dataclasses import asdict
router = APIRouter(prefix="/api/documents", tags=["知识库"])
@router.get("", response_model=list)
async def list_documents(
folder_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
query = select(Document).where(Document.user_id == current_user.id)
if folder_id:
query = query.where(Document.folder_id == folder_id)
result = await db.execute(query.order_by(Document.created_at.desc()))
return result.scalars().all()
@router.post("/upload", status_code=201)
async def upload_document(
background: BackgroundTasks,
file: UploadFile = File(...),
folder_id: Optional[str] = Form(None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""上传文档,自动分块并向量化"""
doc_svc = DocumentService(db)
doc = await doc_svc.upload_document(current_user.id, file, folder_id=folder_id)
# 后台索引到 ChromaDB
def index_task():
import asyncio
from app.database import async_session
from app.services.knowledge_service import KnowledgeService
async def _index():
async with async_session() as session:
kb_svc = KnowledgeService(session, user_id=current_user.id)
await kb_svc.index_document(doc.id, user_id=current_user.id)
asyncio.run(_index())
background.add_task(index_task)
return {"id": doc.id, "title": doc.title, "chunk_count": doc.chunk_count, "status": "上传成功,正在索引..."}
@router.get("/{document_id}")
async def get_document(
document_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(Document).where(
Document.id == document_id,
Document.user_id == current_user.id,
)
)
doc = result.scalar_one_or_none()
if not doc:
raise HTTPException(status_code=404, detail="文档不存在")
return doc
@router.get("/{document_id}/chunks")
async def get_document_chunks(
document_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取文档的所有 chunks"""
result = await db.execute(
select(Document).where(
Document.id == document_id,
Document.user_id == current_user.id,
)
)
doc = result.scalar_one_or_none()
if not doc:
raise HTTPException(status_code=404, detail="文档不存在")
chunks_result = await db.execute(
select(DocumentChunk)
.where(DocumentChunk.document_id == document_id)
.order_by(DocumentChunk.chunk_index)
)
return chunks_result.scalars().all()
@router.delete("/{document_id}", status_code=204)
async def delete_document(
document_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""删除文档"""
doc_svc = DocumentService(db)
await doc_svc.delete_document(current_user.id, document_id)
@router.post("/search")
async def search_documents(
query: str,
top_k: int = 5,
mode: str = "hybrid",
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
搜索知识库
- query: 搜索查询
- top_k: 返回数量默认5
- mode: hybrid混合/ semantic语义/ keyword关键词
"""
kb_svc = KnowledgeService(db, user_id=current_user.id)
if mode == "keyword":
results = await kb_svc._keyword_search(query, current_user.id, top_k)
elif mode == "semantic":
results = await kb_svc.retrieve(query, current_user.id, top_k, use_rerank=True)
else:
results = await kb_svc.hybrid_search(query, current_user.id, top_k)
return [asdict(r) for r in results]
@router.get("/{document_id}/content")
async def get_document_content(
document_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取文档的文本内容用于AI理解"""
from app.services.document_service import DocumentService
doc_svc = DocumentService(db)
content = await doc_svc.get_document_content(current_user.id, document_id)
if content is None:
raise HTTPException(status_code=404, detail="文档不存在或无内容")
return {"content": content}