2026-03-21 10:13:29 +08:00
|
|
|
|
"""
|
|
|
|
|
|
知识库服务 - ChromaDB 向量检索 + 混合检索 + Rerank
|
|
|
|
|
|
|
|
|
|
|
|
检索策略:
|
|
|
|
|
|
1. 语义检索 (dense) - ChromaDB 向量相似度
|
|
|
|
|
|
2. 关键词检索 (sparse) - SQL LIKE
|
|
|
|
|
|
3. 混合检索 - 语义 + 关键词 加权融合
|
|
|
|
|
|
4. Rerank - 二次排序优化结果
|
|
|
|
|
|
5. 上下文丰富 - 自动获取前/后 chunk 提供完整语境
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
from sqlalchemy import select, or_
|
|
|
|
|
|
from app.models.document import Document, DocumentChunk
|
|
|
|
|
|
from app.models.folder import Folder
|
|
|
|
|
|
from app.config import settings
|
2026-03-22 13:42:16 +08:00
|
|
|
|
from app.services.document_service import DocumentService
|
2026-03-21 10:13:29 +08:00
|
|
|
|
import chromadb
|
|
|
|
|
|
from chromadb.config import Settings as ChromaSettings
|
|
|
|
|
|
from dataclasses import dataclass
|
2026-03-22 13:42:16 +08:00
|
|
|
|
from datetime import UTC, datetime
|
|
|
|
|
|
import json
|
2026-03-21 10:13:29 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class SearchResult:
|
|
|
|
|
|
chunk_id: str
|
|
|
|
|
|
document_id: str
|
|
|
|
|
|
document_title: str
|
|
|
|
|
|
content: str
|
|
|
|
|
|
score: float
|
|
|
|
|
|
metadata_: str | None = None
|
|
|
|
|
|
prev_chunk: str | None = None
|
|
|
|
|
|
next_chunk: str | None = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class KnowledgeService:
|
|
|
|
|
|
"""向量知识库检索服务"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, db: AsyncSession, user_id: str | None = None):
|
|
|
|
|
|
self.db = db
|
|
|
|
|
|
self.user_id = user_id
|
|
|
|
|
|
self._chroma_client = None
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
|
def chroma_client(self):
|
|
|
|
|
|
if self._chroma_client is None:
|
|
|
|
|
|
self._chroma_client = chromadb.PersistentClient(
|
|
|
|
|
|
path=settings.CHROMA_PERSIST_DIR,
|
|
|
|
|
|
settings=ChromaSettings(allow_reset=True),
|
|
|
|
|
|
)
|
|
|
|
|
|
return self._chroma_client
|
|
|
|
|
|
|
|
|
|
|
|
def get_collection(self, user_id: str):
|
|
|
|
|
|
return self.chroma_client.get_or_create_collection(
|
|
|
|
|
|
name=f"user_{user_id}",
|
|
|
|
|
|
metadata={"user_id": user_id},
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
async def index_document(self, document_id: str, user_id: str, folder_path: str | None = None):
|
|
|
|
|
|
"""将文档 chunks 向量化存入 ChromaDB"""
|
|
|
|
|
|
result = await self.db.execute(
|
|
|
|
|
|
select(Document).where(Document.id == document_id)
|
|
|
|
|
|
)
|
|
|
|
|
|
doc = result.scalar_one_or_none()
|
|
|
|
|
|
if not doc:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
chunks_result = await self.db.execute(
|
|
|
|
|
|
select(DocumentChunk)
|
|
|
|
|
|
.where(DocumentChunk.document_id == document_id)
|
|
|
|
|
|
.order_by(DocumentChunk.chunk_index)
|
|
|
|
|
|
)
|
|
|
|
|
|
chunks = list(chunks_result.scalars().all())
|
|
|
|
|
|
if not chunks:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
2026-03-22 13:42:16 +08:00
|
|
|
|
await self._index_chunks(doc, chunks, user_id, folder_path=folder_path)
|
|
|
|
|
|
|
|
|
|
|
|
async def _index_chunks(
|
|
|
|
|
|
self,
|
|
|
|
|
|
document: Document,
|
|
|
|
|
|
chunks: list[DocumentChunk],
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
folder_path: str | None = None,
|
|
|
|
|
|
):
|
|
|
|
|
|
folder_path = folder_path or (await self._get_folder_path(document.folder_id) if document.folder_id else "")
|
2026-03-21 10:13:29 +08:00
|
|
|
|
collection = self.get_collection(user_id)
|
|
|
|
|
|
|
|
|
|
|
|
ids = [chunk.id for chunk in chunks]
|
|
|
|
|
|
documents = [chunk.content for chunk in chunks]
|
2026-03-22 13:42:16 +08:00
|
|
|
|
metadatas = []
|
|
|
|
|
|
for chunk in chunks:
|
|
|
|
|
|
chunk_metadata = self._parse_metadata(chunk.metadata_)
|
|
|
|
|
|
meta = {
|
|
|
|
|
|
"document_id": document.id,
|
|
|
|
|
|
"document_title": document.title,
|
|
|
|
|
|
"document_filename": document.filename,
|
2026-03-21 10:13:29 +08:00
|
|
|
|
"chunk_index": chunk.chunk_index,
|
2026-03-22 13:42:16 +08:00
|
|
|
|
"file_type": document.file_type,
|
2026-03-21 10:13:29 +08:00
|
|
|
|
"folder_path": folder_path or "",
|
2026-03-22 13:42:16 +08:00
|
|
|
|
"content_type": chunk_metadata.get("content_type", "text"),
|
|
|
|
|
|
"section_title": chunk_metadata.get("section_title") or "",
|
|
|
|
|
|
"section_path": " / ".join(chunk_metadata.get("section_path", [])),
|
|
|
|
|
|
"page_number": chunk_metadata.get("page_number") or 0,
|
|
|
|
|
|
"sheet_name": chunk_metadata.get("sheet_name") or "",
|
|
|
|
|
|
"row_start": chunk_metadata.get("row_start") or 0,
|
|
|
|
|
|
"row_end": chunk_metadata.get("row_end") or 0,
|
|
|
|
|
|
"parser_version": chunk_metadata.get("parser_version") or document.parser_version or "",
|
|
|
|
|
|
"index_version": chunk_metadata.get("index_version") or document.index_version or "",
|
2026-03-21 10:13:29 +08:00
|
|
|
|
}
|
2026-03-22 13:42:16 +08:00
|
|
|
|
chunk.chroma_collection = f"user_{user_id}"
|
|
|
|
|
|
chunk.chroma_id = chunk.id
|
|
|
|
|
|
metadatas.append(meta)
|
2026-03-21 10:13:29 +08:00
|
|
|
|
|
|
|
|
|
|
collection.add(ids=ids, documents=documents, metadatas=metadatas)
|
|
|
|
|
|
|
2026-03-22 13:42:16 +08:00
|
|
|
|
document.is_indexed = True
|
|
|
|
|
|
document.ingestion_status = "ready"
|
|
|
|
|
|
document.ingestion_error = None
|
|
|
|
|
|
document.indexed_at = datetime.now(UTC)
|
2026-03-21 10:13:29 +08:00
|
|
|
|
await self.db.commit()
|
|
|
|
|
|
|
|
|
|
|
|
async def retrieve(
|
|
|
|
|
|
self,
|
|
|
|
|
|
query: str,
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
folder_id: str | None = None,
|
|
|
|
|
|
top_k: int = 5,
|
|
|
|
|
|
use_rerank: bool = True,
|
|
|
|
|
|
) -> list[SearchResult]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
混合检索 + Rerank,支持按文件夹过滤
|
|
|
|
|
|
|
|
|
|
|
|
流程:
|
|
|
|
|
|
1. ChromaDB 向量检索 (扩大候选集)
|
|
|
|
|
|
2. 提取父 chunk(完整上下文)
|
|
|
|
|
|
3. Rerank 二次排序
|
|
|
|
|
|
4. 返回 top_k 结果
|
|
|
|
|
|
"""
|
|
|
|
|
|
collection = self.get_collection(user_id)
|
|
|
|
|
|
|
|
|
|
|
|
# 构建过滤条件
|
|
|
|
|
|
where = None
|
|
|
|
|
|
if folder_id:
|
|
|
|
|
|
folder_path = await self._get_folder_path(folder_id)
|
|
|
|
|
|
if folder_path:
|
|
|
|
|
|
where = {"folder_path": {"$starts_with": folder_path}}
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
results = collection.query(
|
|
|
|
|
|
query_texts=[query],
|
|
|
|
|
|
n_results=top_k * 3,
|
|
|
|
|
|
where=where,
|
|
|
|
|
|
include=["documents", "metadatas", "distances"],
|
|
|
|
|
|
)
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
if not results or not results.get("ids"):
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
ids = results["ids"][0]
|
|
|
|
|
|
documents = results["documents"][0]
|
|
|
|
|
|
metadatas = results.get("metadatas", [[]])[0]
|
|
|
|
|
|
distances = results.get("distances", [[]])[0]
|
|
|
|
|
|
|
|
|
|
|
|
search_results: list[SearchResult] = []
|
|
|
|
|
|
for i, chunk_id in enumerate(ids):
|
|
|
|
|
|
meta = metadatas[i] if i < len(metadatas) else {}
|
|
|
|
|
|
score = 1.0 - (distances[i] if i < len(distances) else 0.0)
|
|
|
|
|
|
|
2026-03-22 13:42:16 +08:00
|
|
|
|
prev_chunk, next_chunk = await self._get_related_chunks(
|
2026-03-21 10:13:29 +08:00
|
|
|
|
chunk_id=chunk_id,
|
|
|
|
|
|
chunk_index=meta.get("chunk_index", 0),
|
|
|
|
|
|
document_id=meta.get("document_id", ""),
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
search_results.append(SearchResult(
|
|
|
|
|
|
chunk_id=chunk_id,
|
|
|
|
|
|
document_id=meta.get("document_id", ""),
|
|
|
|
|
|
document_title=meta.get("document_title", ""),
|
|
|
|
|
|
content=documents[i] if i < len(documents) else "",
|
|
|
|
|
|
score=score,
|
2026-03-22 13:42:16 +08:00
|
|
|
|
metadata_=json.dumps(meta, ensure_ascii=False),
|
2026-03-21 10:13:29 +08:00
|
|
|
|
prev_chunk=prev_chunk,
|
|
|
|
|
|
next_chunk=next_chunk,
|
|
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
|
|
if use_rerank:
|
|
|
|
|
|
search_results = self._rerank(query, search_results, top_k)
|
|
|
|
|
|
else:
|
|
|
|
|
|
search_results = search_results[:top_k]
|
|
|
|
|
|
|
|
|
|
|
|
return search_results
|
|
|
|
|
|
|
|
|
|
|
|
def _rerank(
|
|
|
|
|
|
self,
|
|
|
|
|
|
query: str,
|
|
|
|
|
|
results: list[SearchResult],
|
|
|
|
|
|
top_k: int,
|
|
|
|
|
|
) -> list[SearchResult]:
|
2026-03-22 13:42:16 +08:00
|
|
|
|
"""Rerank: 语义分 * 0.7 + 关键词匹配 * 0.2 + 标题匹配 * 0.1 + 结构加权"""
|
2026-03-21 10:13:29 +08:00
|
|
|
|
import re
|
|
|
|
|
|
|
|
|
|
|
|
query_words = set(re.findall(r"\w+", query.lower()))
|
2026-03-22 13:42:16 +08:00
|
|
|
|
table_query = any(token in query.lower() for token in ["sheet", "excel", "csv", "表", "列", "金额", "统计", "日期"])
|
2026-03-21 10:13:29 +08:00
|
|
|
|
|
|
|
|
|
|
scored = []
|
|
|
|
|
|
for r in results:
|
|
|
|
|
|
score = r.score * 0.7
|
|
|
|
|
|
|
|
|
|
|
|
content_words = set(re.findall(r"\w+", r.content.lower()))
|
|
|
|
|
|
keyword_overlap = len(query_words & content_words) / max(len(query_words), 1)
|
|
|
|
|
|
score += keyword_overlap * 0.2
|
|
|
|
|
|
|
|
|
|
|
|
if r.document_title:
|
|
|
|
|
|
title_words = set(re.findall(r"\w+", r.document_title.lower()))
|
|
|
|
|
|
title_overlap = len(query_words & title_words) / max(len(query_words), 1)
|
|
|
|
|
|
score += title_overlap * 0.1
|
|
|
|
|
|
|
2026-03-22 13:42:16 +08:00
|
|
|
|
metadata = self._parse_metadata(r.metadata_)
|
|
|
|
|
|
if table_query and metadata.get("content_type") == "table_schema":
|
|
|
|
|
|
score += 0.25
|
|
|
|
|
|
elif table_query and metadata.get("content_type") == "table_rows":
|
|
|
|
|
|
score += 0.15
|
|
|
|
|
|
|
2026-03-21 10:13:29 +08:00
|
|
|
|
scored.append((score, r))
|
|
|
|
|
|
|
|
|
|
|
|
scored.sort(key=lambda x: x[0], reverse=True)
|
|
|
|
|
|
return [r for _, r in scored[:top_k]]
|
|
|
|
|
|
|
2026-03-22 13:42:16 +08:00
|
|
|
|
async def _get_related_chunks(
|
2026-03-21 10:13:29 +08:00
|
|
|
|
self,
|
|
|
|
|
|
chunk_id: str,
|
|
|
|
|
|
chunk_index: int,
|
|
|
|
|
|
document_id: str,
|
|
|
|
|
|
) -> tuple[str | None, str | None]:
|
2026-03-22 13:42:16 +08:00
|
|
|
|
"""获取结构相关的上下文 chunk"""
|
|
|
|
|
|
current_result = await self.db.execute(
|
|
|
|
|
|
select(DocumentChunk).where(DocumentChunk.id == chunk_id)
|
2026-03-21 10:13:29 +08:00
|
|
|
|
)
|
2026-03-22 13:42:16 +08:00
|
|
|
|
current_chunk = current_result.scalar_one_or_none()
|
|
|
|
|
|
if not current_chunk:
|
|
|
|
|
|
return None, None
|
|
|
|
|
|
|
|
|
|
|
|
current_metadata = self._parse_metadata(current_chunk.metadata_)
|
|
|
|
|
|
section_path = current_metadata.get("section_path") or []
|
|
|
|
|
|
sheet_name = current_metadata.get("sheet_name")
|
|
|
|
|
|
|
|
|
|
|
|
chunk_result = await self.db.execute(
|
|
|
|
|
|
select(DocumentChunk)
|
|
|
|
|
|
.where(DocumentChunk.document_id == document_id)
|
|
|
|
|
|
.order_by(DocumentChunk.chunk_index)
|
2026-03-21 10:13:29 +08:00
|
|
|
|
)
|
2026-03-22 13:42:16 +08:00
|
|
|
|
chunks = list(chunk_result.scalars().all())
|
|
|
|
|
|
|
|
|
|
|
|
prev_chunk = None
|
|
|
|
|
|
next_chunk = None
|
|
|
|
|
|
for chunk in chunks:
|
|
|
|
|
|
if chunk.id == chunk_id:
|
|
|
|
|
|
continue
|
|
|
|
|
|
metadata = self._parse_metadata(chunk.metadata_)
|
|
|
|
|
|
same_sheet = bool(sheet_name) and metadata.get("sheet_name") == sheet_name
|
|
|
|
|
|
same_section = bool(section_path) and metadata.get("section_path") == section_path
|
|
|
|
|
|
if chunk.chunk_index < chunk_index and (same_sheet or same_section):
|
|
|
|
|
|
prev_chunk = chunk.content
|
|
|
|
|
|
if chunk.chunk_index > chunk_index and (same_sheet or same_section):
|
|
|
|
|
|
next_chunk = chunk.content
|
|
|
|
|
|
break
|
|
|
|
|
|
return prev_chunk, next_chunk
|
2026-03-21 10:13:29 +08:00
|
|
|
|
|
|
|
|
|
|
async def _get_folder_path(self, folder_id: str) -> str | None:
|
|
|
|
|
|
"""获取文件夹的完整路径"""
|
|
|
|
|
|
result = await self.db.execute(
|
|
|
|
|
|
select(Folder).where(Folder.id == folder_id)
|
|
|
|
|
|
)
|
|
|
|
|
|
folder = result.scalar_one_or_none()
|
|
|
|
|
|
if not folder:
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
path_parts = [folder.name]
|
|
|
|
|
|
current_parent_id = folder.parent_id
|
|
|
|
|
|
|
|
|
|
|
|
while current_parent_id:
|
|
|
|
|
|
parent_result = await self.db.execute(
|
|
|
|
|
|
select(Folder).where(Folder.id == current_parent_id)
|
|
|
|
|
|
)
|
|
|
|
|
|
parent = parent_result.scalar_one_or_none()
|
|
|
|
|
|
if not parent:
|
|
|
|
|
|
break
|
|
|
|
|
|
path_parts.insert(0, parent.name)
|
|
|
|
|
|
current_parent_id = parent.parent_id
|
|
|
|
|
|
|
|
|
|
|
|
return "/" + "/".join(path_parts)
|
|
|
|
|
|
|
2026-03-22 13:42:16 +08:00
|
|
|
|
def _parse_metadata(self, raw_metadata: str | dict | None) -> dict:
|
|
|
|
|
|
if isinstance(raw_metadata, dict):
|
|
|
|
|
|
return raw_metadata
|
|
|
|
|
|
if not raw_metadata:
|
|
|
|
|
|
return {}
|
|
|
|
|
|
try:
|
|
|
|
|
|
return json.loads(raw_metadata)
|
|
|
|
|
|
except (TypeError, json.JSONDecodeError):
|
|
|
|
|
|
return {}
|
|
|
|
|
|
|
2026-03-21 10:13:29 +08:00
|
|
|
|
async def hybrid_search(
|
|
|
|
|
|
self,
|
|
|
|
|
|
query: str,
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
top_k: int = 5,
|
|
|
|
|
|
) -> list[SearchResult]:
|
|
|
|
|
|
"""混合检索: 向量 + 关键词 + Rerank"""
|
|
|
|
|
|
vector_results = await self.retrieve(query, user_id, top_k=top_k * 2, use_rerank=False)
|
|
|
|
|
|
keyword_results = await self._keyword_search(query, user_id, top_k)
|
|
|
|
|
|
|
|
|
|
|
|
seen = set()
|
|
|
|
|
|
merged: list[SearchResult] = []
|
|
|
|
|
|
for r in vector_results + keyword_results:
|
|
|
|
|
|
if r.chunk_id not in seen:
|
|
|
|
|
|
seen.add(r.chunk_id)
|
|
|
|
|
|
merged.append(r)
|
|
|
|
|
|
|
|
|
|
|
|
return self._rerank(query, merged, top_k)
|
|
|
|
|
|
|
|
|
|
|
|
async def _keyword_search(
|
|
|
|
|
|
self,
|
|
|
|
|
|
query: str,
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
top_k: int,
|
|
|
|
|
|
) -> list[SearchResult]:
|
|
|
|
|
|
"""SQL 关键词搜索"""
|
|
|
|
|
|
result = await self.db.execute(
|
|
|
|
|
|
select(DocumentChunk)
|
|
|
|
|
|
.join(Document)
|
|
|
|
|
|
.where(Document.user_id == user_id)
|
|
|
|
|
|
.where(
|
|
|
|
|
|
or_(
|
|
|
|
|
|
DocumentChunk.content.contains(query),
|
|
|
|
|
|
Document.title.contains(query),
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
.limit(top_k)
|
|
|
|
|
|
)
|
|
|
|
|
|
chunks = result.scalars().all()
|
|
|
|
|
|
results = []
|
|
|
|
|
|
for chunk in chunks:
|
|
|
|
|
|
doc_result = await self.db.execute(
|
|
|
|
|
|
select(Document).where(Document.id == chunk.document_id)
|
|
|
|
|
|
)
|
|
|
|
|
|
doc = doc_result.scalar_one_or_none()
|
|
|
|
|
|
results.append(SearchResult(
|
|
|
|
|
|
chunk_id=chunk.id,
|
|
|
|
|
|
document_id=chunk.document_id,
|
|
|
|
|
|
document_title=doc.title if doc else "",
|
|
|
|
|
|
content=chunk.content,
|
|
|
|
|
|
score=0.5,
|
|
|
|
|
|
metadata_=None,
|
|
|
|
|
|
))
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
async def delete_from_vectorstore(self, user_id: str, document_id: str):
|
|
|
|
|
|
"""从向量库删除文档"""
|
|
|
|
|
|
collection = self.get_collection(user_id)
|
|
|
|
|
|
try:
|
|
|
|
|
|
collection.delete(where={"document_id": document_id})
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
pass
|
2026-03-22 13:42:16 +08:00
|
|
|
|
|
|
|
|
|
|
async def reindex_document(self, document_id: str, user_id: str) -> bool:
|
|
|
|
|
|
result = await self.db.execute(
|
|
|
|
|
|
select(Document).where(
|
|
|
|
|
|
Document.id == document_id,
|
|
|
|
|
|
Document.user_id == user_id,
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
document = result.scalar_one_or_none()
|
|
|
|
|
|
if not document:
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
await self.delete_from_vectorstore(user_id, document_id)
|
|
|
|
|
|
document = await DocumentService(self.db, user_id=user_id).rebuild_document(document)
|
|
|
|
|
|
await self.index_document(document.id, user_id)
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
async def reindex_document_chunks(self, document_id: str, user_id: str) -> bool:
|
|
|
|
|
|
result = await self.db.execute(
|
|
|
|
|
|
select(Document).where(
|
|
|
|
|
|
Document.id == document_id,
|
|
|
|
|
|
Document.user_id == user_id,
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
document = result.scalar_one_or_none()
|
|
|
|
|
|
if not document:
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
chunks_result = await self.db.execute(
|
|
|
|
|
|
select(DocumentChunk)
|
|
|
|
|
|
.where(DocumentChunk.document_id == document_id)
|
|
|
|
|
|
.order_by(DocumentChunk.chunk_index)
|
|
|
|
|
|
)
|
|
|
|
|
|
chunks = list(chunks_result.scalars().all())
|
|
|
|
|
|
if not chunks:
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
await self.delete_from_vectorstore(user_id, document_id)
|
|
|
|
|
|
await self._index_chunks(document, chunks, user_id)
|
|
|
|
|
|
return True
|