""" 知识库服务 - ChromaDB 向量检索 + 混合检索 + Rerank 检索策略: 1. 语义检索 (dense) - ChromaDB 向量相似度 2. 关键词检索 (sparse) - SQL LIKE 3. 混合检索 - 语义 + 关键词 加权融合 4. Rerank - 二次排序优化结果 5. 上下文丰富 - 自动获取前/后 chunk 提供完整语境 """ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, or_ from app.models.document import Document, DocumentChunk from app.models.folder import Folder from app.config import settings from app.services.document_service import DocumentService import chromadb from chromadb.config import Settings as ChromaSettings from dataclasses import dataclass from datetime import UTC, datetime import json @dataclass class SearchResult: chunk_id: str document_id: str document_title: str content: str score: float metadata_: str | None = None prev_chunk: str | None = None next_chunk: str | None = None class KnowledgeService: """向量知识库检索服务""" def __init__(self, db: AsyncSession, user_id: str | None = None): self.db = db self.user_id = user_id self._chroma_client = None @property def chroma_client(self): if self._chroma_client is None: self._chroma_client = chromadb.PersistentClient( path=settings.CHROMA_PERSIST_DIR, settings=ChromaSettings(allow_reset=True), ) return self._chroma_client def get_collection(self, user_id: str): return self.chroma_client.get_or_create_collection( name=f"user_{user_id}", metadata={"user_id": user_id}, ) async def index_document(self, document_id: str, user_id: str, folder_path: str | None = None): """将文档 chunks 向量化存入 ChromaDB""" result = await self.db.execute( select(Document).where(Document.id == document_id) ) doc = result.scalar_one_or_none() if not doc: return chunks_result = await self.db.execute( select(DocumentChunk) .where(DocumentChunk.document_id == document_id) .order_by(DocumentChunk.chunk_index) ) chunks = list(chunks_result.scalars().all()) if not chunks: return await self._index_chunks(doc, chunks, user_id, folder_path=folder_path) async def _index_chunks( self, document: Document, chunks: list[DocumentChunk], user_id: str, folder_path: str | None = None, ): folder_path = folder_path or (await self._get_folder_path(document.folder_id) if document.folder_id else "") collection = self.get_collection(user_id) ids = [chunk.id for chunk in chunks] documents = [chunk.content for chunk in chunks] metadatas = [] for chunk in chunks: chunk_metadata = self._parse_metadata(chunk.metadata_) meta = { "document_id": document.id, "document_title": document.title, "document_filename": document.filename, "chunk_index": chunk.chunk_index, "file_type": document.file_type, "folder_path": folder_path or "", "content_type": chunk_metadata.get("content_type", "text"), "section_title": chunk_metadata.get("section_title") or "", "section_path": " / ".join(chunk_metadata.get("section_path", [])), "page_number": chunk_metadata.get("page_number") or 0, "sheet_name": chunk_metadata.get("sheet_name") or "", "row_start": chunk_metadata.get("row_start") or 0, "row_end": chunk_metadata.get("row_end") or 0, "parser_version": chunk_metadata.get("parser_version") or document.parser_version or "", "index_version": chunk_metadata.get("index_version") or document.index_version or "", } chunk.chroma_collection = f"user_{user_id}" chunk.chroma_id = chunk.id metadatas.append(meta) collection.add(ids=ids, documents=documents, metadatas=metadatas) document.is_indexed = True document.ingestion_status = "ready" document.ingestion_error = None document.indexed_at = datetime.now(UTC) await self.db.commit() async def retrieve( self, query: str, user_id: str, folder_id: str | None = None, top_k: int = 5, use_rerank: bool = True, ) -> list[SearchResult]: """ 混合检索 + Rerank,支持按文件夹过滤 流程: 1. ChromaDB 向量检索 (扩大候选集) 2. 提取父 chunk(完整上下文) 3. Rerank 二次排序 4. 返回 top_k 结果 """ collection = self.get_collection(user_id) # 构建过滤条件 where = None if folder_id: folder_path = await self._get_folder_path(folder_id) if folder_path: where = {"folder_path": {"$starts_with": folder_path}} try: results = collection.query( query_texts=[query], n_results=top_k * 3, where=where, include=["documents", "metadatas", "distances"], ) except Exception: return [] if not results or not results.get("ids"): return [] ids = results["ids"][0] documents = results["documents"][0] metadatas = results.get("metadatas", [[]])[0] distances = results.get("distances", [[]])[0] search_results: list[SearchResult] = [] for i, chunk_id in enumerate(ids): meta = metadatas[i] if i < len(metadatas) else {} score = 1.0 - (distances[i] if i < len(distances) else 0.0) prev_chunk, next_chunk = await self._get_related_chunks( chunk_id=chunk_id, chunk_index=meta.get("chunk_index", 0), document_id=meta.get("document_id", ""), ) search_results.append(SearchResult( chunk_id=chunk_id, document_id=meta.get("document_id", ""), document_title=meta.get("document_title", ""), content=documents[i] if i < len(documents) else "", score=score, metadata_=json.dumps(meta, ensure_ascii=False), prev_chunk=prev_chunk, next_chunk=next_chunk, )) if use_rerank: search_results = self._rerank(query, search_results, top_k) else: search_results = search_results[:top_k] return search_results def _rerank( self, query: str, results: list[SearchResult], top_k: int, ) -> list[SearchResult]: """Rerank: 语义分 * 0.7 + 关键词匹配 * 0.2 + 标题匹配 * 0.1 + 结构加权""" import re query_words = set(re.findall(r"\w+", query.lower())) table_query = any(token in query.lower() for token in ["sheet", "excel", "csv", "表", "列", "金额", "统计", "日期"]) scored = [] for r in results: score = r.score * 0.7 content_words = set(re.findall(r"\w+", r.content.lower())) keyword_overlap = len(query_words & content_words) / max(len(query_words), 1) score += keyword_overlap * 0.2 if r.document_title: title_words = set(re.findall(r"\w+", r.document_title.lower())) title_overlap = len(query_words & title_words) / max(len(query_words), 1) score += title_overlap * 0.1 metadata = self._parse_metadata(r.metadata_) if table_query and metadata.get("content_type") == "table_schema": score += 0.25 elif table_query and metadata.get("content_type") == "table_rows": score += 0.15 scored.append((score, r)) scored.sort(key=lambda x: x[0], reverse=True) return [r for _, r in scored[:top_k]] async def _get_related_chunks( self, chunk_id: str, chunk_index: int, document_id: str, ) -> tuple[str | None, str | None]: """获取结构相关的上下文 chunk""" current_result = await self.db.execute( select(DocumentChunk).where(DocumentChunk.id == chunk_id) ) current_chunk = current_result.scalar_one_or_none() if not current_chunk: return None, None current_metadata = self._parse_metadata(current_chunk.metadata_) section_path = current_metadata.get("section_path") or [] sheet_name = current_metadata.get("sheet_name") chunk_result = await self.db.execute( select(DocumentChunk) .where(DocumentChunk.document_id == document_id) .order_by(DocumentChunk.chunk_index) ) chunks = list(chunk_result.scalars().all()) prev_chunk = None next_chunk = None for chunk in chunks: if chunk.id == chunk_id: continue metadata = self._parse_metadata(chunk.metadata_) same_sheet = bool(sheet_name) and metadata.get("sheet_name") == sheet_name same_section = bool(section_path) and metadata.get("section_path") == section_path if chunk.chunk_index < chunk_index and (same_sheet or same_section): prev_chunk = chunk.content if chunk.chunk_index > chunk_index and (same_sheet or same_section): next_chunk = chunk.content break return prev_chunk, next_chunk async def _get_folder_path(self, folder_id: str) -> str | None: """获取文件夹的完整路径""" result = await self.db.execute( select(Folder).where(Folder.id == folder_id) ) folder = result.scalar_one_or_none() if not folder: return None path_parts = [folder.name] current_parent_id = folder.parent_id while current_parent_id: parent_result = await self.db.execute( select(Folder).where(Folder.id == current_parent_id) ) parent = parent_result.scalar_one_or_none() if not parent: break path_parts.insert(0, parent.name) current_parent_id = parent.parent_id return "/" + "/".join(path_parts) def _parse_metadata(self, raw_metadata: str | dict | None) -> dict: if isinstance(raw_metadata, dict): return raw_metadata if not raw_metadata: return {} try: return json.loads(raw_metadata) except (TypeError, json.JSONDecodeError): return {} async def hybrid_search( self, query: str, user_id: str, top_k: int = 5, ) -> list[SearchResult]: """混合检索: 向量 + 关键词 + Rerank""" vector_results = await self.retrieve(query, user_id, top_k=top_k * 2, use_rerank=False) keyword_results = await self._keyword_search(query, user_id, top_k) seen = set() merged: list[SearchResult] = [] for r in vector_results + keyword_results: if r.chunk_id not in seen: seen.add(r.chunk_id) merged.append(r) return self._rerank(query, merged, top_k) async def _keyword_search( self, query: str, user_id: str, top_k: int, ) -> list[SearchResult]: """SQL 关键词搜索""" result = await self.db.execute( select(DocumentChunk) .join(Document) .where(Document.user_id == user_id) .where( or_( DocumentChunk.content.contains(query), Document.title.contains(query), ) ) .limit(top_k) ) chunks = result.scalars().all() results = [] for chunk in chunks: doc_result = await self.db.execute( select(Document).where(Document.id == chunk.document_id) ) doc = doc_result.scalar_one_or_none() results.append(SearchResult( chunk_id=chunk.id, document_id=chunk.document_id, document_title=doc.title if doc else "", content=chunk.content, score=0.5, metadata_=None, )) return results async def delete_from_vectorstore(self, user_id: str, document_id: str): """从向量库删除文档""" collection = self.get_collection(user_id) try: collection.delete(where={"document_id": document_id}) except Exception: pass async def reindex_document(self, document_id: str, user_id: str) -> bool: result = await self.db.execute( select(Document).where( Document.id == document_id, Document.user_id == user_id, ) ) document = result.scalar_one_or_none() if not document: return False await self.delete_from_vectorstore(user_id, document_id) document = await DocumentService(self.db, user_id=user_id).rebuild_document(document) await self.index_document(document.id, user_id) return True async def reindex_document_chunks(self, document_id: str, user_id: str) -> bool: result = await self.db.execute( select(Document).where( Document.id == document_id, Document.user_id == user_id, ) ) document = result.scalar_one_or_none() if not document: return False chunks_result = await self.db.execute( select(DocumentChunk) .where(DocumentChunk.document_id == document_id) .order_by(DocumentChunk.chunk_index) ) chunks = list(chunks_result.scalars().all()) if not chunks: return False await self.delete_from_vectorstore(user_id, document_id) await self._index_chunks(document, chunks, user_id) return True