Normalize uploaded documents into structured markdown, add clearer parser errors for missing dependencies, and cover the ingestion flow with backend tests. This also replaces deprecated UTC timestamp helpers in the touched backend paths so the knowledge pipeline stays warning-free. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
186 lines
6.1 KiB
Python
186 lines
6.1 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, BackgroundTasks, Form
|
||
from typing import Optional
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select
|
||
from app.database import get_db
|
||
from app.models.document import Document, DocumentChunk
|
||
from app.models.user import User
|
||
from app.routers.auth import get_current_user
|
||
from app.services.document_service import DocumentService
|
||
from app.services.knowledge_service import KnowledgeService
|
||
from app.schemas.document import DocumentChunkOut, DocumentChunkUpdate, DocumentOut
|
||
from dataclasses import asdict
|
||
|
||
router = APIRouter(prefix="/api/documents", tags=["知识库"])
|
||
|
||
|
||
@router.get("", response_model=list[DocumentOut])
|
||
async def list_documents(
|
||
folder_id: Optional[str] = None,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
query = select(Document).where(Document.user_id == current_user.id)
|
||
if folder_id:
|
||
query = query.where(Document.folder_id == folder_id)
|
||
result = await db.execute(query.order_by(Document.created_at.desc()))
|
||
return result.scalars().all()
|
||
|
||
|
||
@router.post("/upload", status_code=201)
|
||
async def upload_document(
|
||
background: BackgroundTasks,
|
||
file: UploadFile = File(...),
|
||
folder_id: Optional[str] = Form(None),
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""上传文档,自动分块并向量化"""
|
||
doc_svc = DocumentService(db)
|
||
try:
|
||
doc = await doc_svc.upload_document(current_user.id, file, folder_id=folder_id)
|
||
except ValueError as error:
|
||
raise HTTPException(status_code=400, detail=str(error)) from error
|
||
|
||
# 后台索引到 ChromaDB
|
||
def index_task():
|
||
import asyncio
|
||
from app.database import async_session
|
||
from app.services.knowledge_service import KnowledgeService
|
||
|
||
async def _index():
|
||
async with async_session() as session:
|
||
kb_svc = KnowledgeService(session, user_id=current_user.id)
|
||
await kb_svc.index_document(doc.id, user_id=current_user.id)
|
||
|
||
asyncio.run(_index())
|
||
|
||
background.add_task(index_task)
|
||
return {"id": doc.id, "title": doc.title, "chunk_count": doc.chunk_count, "status": "上传成功,正在索引..."}
|
||
|
||
|
||
@router.get("/{document_id}")
|
||
async def get_document(
|
||
document_id: str,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
result = await db.execute(
|
||
select(Document).where(
|
||
Document.id == document_id,
|
||
Document.user_id == current_user.id,
|
||
)
|
||
)
|
||
doc = result.scalar_one_or_none()
|
||
if not doc:
|
||
raise HTTPException(status_code=404, detail="文档不存在")
|
||
return doc
|
||
|
||
|
||
@router.get("/{document_id}/chunks", response_model=list[DocumentChunkOut])
|
||
async def get_document_chunks(
|
||
document_id: str,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""获取文档的所有 chunks"""
|
||
result = await db.execute(
|
||
select(Document).where(
|
||
Document.id == document_id,
|
||
Document.user_id == current_user.id,
|
||
)
|
||
)
|
||
doc = result.scalar_one_or_none()
|
||
if not doc:
|
||
raise HTTPException(status_code=404, detail="文档不存在")
|
||
|
||
chunks_result = await db.execute(
|
||
select(DocumentChunk)
|
||
.where(DocumentChunk.document_id == document_id)
|
||
.order_by(DocumentChunk.chunk_index)
|
||
)
|
||
return chunks_result.scalars().all()
|
||
|
||
|
||
@router.put("/{document_id}/chunks/{chunk_id}", response_model=DocumentChunkOut)
|
||
async def update_document_chunk(
|
||
document_id: str,
|
||
chunk_id: str,
|
||
payload: DocumentChunkUpdate,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
doc_svc = DocumentService(db)
|
||
kb_svc = KnowledgeService(db, user_id=current_user.id)
|
||
|
||
try:
|
||
chunk = await doc_svc.update_document_chunk(current_user.id, document_id, chunk_id, payload.content)
|
||
except ValueError as error:
|
||
raise HTTPException(status_code=404, detail=str(error)) from error
|
||
|
||
reindexed = await kb_svc.reindex_document_chunks(document_id, current_user.id)
|
||
if not reindexed:
|
||
raise HTTPException(status_code=500, detail="切片更新后重新索引失败")
|
||
|
||
refreshed_chunk_result = await db.execute(
|
||
select(DocumentChunk).where(DocumentChunk.id == chunk.id)
|
||
)
|
||
refreshed_chunk = refreshed_chunk_result.scalar_one()
|
||
return refreshed_chunk
|
||
|
||
|
||
@router.delete("/{document_id}", status_code=204)
|
||
async def delete_document(
|
||
document_id: str,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""删除文档"""
|
||
doc_svc = DocumentService(db)
|
||
await doc_svc.delete_document(current_user.id, document_id)
|
||
|
||
|
||
@router.post("/search")
|
||
async def search_documents(
|
||
query: str,
|
||
top_k: int = 5,
|
||
mode: str = "hybrid",
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
搜索知识库
|
||
|
||
- query: 搜索查询
|
||
- top_k: 返回数量,默认5
|
||
- mode: hybrid(混合)/ semantic(语义)/ keyword(关键词)
|
||
"""
|
||
kb_svc = KnowledgeService(db, user_id=current_user.id)
|
||
|
||
if mode == "keyword":
|
||
results = await kb_svc._keyword_search(query, current_user.id, top_k)
|
||
elif mode == "semantic":
|
||
results = await kb_svc.retrieve(query, current_user.id, top_k=top_k, use_rerank=True)
|
||
else:
|
||
results = await kb_svc.hybrid_search(query, current_user.id, top_k)
|
||
|
||
return [asdict(r) for r in results]
|
||
|
||
|
||
@router.get("/{document_id}/content")
|
||
async def get_document_content(
|
||
document_id: str,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""获取文档的文本内容(用于AI理解)"""
|
||
from app.services.document_service import DocumentService
|
||
|
||
doc_svc = DocumentService(db)
|
||
content = await doc_svc.get_document_content(current_user.id, document_id)
|
||
|
||
if content is None:
|
||
raise HTTPException(status_code=404, detail="文档不存在或无内容")
|
||
|
||
return {"content": content}
|