Files
JARVIS/backend/app/services/knowledge_service.py

309 lines
9.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
知识库服务 - ChromaDB 向量检索 + 混合检索 + Rerank
检索策略:
1. 语义检索 (dense) - ChromaDB 向量相似度
2. 关键词检索 (sparse) - SQL LIKE
3. 混合检索 - 语义 + 关键词 加权融合
4. Rerank - 二次排序优化结果
5. 上下文丰富 - 自动获取前/后 chunk 提供完整语境
"""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, or_
from app.models.document import Document, DocumentChunk
from app.models.folder import Folder
from app.config import settings
import chromadb
from chromadb.config import Settings as ChromaSettings
from dataclasses import dataclass
@dataclass
class SearchResult:
chunk_id: str
document_id: str
document_title: str
content: str
score: float
metadata_: str | None = None
prev_chunk: str | None = None
next_chunk: str | None = None
class KnowledgeService:
"""向量知识库检索服务"""
def __init__(self, db: AsyncSession, user_id: str | None = None):
self.db = db
self.user_id = user_id
self._chroma_client = None
@property
def chroma_client(self):
if self._chroma_client is None:
self._chroma_client = chromadb.PersistentClient(
path=settings.CHROMA_PERSIST_DIR,
settings=ChromaSettings(allow_reset=True),
)
return self._chroma_client
def get_collection(self, user_id: str):
return self.chroma_client.get_or_create_collection(
name=f"user_{user_id}",
metadata={"user_id": user_id},
)
async def index_document(self, document_id: str, user_id: str, folder_path: str | None = None):
"""将文档 chunks 向量化存入 ChromaDB"""
result = await self.db.execute(
select(Document).where(Document.id == document_id)
)
doc = result.scalar_one_or_none()
if not doc:
return
chunks_result = await self.db.execute(
select(DocumentChunk)
.where(DocumentChunk.document_id == document_id)
.order_by(DocumentChunk.chunk_index)
)
chunks = list(chunks_result.scalars().all())
if not chunks:
return
collection = self.get_collection(user_id)
ids = [chunk.id for chunk in chunks]
documents = [chunk.content for chunk in chunks]
metadatas = [
{
"document_id": doc.id,
"document_title": doc.title,
"chunk_index": chunk.chunk_index,
"file_type": doc.file_type,
"folder_path": folder_path or "",
}
for chunk in chunks
]
collection.add(ids=ids, documents=documents, metadatas=metadatas)
doc.is_indexed = True
await self.db.commit()
async def retrieve(
self,
query: str,
user_id: str,
folder_id: str | None = None,
top_k: int = 5,
use_rerank: bool = True,
) -> list[SearchResult]:
"""
混合检索 + Rerank支持按文件夹过滤
流程:
1. ChromaDB 向量检索 (扩大候选集)
2. 提取父 chunk完整上下文
3. Rerank 二次排序
4. 返回 top_k 结果
"""
collection = self.get_collection(user_id)
# 构建过滤条件
where = None
if folder_id:
folder_path = await self._get_folder_path(folder_id)
if folder_path:
where = {"folder_path": {"$starts_with": folder_path}}
try:
results = collection.query(
query_texts=[query],
n_results=top_k * 3,
where=where,
include=["documents", "metadatas", "distances"],
)
except Exception:
return []
if not results or not results.get("ids"):
return []
ids = results["ids"][0]
documents = results["documents"][0]
metadatas = results.get("metadatas", [[]])[0]
distances = results.get("distances", [[]])[0]
search_results: list[SearchResult] = []
for i, chunk_id in enumerate(ids):
meta = metadatas[i] if i < len(metadatas) else {}
score = 1.0 - (distances[i] if i < len(distances) else 0.0)
prev_chunk, next_chunk = await self._get_sibling_chunks(
chunk_id=chunk_id,
chunk_index=meta.get("chunk_index", 0),
document_id=meta.get("document_id", ""),
)
search_results.append(SearchResult(
chunk_id=chunk_id,
document_id=meta.get("document_id", ""),
document_title=meta.get("document_title", ""),
content=documents[i] if i < len(documents) else "",
score=score,
metadata_=str(meta),
prev_chunk=prev_chunk,
next_chunk=next_chunk,
))
if use_rerank:
search_results = self._rerank(query, search_results, top_k)
else:
search_results = search_results[:top_k]
return search_results
def _rerank(
self,
query: str,
results: list[SearchResult],
top_k: int,
) -> list[SearchResult]:
"""Rerank: 语义分 * 0.7 + 关键词匹配 * 0.2 + 标题匹配 * 0.1"""
import re
query_words = set(re.findall(r"\w+", query.lower()))
scored = []
for r in results:
score = r.score * 0.7
content_words = set(re.findall(r"\w+", r.content.lower()))
keyword_overlap = len(query_words & content_words) / max(len(query_words), 1)
score += keyword_overlap * 0.2
if r.document_title:
title_words = set(re.findall(r"\w+", r.document_title.lower()))
title_overlap = len(query_words & title_words) / max(len(query_words), 1)
score += title_overlap * 0.1
scored.append((score, r))
scored.sort(key=lambda x: x[0], reverse=True)
return [r for _, r in scored[:top_k]]
async def _get_sibling_chunks(
self,
chunk_id: str,
chunk_index: int,
document_id: str,
) -> tuple[str | None, str | None]:
"""获取前一个和后一个 chunk完整上下文"""
prev_result = await self.db.execute(
select(DocumentChunk).where(
DocumentChunk.document_id == document_id,
DocumentChunk.chunk_index == chunk_index - 1,
)
)
next_result = await self.db.execute(
select(DocumentChunk).where(
DocumentChunk.document_id == document_id,
DocumentChunk.chunk_index == chunk_index + 1,
)
)
prev_chunk = prev_result.scalar_one_or_none()
next_chunk = next_result.scalar_one_or_none()
return (
prev_chunk.content if prev_chunk else None,
next_chunk.content if next_chunk else None,
)
async def _get_folder_path(self, folder_id: str) -> str | None:
"""获取文件夹的完整路径"""
result = await self.db.execute(
select(Folder).where(Folder.id == folder_id)
)
folder = result.scalar_one_or_none()
if not folder:
return None
path_parts = [folder.name]
current_parent_id = folder.parent_id
while current_parent_id:
parent_result = await self.db.execute(
select(Folder).where(Folder.id == current_parent_id)
)
parent = parent_result.scalar_one_or_none()
if not parent:
break
path_parts.insert(0, parent.name)
current_parent_id = parent.parent_id
return "/" + "/".join(path_parts)
async def hybrid_search(
self,
query: str,
user_id: str,
top_k: int = 5,
) -> list[SearchResult]:
"""混合检索: 向量 + 关键词 + Rerank"""
vector_results = await self.retrieve(query, user_id, top_k=top_k * 2, use_rerank=False)
keyword_results = await self._keyword_search(query, user_id, top_k)
seen = set()
merged: list[SearchResult] = []
for r in vector_results + keyword_results:
if r.chunk_id not in seen:
seen.add(r.chunk_id)
merged.append(r)
return self._rerank(query, merged, top_k)
async def _keyword_search(
self,
query: str,
user_id: str,
top_k: int,
) -> list[SearchResult]:
"""SQL 关键词搜索"""
result = await self.db.execute(
select(DocumentChunk)
.join(Document)
.where(Document.user_id == user_id)
.where(
or_(
DocumentChunk.content.contains(query),
Document.title.contains(query),
)
)
.limit(top_k)
)
chunks = result.scalars().all()
results = []
for chunk in chunks:
doc_result = await self.db.execute(
select(Document).where(Document.id == chunk.document_id)
)
doc = doc_result.scalar_one_or_none()
results.append(SearchResult(
chunk_id=chunk.id,
document_id=chunk.document_id,
document_title=doc.title if doc else "",
content=chunk.content,
score=0.5,
metadata_=None,
))
return results
async def delete_from_vectorstore(self, user_id: str, document_id: str):
"""从向量库删除文档"""
collection = self.get_collection(user_id)
try:
collection.delete(where={"document_id": document_id})
except Exception:
pass