309 lines
9.8 KiB
Python
309 lines
9.8 KiB
Python
"""
|
||
知识库服务 - ChromaDB 向量检索 + 混合检索 + Rerank
|
||
|
||
检索策略:
|
||
1. 语义检索 (dense) - ChromaDB 向量相似度
|
||
2. 关键词检索 (sparse) - SQL LIKE
|
||
3. 混合检索 - 语义 + 关键词 加权融合
|
||
4. Rerank - 二次排序优化结果
|
||
5. 上下文丰富 - 自动获取前/后 chunk 提供完整语境
|
||
"""
|
||
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select, or_
|
||
from app.models.document import Document, DocumentChunk
|
||
from app.models.folder import Folder
|
||
from app.config import settings
|
||
import chromadb
|
||
from chromadb.config import Settings as ChromaSettings
|
||
from dataclasses import dataclass
|
||
|
||
|
||
@dataclass
|
||
class SearchResult:
|
||
chunk_id: str
|
||
document_id: str
|
||
document_title: str
|
||
content: str
|
||
score: float
|
||
metadata_: str | None = None
|
||
prev_chunk: str | None = None
|
||
next_chunk: str | None = None
|
||
|
||
|
||
class KnowledgeService:
|
||
"""向量知识库检索服务"""
|
||
|
||
def __init__(self, db: AsyncSession, user_id: str | None = None):
|
||
self.db = db
|
||
self.user_id = user_id
|
||
self._chroma_client = None
|
||
|
||
@property
|
||
def chroma_client(self):
|
||
if self._chroma_client is None:
|
||
self._chroma_client = chromadb.PersistentClient(
|
||
path=settings.CHROMA_PERSIST_DIR,
|
||
settings=ChromaSettings(allow_reset=True),
|
||
)
|
||
return self._chroma_client
|
||
|
||
def get_collection(self, user_id: str):
|
||
return self.chroma_client.get_or_create_collection(
|
||
name=f"user_{user_id}",
|
||
metadata={"user_id": user_id},
|
||
)
|
||
|
||
async def index_document(self, document_id: str, user_id: str, folder_path: str | None = None):
|
||
"""将文档 chunks 向量化存入 ChromaDB"""
|
||
result = await self.db.execute(
|
||
select(Document).where(Document.id == document_id)
|
||
)
|
||
doc = result.scalar_one_or_none()
|
||
if not doc:
|
||
return
|
||
|
||
chunks_result = await self.db.execute(
|
||
select(DocumentChunk)
|
||
.where(DocumentChunk.document_id == document_id)
|
||
.order_by(DocumentChunk.chunk_index)
|
||
)
|
||
chunks = list(chunks_result.scalars().all())
|
||
if not chunks:
|
||
return
|
||
|
||
collection = self.get_collection(user_id)
|
||
|
||
ids = [chunk.id for chunk in chunks]
|
||
documents = [chunk.content for chunk in chunks]
|
||
metadatas = [
|
||
{
|
||
"document_id": doc.id,
|
||
"document_title": doc.title,
|
||
"chunk_index": chunk.chunk_index,
|
||
"file_type": doc.file_type,
|
||
"folder_path": folder_path or "",
|
||
}
|
||
for chunk in chunks
|
||
]
|
||
|
||
collection.add(ids=ids, documents=documents, metadatas=metadatas)
|
||
|
||
doc.is_indexed = True
|
||
await self.db.commit()
|
||
|
||
async def retrieve(
|
||
self,
|
||
query: str,
|
||
user_id: str,
|
||
folder_id: str | None = None,
|
||
top_k: int = 5,
|
||
use_rerank: bool = True,
|
||
) -> list[SearchResult]:
|
||
"""
|
||
混合检索 + Rerank,支持按文件夹过滤
|
||
|
||
流程:
|
||
1. ChromaDB 向量检索 (扩大候选集)
|
||
2. 提取父 chunk(完整上下文)
|
||
3. Rerank 二次排序
|
||
4. 返回 top_k 结果
|
||
"""
|
||
collection = self.get_collection(user_id)
|
||
|
||
# 构建过滤条件
|
||
where = None
|
||
if folder_id:
|
||
folder_path = await self._get_folder_path(folder_id)
|
||
if folder_path:
|
||
where = {"folder_path": {"$starts_with": folder_path}}
|
||
|
||
try:
|
||
results = collection.query(
|
||
query_texts=[query],
|
||
n_results=top_k * 3,
|
||
where=where,
|
||
include=["documents", "metadatas", "distances"],
|
||
)
|
||
except Exception:
|
||
return []
|
||
|
||
if not results or not results.get("ids"):
|
||
return []
|
||
|
||
ids = results["ids"][0]
|
||
documents = results["documents"][0]
|
||
metadatas = results.get("metadatas", [[]])[0]
|
||
distances = results.get("distances", [[]])[0]
|
||
|
||
search_results: list[SearchResult] = []
|
||
for i, chunk_id in enumerate(ids):
|
||
meta = metadatas[i] if i < len(metadatas) else {}
|
||
score = 1.0 - (distances[i] if i < len(distances) else 0.0)
|
||
|
||
prev_chunk, next_chunk = await self._get_sibling_chunks(
|
||
chunk_id=chunk_id,
|
||
chunk_index=meta.get("chunk_index", 0),
|
||
document_id=meta.get("document_id", ""),
|
||
)
|
||
|
||
search_results.append(SearchResult(
|
||
chunk_id=chunk_id,
|
||
document_id=meta.get("document_id", ""),
|
||
document_title=meta.get("document_title", ""),
|
||
content=documents[i] if i < len(documents) else "",
|
||
score=score,
|
||
metadata_=str(meta),
|
||
prev_chunk=prev_chunk,
|
||
next_chunk=next_chunk,
|
||
))
|
||
|
||
if use_rerank:
|
||
search_results = self._rerank(query, search_results, top_k)
|
||
else:
|
||
search_results = search_results[:top_k]
|
||
|
||
return search_results
|
||
|
||
def _rerank(
|
||
self,
|
||
query: str,
|
||
results: list[SearchResult],
|
||
top_k: int,
|
||
) -> list[SearchResult]:
|
||
"""Rerank: 语义分 * 0.7 + 关键词匹配 * 0.2 + 标题匹配 * 0.1"""
|
||
import re
|
||
|
||
query_words = set(re.findall(r"\w+", query.lower()))
|
||
|
||
scored = []
|
||
for r in results:
|
||
score = r.score * 0.7
|
||
|
||
content_words = set(re.findall(r"\w+", r.content.lower()))
|
||
keyword_overlap = len(query_words & content_words) / max(len(query_words), 1)
|
||
score += keyword_overlap * 0.2
|
||
|
||
if r.document_title:
|
||
title_words = set(re.findall(r"\w+", r.document_title.lower()))
|
||
title_overlap = len(query_words & title_words) / max(len(query_words), 1)
|
||
score += title_overlap * 0.1
|
||
|
||
scored.append((score, r))
|
||
|
||
scored.sort(key=lambda x: x[0], reverse=True)
|
||
return [r for _, r in scored[:top_k]]
|
||
|
||
async def _get_sibling_chunks(
|
||
self,
|
||
chunk_id: str,
|
||
chunk_index: int,
|
||
document_id: str,
|
||
) -> tuple[str | None, str | None]:
|
||
"""获取前一个和后一个 chunk(完整上下文)"""
|
||
prev_result = await self.db.execute(
|
||
select(DocumentChunk).where(
|
||
DocumentChunk.document_id == document_id,
|
||
DocumentChunk.chunk_index == chunk_index - 1,
|
||
)
|
||
)
|
||
next_result = await self.db.execute(
|
||
select(DocumentChunk).where(
|
||
DocumentChunk.document_id == document_id,
|
||
DocumentChunk.chunk_index == chunk_index + 1,
|
||
)
|
||
)
|
||
prev_chunk = prev_result.scalar_one_or_none()
|
||
next_chunk = next_result.scalar_one_or_none()
|
||
return (
|
||
prev_chunk.content if prev_chunk else None,
|
||
next_chunk.content if next_chunk else None,
|
||
)
|
||
|
||
async def _get_folder_path(self, folder_id: str) -> str | None:
|
||
"""获取文件夹的完整路径"""
|
||
result = await self.db.execute(
|
||
select(Folder).where(Folder.id == folder_id)
|
||
)
|
||
folder = result.scalar_one_or_none()
|
||
if not folder:
|
||
return None
|
||
|
||
path_parts = [folder.name]
|
||
current_parent_id = folder.parent_id
|
||
|
||
while current_parent_id:
|
||
parent_result = await self.db.execute(
|
||
select(Folder).where(Folder.id == current_parent_id)
|
||
)
|
||
parent = parent_result.scalar_one_or_none()
|
||
if not parent:
|
||
break
|
||
path_parts.insert(0, parent.name)
|
||
current_parent_id = parent.parent_id
|
||
|
||
return "/" + "/".join(path_parts)
|
||
|
||
async def hybrid_search(
|
||
self,
|
||
query: str,
|
||
user_id: str,
|
||
top_k: int = 5,
|
||
) -> list[SearchResult]:
|
||
"""混合检索: 向量 + 关键词 + Rerank"""
|
||
vector_results = await self.retrieve(query, user_id, top_k=top_k * 2, use_rerank=False)
|
||
keyword_results = await self._keyword_search(query, user_id, top_k)
|
||
|
||
seen = set()
|
||
merged: list[SearchResult] = []
|
||
for r in vector_results + keyword_results:
|
||
if r.chunk_id not in seen:
|
||
seen.add(r.chunk_id)
|
||
merged.append(r)
|
||
|
||
return self._rerank(query, merged, top_k)
|
||
|
||
async def _keyword_search(
|
||
self,
|
||
query: str,
|
||
user_id: str,
|
||
top_k: int,
|
||
) -> list[SearchResult]:
|
||
"""SQL 关键词搜索"""
|
||
result = await self.db.execute(
|
||
select(DocumentChunk)
|
||
.join(Document)
|
||
.where(Document.user_id == user_id)
|
||
.where(
|
||
or_(
|
||
DocumentChunk.content.contains(query),
|
||
Document.title.contains(query),
|
||
)
|
||
)
|
||
.limit(top_k)
|
||
)
|
||
chunks = result.scalars().all()
|
||
results = []
|
||
for chunk in chunks:
|
||
doc_result = await self.db.execute(
|
||
select(Document).where(Document.id == chunk.document_id)
|
||
)
|
||
doc = doc_result.scalar_one_or_none()
|
||
results.append(SearchResult(
|
||
chunk_id=chunk.id,
|
||
document_id=chunk.document_id,
|
||
document_title=doc.title if doc else "",
|
||
content=chunk.content,
|
||
score=0.5,
|
||
metadata_=None,
|
||
))
|
||
return results
|
||
|
||
async def delete_from_vectorstore(self, user_id: str, document_id: str):
|
||
"""从向量库删除文档"""
|
||
collection = self.get_collection(user_id)
|
||
try:
|
||
collection.delete(where={"document_id": document_id})
|
||
except Exception:
|
||
pass
|