Normalize uploaded documents into structured markdown, add clearer parser errors for missing dependencies, and cover the ingestion flow with backend tests. This also replaces deprecated UTC timestamp helpers in the touched backend paths so the knowledge pipeline stays warning-free. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
409 lines
14 KiB
Python
409 lines
14 KiB
Python
"""
|
||
知识库服务 - ChromaDB 向量检索 + 混合检索 + Rerank
|
||
|
||
检索策略:
|
||
1. 语义检索 (dense) - ChromaDB 向量相似度
|
||
2. 关键词检索 (sparse) - SQL LIKE
|
||
3. 混合检索 - 语义 + 关键词 加权融合
|
||
4. Rerank - 二次排序优化结果
|
||
5. 上下文丰富 - 自动获取前/后 chunk 提供完整语境
|
||
"""
|
||
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select, or_
|
||
from app.models.document import Document, DocumentChunk
|
||
from app.models.folder import Folder
|
||
from app.config import settings
|
||
from app.services.document_service import DocumentService
|
||
import chromadb
|
||
from chromadb.config import Settings as ChromaSettings
|
||
from dataclasses import dataclass
|
||
from datetime import UTC, datetime
|
||
import json
|
||
|
||
|
||
@dataclass
|
||
class SearchResult:
|
||
chunk_id: str
|
||
document_id: str
|
||
document_title: str
|
||
content: str
|
||
score: float
|
||
metadata_: str | None = None
|
||
prev_chunk: str | None = None
|
||
next_chunk: str | None = None
|
||
|
||
|
||
class KnowledgeService:
|
||
"""向量知识库检索服务"""
|
||
|
||
def __init__(self, db: AsyncSession, user_id: str | None = None):
|
||
self.db = db
|
||
self.user_id = user_id
|
||
self._chroma_client = None
|
||
|
||
@property
|
||
def chroma_client(self):
|
||
if self._chroma_client is None:
|
||
self._chroma_client = chromadb.PersistentClient(
|
||
path=settings.CHROMA_PERSIST_DIR,
|
||
settings=ChromaSettings(allow_reset=True),
|
||
)
|
||
return self._chroma_client
|
||
|
||
def get_collection(self, user_id: str):
|
||
return self.chroma_client.get_or_create_collection(
|
||
name=f"user_{user_id}",
|
||
metadata={"user_id": user_id},
|
||
)
|
||
|
||
async def index_document(self, document_id: str, user_id: str, folder_path: str | None = None):
|
||
"""将文档 chunks 向量化存入 ChromaDB"""
|
||
result = await self.db.execute(
|
||
select(Document).where(Document.id == document_id)
|
||
)
|
||
doc = result.scalar_one_or_none()
|
||
if not doc:
|
||
return
|
||
|
||
chunks_result = await self.db.execute(
|
||
select(DocumentChunk)
|
||
.where(DocumentChunk.document_id == document_id)
|
||
.order_by(DocumentChunk.chunk_index)
|
||
)
|
||
chunks = list(chunks_result.scalars().all())
|
||
if not chunks:
|
||
return
|
||
|
||
await self._index_chunks(doc, chunks, user_id, folder_path=folder_path)
|
||
|
||
async def _index_chunks(
|
||
self,
|
||
document: Document,
|
||
chunks: list[DocumentChunk],
|
||
user_id: str,
|
||
folder_path: str | None = None,
|
||
):
|
||
folder_path = folder_path or (await self._get_folder_path(document.folder_id) if document.folder_id else "")
|
||
collection = self.get_collection(user_id)
|
||
|
||
ids = [chunk.id for chunk in chunks]
|
||
documents = [chunk.content for chunk in chunks]
|
||
metadatas = []
|
||
for chunk in chunks:
|
||
chunk_metadata = self._parse_metadata(chunk.metadata_)
|
||
meta = {
|
||
"document_id": document.id,
|
||
"document_title": document.title,
|
||
"document_filename": document.filename,
|
||
"chunk_index": chunk.chunk_index,
|
||
"file_type": document.file_type,
|
||
"folder_path": folder_path or "",
|
||
"content_type": chunk_metadata.get("content_type", "text"),
|
||
"section_title": chunk_metadata.get("section_title") or "",
|
||
"section_path": " / ".join(chunk_metadata.get("section_path", [])),
|
||
"page_number": chunk_metadata.get("page_number") or 0,
|
||
"sheet_name": chunk_metadata.get("sheet_name") or "",
|
||
"row_start": chunk_metadata.get("row_start") or 0,
|
||
"row_end": chunk_metadata.get("row_end") or 0,
|
||
"parser_version": chunk_metadata.get("parser_version") or document.parser_version or "",
|
||
"index_version": chunk_metadata.get("index_version") or document.index_version or "",
|
||
}
|
||
chunk.chroma_collection = f"user_{user_id}"
|
||
chunk.chroma_id = chunk.id
|
||
metadatas.append(meta)
|
||
|
||
collection.add(ids=ids, documents=documents, metadatas=metadatas)
|
||
|
||
document.is_indexed = True
|
||
document.ingestion_status = "ready"
|
||
document.ingestion_error = None
|
||
document.indexed_at = datetime.now(UTC)
|
||
await self.db.commit()
|
||
|
||
async def retrieve(
|
||
self,
|
||
query: str,
|
||
user_id: str,
|
||
folder_id: str | None = None,
|
||
top_k: int = 5,
|
||
use_rerank: bool = True,
|
||
) -> list[SearchResult]:
|
||
"""
|
||
混合检索 + Rerank,支持按文件夹过滤
|
||
|
||
流程:
|
||
1. ChromaDB 向量检索 (扩大候选集)
|
||
2. 提取父 chunk(完整上下文)
|
||
3. Rerank 二次排序
|
||
4. 返回 top_k 结果
|
||
"""
|
||
collection = self.get_collection(user_id)
|
||
|
||
# 构建过滤条件
|
||
where = None
|
||
if folder_id:
|
||
folder_path = await self._get_folder_path(folder_id)
|
||
if folder_path:
|
||
where = {"folder_path": {"$starts_with": folder_path}}
|
||
|
||
try:
|
||
results = collection.query(
|
||
query_texts=[query],
|
||
n_results=top_k * 3,
|
||
where=where,
|
||
include=["documents", "metadatas", "distances"],
|
||
)
|
||
except Exception:
|
||
return []
|
||
|
||
if not results or not results.get("ids"):
|
||
return []
|
||
|
||
ids = results["ids"][0]
|
||
documents = results["documents"][0]
|
||
metadatas = results.get("metadatas", [[]])[0]
|
||
distances = results.get("distances", [[]])[0]
|
||
|
||
search_results: list[SearchResult] = []
|
||
for i, chunk_id in enumerate(ids):
|
||
meta = metadatas[i] if i < len(metadatas) else {}
|
||
score = 1.0 - (distances[i] if i < len(distances) else 0.0)
|
||
|
||
prev_chunk, next_chunk = await self._get_related_chunks(
|
||
chunk_id=chunk_id,
|
||
chunk_index=meta.get("chunk_index", 0),
|
||
document_id=meta.get("document_id", ""),
|
||
)
|
||
|
||
search_results.append(SearchResult(
|
||
chunk_id=chunk_id,
|
||
document_id=meta.get("document_id", ""),
|
||
document_title=meta.get("document_title", ""),
|
||
content=documents[i] if i < len(documents) else "",
|
||
score=score,
|
||
metadata_=json.dumps(meta, ensure_ascii=False),
|
||
prev_chunk=prev_chunk,
|
||
next_chunk=next_chunk,
|
||
))
|
||
|
||
if use_rerank:
|
||
search_results = self._rerank(query, search_results, top_k)
|
||
else:
|
||
search_results = search_results[:top_k]
|
||
|
||
return search_results
|
||
|
||
def _rerank(
|
||
self,
|
||
query: str,
|
||
results: list[SearchResult],
|
||
top_k: int,
|
||
) -> list[SearchResult]:
|
||
"""Rerank: 语义分 * 0.7 + 关键词匹配 * 0.2 + 标题匹配 * 0.1 + 结构加权"""
|
||
import re
|
||
|
||
query_words = set(re.findall(r"\w+", query.lower()))
|
||
table_query = any(token in query.lower() for token in ["sheet", "excel", "csv", "表", "列", "金额", "统计", "日期"])
|
||
|
||
scored = []
|
||
for r in results:
|
||
score = r.score * 0.7
|
||
|
||
content_words = set(re.findall(r"\w+", r.content.lower()))
|
||
keyword_overlap = len(query_words & content_words) / max(len(query_words), 1)
|
||
score += keyword_overlap * 0.2
|
||
|
||
if r.document_title:
|
||
title_words = set(re.findall(r"\w+", r.document_title.lower()))
|
||
title_overlap = len(query_words & title_words) / max(len(query_words), 1)
|
||
score += title_overlap * 0.1
|
||
|
||
metadata = self._parse_metadata(r.metadata_)
|
||
if table_query and metadata.get("content_type") == "table_schema":
|
||
score += 0.25
|
||
elif table_query and metadata.get("content_type") == "table_rows":
|
||
score += 0.15
|
||
|
||
scored.append((score, r))
|
||
|
||
scored.sort(key=lambda x: x[0], reverse=True)
|
||
return [r for _, r in scored[:top_k]]
|
||
|
||
async def _get_related_chunks(
|
||
self,
|
||
chunk_id: str,
|
||
chunk_index: int,
|
||
document_id: str,
|
||
) -> tuple[str | None, str | None]:
|
||
"""获取结构相关的上下文 chunk"""
|
||
current_result = await self.db.execute(
|
||
select(DocumentChunk).where(DocumentChunk.id == chunk_id)
|
||
)
|
||
current_chunk = current_result.scalar_one_or_none()
|
||
if not current_chunk:
|
||
return None, None
|
||
|
||
current_metadata = self._parse_metadata(current_chunk.metadata_)
|
||
section_path = current_metadata.get("section_path") or []
|
||
sheet_name = current_metadata.get("sheet_name")
|
||
|
||
chunk_result = await self.db.execute(
|
||
select(DocumentChunk)
|
||
.where(DocumentChunk.document_id == document_id)
|
||
.order_by(DocumentChunk.chunk_index)
|
||
)
|
||
chunks = list(chunk_result.scalars().all())
|
||
|
||
prev_chunk = None
|
||
next_chunk = None
|
||
for chunk in chunks:
|
||
if chunk.id == chunk_id:
|
||
continue
|
||
metadata = self._parse_metadata(chunk.metadata_)
|
||
same_sheet = bool(sheet_name) and metadata.get("sheet_name") == sheet_name
|
||
same_section = bool(section_path) and metadata.get("section_path") == section_path
|
||
if chunk.chunk_index < chunk_index and (same_sheet or same_section):
|
||
prev_chunk = chunk.content
|
||
if chunk.chunk_index > chunk_index and (same_sheet or same_section):
|
||
next_chunk = chunk.content
|
||
break
|
||
return prev_chunk, next_chunk
|
||
|
||
async def _get_folder_path(self, folder_id: str) -> str | None:
|
||
"""获取文件夹的完整路径"""
|
||
result = await self.db.execute(
|
||
select(Folder).where(Folder.id == folder_id)
|
||
)
|
||
folder = result.scalar_one_or_none()
|
||
if not folder:
|
||
return None
|
||
|
||
path_parts = [folder.name]
|
||
current_parent_id = folder.parent_id
|
||
|
||
while current_parent_id:
|
||
parent_result = await self.db.execute(
|
||
select(Folder).where(Folder.id == current_parent_id)
|
||
)
|
||
parent = parent_result.scalar_one_or_none()
|
||
if not parent:
|
||
break
|
||
path_parts.insert(0, parent.name)
|
||
current_parent_id = parent.parent_id
|
||
|
||
return "/" + "/".join(path_parts)
|
||
|
||
def _parse_metadata(self, raw_metadata: str | dict | None) -> dict:
|
||
if isinstance(raw_metadata, dict):
|
||
return raw_metadata
|
||
if not raw_metadata:
|
||
return {}
|
||
try:
|
||
return json.loads(raw_metadata)
|
||
except (TypeError, json.JSONDecodeError):
|
||
return {}
|
||
|
||
async def hybrid_search(
|
||
self,
|
||
query: str,
|
||
user_id: str,
|
||
top_k: int = 5,
|
||
) -> list[SearchResult]:
|
||
"""混合检索: 向量 + 关键词 + Rerank"""
|
||
vector_results = await self.retrieve(query, user_id, top_k=top_k * 2, use_rerank=False)
|
||
keyword_results = await self._keyword_search(query, user_id, top_k)
|
||
|
||
seen = set()
|
||
merged: list[SearchResult] = []
|
||
for r in vector_results + keyword_results:
|
||
if r.chunk_id not in seen:
|
||
seen.add(r.chunk_id)
|
||
merged.append(r)
|
||
|
||
return self._rerank(query, merged, top_k)
|
||
|
||
async def _keyword_search(
|
||
self,
|
||
query: str,
|
||
user_id: str,
|
||
top_k: int,
|
||
) -> list[SearchResult]:
|
||
"""SQL 关键词搜索"""
|
||
result = await self.db.execute(
|
||
select(DocumentChunk)
|
||
.join(Document)
|
||
.where(Document.user_id == user_id)
|
||
.where(
|
||
or_(
|
||
DocumentChunk.content.contains(query),
|
||
Document.title.contains(query),
|
||
)
|
||
)
|
||
.limit(top_k)
|
||
)
|
||
chunks = result.scalars().all()
|
||
results = []
|
||
for chunk in chunks:
|
||
doc_result = await self.db.execute(
|
||
select(Document).where(Document.id == chunk.document_id)
|
||
)
|
||
doc = doc_result.scalar_one_or_none()
|
||
results.append(SearchResult(
|
||
chunk_id=chunk.id,
|
||
document_id=chunk.document_id,
|
||
document_title=doc.title if doc else "",
|
||
content=chunk.content,
|
||
score=0.5,
|
||
metadata_=None,
|
||
))
|
||
return results
|
||
|
||
async def delete_from_vectorstore(self, user_id: str, document_id: str):
|
||
"""从向量库删除文档"""
|
||
collection = self.get_collection(user_id)
|
||
try:
|
||
collection.delete(where={"document_id": document_id})
|
||
except Exception:
|
||
pass
|
||
|
||
async def reindex_document(self, document_id: str, user_id: str) -> bool:
|
||
result = await self.db.execute(
|
||
select(Document).where(
|
||
Document.id == document_id,
|
||
Document.user_id == user_id,
|
||
)
|
||
)
|
||
document = result.scalar_one_or_none()
|
||
if not document:
|
||
return False
|
||
|
||
await self.delete_from_vectorstore(user_id, document_id)
|
||
document = await DocumentService(self.db, user_id=user_id).rebuild_document(document)
|
||
await self.index_document(document.id, user_id)
|
||
return True
|
||
|
||
async def reindex_document_chunks(self, document_id: str, user_id: str) -> bool:
|
||
result = await self.db.execute(
|
||
select(Document).where(
|
||
Document.id == document_id,
|
||
Document.user_id == user_id,
|
||
)
|
||
)
|
||
document = result.scalar_one_or_none()
|
||
if not document:
|
||
return False
|
||
|
||
chunks_result = await self.db.execute(
|
||
select(DocumentChunk)
|
||
.where(DocumentChunk.document_id == document_id)
|
||
.order_by(DocumentChunk.chunk_index)
|
||
)
|
||
chunks = list(chunks_result.scalars().all())
|
||
if not chunks:
|
||
return False
|
||
|
||
await self.delete_from_vectorstore(user_id, document_id)
|
||
await self._index_chunks(document, chunks, user_id)
|
||
return True
|