""" 知识库服务 - ChromaDB 向量检索 + 混合检索 + Rerank 检索策略: 1. 语义检索 (dense) - ChromaDB 向量相似度 2. 关键词检索 (sparse) - SQL LIKE 3. 混合检索 - 语义 + 关键词 加权融合 4. Rerank - 二次排序优化结果 5. 上下文丰富 - 自动获取前/后 chunk 提供完整语境 """ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, or_ from app.models.document import Document, DocumentChunk from app.models.folder import Folder from app.config import settings import chromadb from chromadb.config import Settings as ChromaSettings from dataclasses import dataclass @dataclass class SearchResult: chunk_id: str document_id: str document_title: str content: str score: float metadata_: str | None = None prev_chunk: str | None = None next_chunk: str | None = None class KnowledgeService: """向量知识库检索服务""" def __init__(self, db: AsyncSession, user_id: str | None = None): self.db = db self.user_id = user_id self._chroma_client = None @property def chroma_client(self): if self._chroma_client is None: self._chroma_client = chromadb.PersistentClient( path=settings.CHROMA_PERSIST_DIR, settings=ChromaSettings(allow_reset=True), ) return self._chroma_client def get_collection(self, user_id: str): return self.chroma_client.get_or_create_collection( name=f"user_{user_id}", metadata={"user_id": user_id}, ) async def index_document(self, document_id: str, user_id: str, folder_path: str | None = None): """将文档 chunks 向量化存入 ChromaDB""" result = await self.db.execute( select(Document).where(Document.id == document_id) ) doc = result.scalar_one_or_none() if not doc: return chunks_result = await self.db.execute( select(DocumentChunk) .where(DocumentChunk.document_id == document_id) .order_by(DocumentChunk.chunk_index) ) chunks = list(chunks_result.scalars().all()) if not chunks: return collection = self.get_collection(user_id) ids = [chunk.id for chunk in chunks] documents = [chunk.content for chunk in chunks] metadatas = [ { "document_id": doc.id, "document_title": doc.title, "chunk_index": chunk.chunk_index, "file_type": doc.file_type, "folder_path": folder_path or "", } for chunk in chunks ] collection.add(ids=ids, documents=documents, metadatas=metadatas) doc.is_indexed = True await self.db.commit() async def retrieve( self, query: str, user_id: str, folder_id: str | None = None, top_k: int = 5, use_rerank: bool = True, ) -> list[SearchResult]: """ 混合检索 + Rerank,支持按文件夹过滤 流程: 1. ChromaDB 向量检索 (扩大候选集) 2. 提取父 chunk(完整上下文) 3. Rerank 二次排序 4. 返回 top_k 结果 """ collection = self.get_collection(user_id) # 构建过滤条件 where = None if folder_id: folder_path = await self._get_folder_path(folder_id) if folder_path: where = {"folder_path": {"$starts_with": folder_path}} try: results = collection.query( query_texts=[query], n_results=top_k * 3, where=where, include=["documents", "metadatas", "distances"], ) except Exception: return [] if not results or not results.get("ids"): return [] ids = results["ids"][0] documents = results["documents"][0] metadatas = results.get("metadatas", [[]])[0] distances = results.get("distances", [[]])[0] search_results: list[SearchResult] = [] for i, chunk_id in enumerate(ids): meta = metadatas[i] if i < len(metadatas) else {} score = 1.0 - (distances[i] if i < len(distances) else 0.0) prev_chunk, next_chunk = await self._get_sibling_chunks( chunk_id=chunk_id, chunk_index=meta.get("chunk_index", 0), document_id=meta.get("document_id", ""), ) search_results.append(SearchResult( chunk_id=chunk_id, document_id=meta.get("document_id", ""), document_title=meta.get("document_title", ""), content=documents[i] if i < len(documents) else "", score=score, metadata_=str(meta), prev_chunk=prev_chunk, next_chunk=next_chunk, )) if use_rerank: search_results = self._rerank(query, search_results, top_k) else: search_results = search_results[:top_k] return search_results def _rerank( self, query: str, results: list[SearchResult], top_k: int, ) -> list[SearchResult]: """Rerank: 语义分 * 0.7 + 关键词匹配 * 0.2 + 标题匹配 * 0.1""" import re query_words = set(re.findall(r"\w+", query.lower())) scored = [] for r in results: score = r.score * 0.7 content_words = set(re.findall(r"\w+", r.content.lower())) keyword_overlap = len(query_words & content_words) / max(len(query_words), 1) score += keyword_overlap * 0.2 if r.document_title: title_words = set(re.findall(r"\w+", r.document_title.lower())) title_overlap = len(query_words & title_words) / max(len(query_words), 1) score += title_overlap * 0.1 scored.append((score, r)) scored.sort(key=lambda x: x[0], reverse=True) return [r for _, r in scored[:top_k]] async def _get_sibling_chunks( self, chunk_id: str, chunk_index: int, document_id: str, ) -> tuple[str | None, str | None]: """获取前一个和后一个 chunk(完整上下文)""" prev_result = await self.db.execute( select(DocumentChunk).where( DocumentChunk.document_id == document_id, DocumentChunk.chunk_index == chunk_index - 1, ) ) next_result = await self.db.execute( select(DocumentChunk).where( DocumentChunk.document_id == document_id, DocumentChunk.chunk_index == chunk_index + 1, ) ) prev_chunk = prev_result.scalar_one_or_none() next_chunk = next_result.scalar_one_or_none() return ( prev_chunk.content if prev_chunk else None, next_chunk.content if next_chunk else None, ) async def _get_folder_path(self, folder_id: str) -> str | None: """获取文件夹的完整路径""" result = await self.db.execute( select(Folder).where(Folder.id == folder_id) ) folder = result.scalar_one_or_none() if not folder: return None path_parts = [folder.name] current_parent_id = folder.parent_id while current_parent_id: parent_result = await self.db.execute( select(Folder).where(Folder.id == current_parent_id) ) parent = parent_result.scalar_one_or_none() if not parent: break path_parts.insert(0, parent.name) current_parent_id = parent.parent_id return "/" + "/".join(path_parts) async def hybrid_search( self, query: str, user_id: str, top_k: int = 5, ) -> list[SearchResult]: """混合检索: 向量 + 关键词 + Rerank""" vector_results = await self.retrieve(query, user_id, top_k=top_k * 2, use_rerank=False) keyword_results = await self._keyword_search(query, user_id, top_k) seen = set() merged: list[SearchResult] = [] for r in vector_results + keyword_results: if r.chunk_id not in seen: seen.add(r.chunk_id) merged.append(r) return self._rerank(query, merged, top_k) async def _keyword_search( self, query: str, user_id: str, top_k: int, ) -> list[SearchResult]: """SQL 关键词搜索""" result = await self.db.execute( select(DocumentChunk) .join(Document) .where(Document.user_id == user_id) .where( or_( DocumentChunk.content.contains(query), Document.title.contains(query), ) ) .limit(top_k) ) chunks = result.scalars().all() results = [] for chunk in chunks: doc_result = await self.db.execute( select(Document).where(Document.id == chunk.document_id) ) doc = doc_result.scalar_one_or_none() results.append(SearchResult( chunk_id=chunk.id, document_id=chunk.document_id, document_title=doc.title if doc else "", content=chunk.content, score=0.5, metadata_=None, )) return results async def delete_from_vectorstore(self, user_id: str, document_id: str): """从向量库删除文档""" collection = self.get_collection(user_id) try: collection.delete(where={"document_id": document_id}) except Exception: pass