Files
JARVIS/backend/app/services/knowledge_service.py

409 lines
14 KiB
Python
Raw Normal View History

2026-03-21 10:13:29 +08:00
"""
知识库服务 - ChromaDB 向量检索 + 混合检索 + Rerank
检索策略:
1. 语义检索 (dense) - ChromaDB 向量相似度
2. 关键词检索 (sparse) - SQL LIKE
3. 混合检索 - 语义 + 关键词 加权融合
4. Rerank - 二次排序优化结果
5. 上下文丰富 - 自动获取前/ chunk 提供完整语境
"""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, or_
from app.models.document import Document, DocumentChunk
from app.models.folder import Folder
from app.config import settings
from app.services.document_service import DocumentService
2026-03-21 10:13:29 +08:00
import chromadb
from chromadb.config import Settings as ChromaSettings
from dataclasses import dataclass
from datetime import UTC, datetime
import json
2026-03-21 10:13:29 +08:00
@dataclass
class SearchResult:
chunk_id: str
document_id: str
document_title: str
content: str
score: float
metadata_: str | None = None
prev_chunk: str | None = None
next_chunk: str | None = None
class KnowledgeService:
"""向量知识库检索服务"""
def __init__(self, db: AsyncSession, user_id: str | None = None):
self.db = db
self.user_id = user_id
self._chroma_client = None
@property
def chroma_client(self):
if self._chroma_client is None:
self._chroma_client = chromadb.PersistentClient(
path=settings.CHROMA_PERSIST_DIR,
settings=ChromaSettings(allow_reset=True),
)
return self._chroma_client
def get_collection(self, user_id: str):
return self.chroma_client.get_or_create_collection(
name=f"user_{user_id}",
metadata={"user_id": user_id},
)
async def index_document(self, document_id: str, user_id: str, folder_path: str | None = None):
"""将文档 chunks 向量化存入 ChromaDB"""
result = await self.db.execute(
select(Document).where(Document.id == document_id)
)
doc = result.scalar_one_or_none()
if not doc:
return
chunks_result = await self.db.execute(
select(DocumentChunk)
.where(DocumentChunk.document_id == document_id)
.order_by(DocumentChunk.chunk_index)
)
chunks = list(chunks_result.scalars().all())
if not chunks:
return
await self._index_chunks(doc, chunks, user_id, folder_path=folder_path)
async def _index_chunks(
self,
document: Document,
chunks: list[DocumentChunk],
user_id: str,
folder_path: str | None = None,
):
folder_path = folder_path or (await self._get_folder_path(document.folder_id) if document.folder_id else "")
2026-03-21 10:13:29 +08:00
collection = self.get_collection(user_id)
ids = [chunk.id for chunk in chunks]
documents = [chunk.content for chunk in chunks]
metadatas = []
for chunk in chunks:
chunk_metadata = self._parse_metadata(chunk.metadata_)
meta = {
"document_id": document.id,
"document_title": document.title,
"document_filename": document.filename,
2026-03-21 10:13:29 +08:00
"chunk_index": chunk.chunk_index,
"file_type": document.file_type,
2026-03-21 10:13:29 +08:00
"folder_path": folder_path or "",
"content_type": chunk_metadata.get("content_type", "text"),
"section_title": chunk_metadata.get("section_title") or "",
"section_path": " / ".join(chunk_metadata.get("section_path", [])),
"page_number": chunk_metadata.get("page_number") or 0,
"sheet_name": chunk_metadata.get("sheet_name") or "",
"row_start": chunk_metadata.get("row_start") or 0,
"row_end": chunk_metadata.get("row_end") or 0,
"parser_version": chunk_metadata.get("parser_version") or document.parser_version or "",
"index_version": chunk_metadata.get("index_version") or document.index_version or "",
2026-03-21 10:13:29 +08:00
}
chunk.chroma_collection = f"user_{user_id}"
chunk.chroma_id = chunk.id
metadatas.append(meta)
2026-03-21 10:13:29 +08:00
collection.add(ids=ids, documents=documents, metadatas=metadatas)
document.is_indexed = True
document.ingestion_status = "ready"
document.ingestion_error = None
document.indexed_at = datetime.now(UTC)
2026-03-21 10:13:29 +08:00
await self.db.commit()
async def retrieve(
self,
query: str,
user_id: str,
folder_id: str | None = None,
top_k: int = 5,
use_rerank: bool = True,
) -> list[SearchResult]:
"""
混合检索 + Rerank支持按文件夹过滤
流程:
1. ChromaDB 向量检索 (扩大候选集)
2. 提取父 chunk完整上下文
3. Rerank 二次排序
4. 返回 top_k 结果
"""
collection = self.get_collection(user_id)
# 构建过滤条件
where = None
if folder_id:
folder_path = await self._get_folder_path(folder_id)
if folder_path:
where = {"folder_path": {"$starts_with": folder_path}}
try:
results = collection.query(
query_texts=[query],
n_results=top_k * 3,
where=where,
include=["documents", "metadatas", "distances"],
)
except Exception:
return []
if not results or not results.get("ids"):
return []
ids = results["ids"][0]
documents = results["documents"][0]
metadatas = results.get("metadatas", [[]])[0]
distances = results.get("distances", [[]])[0]
search_results: list[SearchResult] = []
for i, chunk_id in enumerate(ids):
meta = metadatas[i] if i < len(metadatas) else {}
score = 1.0 - (distances[i] if i < len(distances) else 0.0)
prev_chunk, next_chunk = await self._get_related_chunks(
2026-03-21 10:13:29 +08:00
chunk_id=chunk_id,
chunk_index=meta.get("chunk_index", 0),
document_id=meta.get("document_id", ""),
)
search_results.append(SearchResult(
chunk_id=chunk_id,
document_id=meta.get("document_id", ""),
document_title=meta.get("document_title", ""),
content=documents[i] if i < len(documents) else "",
score=score,
metadata_=json.dumps(meta, ensure_ascii=False),
2026-03-21 10:13:29 +08:00
prev_chunk=prev_chunk,
next_chunk=next_chunk,
))
if use_rerank:
search_results = self._rerank(query, search_results, top_k)
else:
search_results = search_results[:top_k]
return search_results
def _rerank(
self,
query: str,
results: list[SearchResult],
top_k: int,
) -> list[SearchResult]:
"""Rerank: 语义分 * 0.7 + 关键词匹配 * 0.2 + 标题匹配 * 0.1 + 结构加权"""
2026-03-21 10:13:29 +08:00
import re
query_words = set(re.findall(r"\w+", query.lower()))
table_query = any(token in query.lower() for token in ["sheet", "excel", "csv", "", "", "金额", "统计", "日期"])
2026-03-21 10:13:29 +08:00
scored = []
for r in results:
score = r.score * 0.7
content_words = set(re.findall(r"\w+", r.content.lower()))
keyword_overlap = len(query_words & content_words) / max(len(query_words), 1)
score += keyword_overlap * 0.2
if r.document_title:
title_words = set(re.findall(r"\w+", r.document_title.lower()))
title_overlap = len(query_words & title_words) / max(len(query_words), 1)
score += title_overlap * 0.1
metadata = self._parse_metadata(r.metadata_)
if table_query and metadata.get("content_type") == "table_schema":
score += 0.25
elif table_query and metadata.get("content_type") == "table_rows":
score += 0.15
2026-03-21 10:13:29 +08:00
scored.append((score, r))
scored.sort(key=lambda x: x[0], reverse=True)
return [r for _, r in scored[:top_k]]
async def _get_related_chunks(
2026-03-21 10:13:29 +08:00
self,
chunk_id: str,
chunk_index: int,
document_id: str,
) -> tuple[str | None, str | None]:
"""获取结构相关的上下文 chunk"""
current_result = await self.db.execute(
select(DocumentChunk).where(DocumentChunk.id == chunk_id)
2026-03-21 10:13:29 +08:00
)
current_chunk = current_result.scalar_one_or_none()
if not current_chunk:
return None, None
current_metadata = self._parse_metadata(current_chunk.metadata_)
section_path = current_metadata.get("section_path") or []
sheet_name = current_metadata.get("sheet_name")
chunk_result = await self.db.execute(
select(DocumentChunk)
.where(DocumentChunk.document_id == document_id)
.order_by(DocumentChunk.chunk_index)
2026-03-21 10:13:29 +08:00
)
chunks = list(chunk_result.scalars().all())
prev_chunk = None
next_chunk = None
for chunk in chunks:
if chunk.id == chunk_id:
continue
metadata = self._parse_metadata(chunk.metadata_)
same_sheet = bool(sheet_name) and metadata.get("sheet_name") == sheet_name
same_section = bool(section_path) and metadata.get("section_path") == section_path
if chunk.chunk_index < chunk_index and (same_sheet or same_section):
prev_chunk = chunk.content
if chunk.chunk_index > chunk_index and (same_sheet or same_section):
next_chunk = chunk.content
break
return prev_chunk, next_chunk
2026-03-21 10:13:29 +08:00
async def _get_folder_path(self, folder_id: str) -> str | None:
"""获取文件夹的完整路径"""
result = await self.db.execute(
select(Folder).where(Folder.id == folder_id)
)
folder = result.scalar_one_or_none()
if not folder:
return None
path_parts = [folder.name]
current_parent_id = folder.parent_id
while current_parent_id:
parent_result = await self.db.execute(
select(Folder).where(Folder.id == current_parent_id)
)
parent = parent_result.scalar_one_or_none()
if not parent:
break
path_parts.insert(0, parent.name)
current_parent_id = parent.parent_id
return "/" + "/".join(path_parts)
def _parse_metadata(self, raw_metadata: str | dict | None) -> dict:
if isinstance(raw_metadata, dict):
return raw_metadata
if not raw_metadata:
return {}
try:
return json.loads(raw_metadata)
except (TypeError, json.JSONDecodeError):
return {}
2026-03-21 10:13:29 +08:00
async def hybrid_search(
self,
query: str,
user_id: str,
top_k: int = 5,
) -> list[SearchResult]:
"""混合检索: 向量 + 关键词 + Rerank"""
vector_results = await self.retrieve(query, user_id, top_k=top_k * 2, use_rerank=False)
keyword_results = await self._keyword_search(query, user_id, top_k)
seen = set()
merged: list[SearchResult] = []
for r in vector_results + keyword_results:
if r.chunk_id not in seen:
seen.add(r.chunk_id)
merged.append(r)
return self._rerank(query, merged, top_k)
async def _keyword_search(
self,
query: str,
user_id: str,
top_k: int,
) -> list[SearchResult]:
"""SQL 关键词搜索"""
result = await self.db.execute(
select(DocumentChunk)
.join(Document)
.where(Document.user_id == user_id)
.where(
or_(
DocumentChunk.content.contains(query),
Document.title.contains(query),
)
)
.limit(top_k)
)
chunks = result.scalars().all()
results = []
for chunk in chunks:
doc_result = await self.db.execute(
select(Document).where(Document.id == chunk.document_id)
)
doc = doc_result.scalar_one_or_none()
results.append(SearchResult(
chunk_id=chunk.id,
document_id=chunk.document_id,
document_title=doc.title if doc else "",
content=chunk.content,
score=0.5,
metadata_=None,
))
return results
async def delete_from_vectorstore(self, user_id: str, document_id: str):
"""从向量库删除文档"""
collection = self.get_collection(user_id)
try:
collection.delete(where={"document_id": document_id})
except Exception:
pass
async def reindex_document(self, document_id: str, user_id: str) -> bool:
result = await self.db.execute(
select(Document).where(
Document.id == document_id,
Document.user_id == user_id,
)
)
document = result.scalar_one_or_none()
if not document:
return False
await self.delete_from_vectorstore(user_id, document_id)
document = await DocumentService(self.db, user_id=user_id).rebuild_document(document)
await self.index_document(document.id, user_id)
return True
async def reindex_document_chunks(self, document_id: str, user_id: str) -> bool:
result = await self.db.execute(
select(Document).where(
Document.id == document_id,
Document.user_id == user_id,
)
)
document = result.scalar_one_or_none()
if not document:
return False
chunks_result = await self.db.execute(
select(DocumentChunk)
.where(DocumentChunk.document_id == document_id)
.order_by(DocumentChunk.chunk_index)
)
chunks = list(chunks_result.scalars().all())
if not chunks:
return False
await self.delete_from_vectorstore(user_id, document_id)
await self._index_chunks(document, chunks, user_id)
return True