Files
JARVIS/backend/app/services/knowledge_service.py
DESKTOP-72TV0V4\caoxiaozhu 3ee825aa90 Add MinerU document ingestion support
Normalize uploaded documents into structured markdown, add clearer parser
errors for missing dependencies, and cover the ingestion flow with
backend tests. This also replaces deprecated UTC timestamp helpers in
the touched backend paths so the knowledge pipeline stays warning-free.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-22 13:42:16 +08:00

409 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
知识库服务 - ChromaDB 向量检索 + 混合检索 + Rerank
检索策略:
1. 语义检索 (dense) - ChromaDB 向量相似度
2. 关键词检索 (sparse) - SQL LIKE
3. 混合检索 - 语义 + 关键词 加权融合
4. Rerank - 二次排序优化结果
5. 上下文丰富 - 自动获取前/后 chunk 提供完整语境
"""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, or_
from app.models.document import Document, DocumentChunk
from app.models.folder import Folder
from app.config import settings
from app.services.document_service import DocumentService
import chromadb
from chromadb.config import Settings as ChromaSettings
from dataclasses import dataclass
from datetime import UTC, datetime
import json
@dataclass
class SearchResult:
chunk_id: str
document_id: str
document_title: str
content: str
score: float
metadata_: str | None = None
prev_chunk: str | None = None
next_chunk: str | None = None
class KnowledgeService:
"""向量知识库检索服务"""
def __init__(self, db: AsyncSession, user_id: str | None = None):
self.db = db
self.user_id = user_id
self._chroma_client = None
@property
def chroma_client(self):
if self._chroma_client is None:
self._chroma_client = chromadb.PersistentClient(
path=settings.CHROMA_PERSIST_DIR,
settings=ChromaSettings(allow_reset=True),
)
return self._chroma_client
def get_collection(self, user_id: str):
return self.chroma_client.get_or_create_collection(
name=f"user_{user_id}",
metadata={"user_id": user_id},
)
async def index_document(self, document_id: str, user_id: str, folder_path: str | None = None):
"""将文档 chunks 向量化存入 ChromaDB"""
result = await self.db.execute(
select(Document).where(Document.id == document_id)
)
doc = result.scalar_one_or_none()
if not doc:
return
chunks_result = await self.db.execute(
select(DocumentChunk)
.where(DocumentChunk.document_id == document_id)
.order_by(DocumentChunk.chunk_index)
)
chunks = list(chunks_result.scalars().all())
if not chunks:
return
await self._index_chunks(doc, chunks, user_id, folder_path=folder_path)
async def _index_chunks(
self,
document: Document,
chunks: list[DocumentChunk],
user_id: str,
folder_path: str | None = None,
):
folder_path = folder_path or (await self._get_folder_path(document.folder_id) if document.folder_id else "")
collection = self.get_collection(user_id)
ids = [chunk.id for chunk in chunks]
documents = [chunk.content for chunk in chunks]
metadatas = []
for chunk in chunks:
chunk_metadata = self._parse_metadata(chunk.metadata_)
meta = {
"document_id": document.id,
"document_title": document.title,
"document_filename": document.filename,
"chunk_index": chunk.chunk_index,
"file_type": document.file_type,
"folder_path": folder_path or "",
"content_type": chunk_metadata.get("content_type", "text"),
"section_title": chunk_metadata.get("section_title") or "",
"section_path": " / ".join(chunk_metadata.get("section_path", [])),
"page_number": chunk_metadata.get("page_number") or 0,
"sheet_name": chunk_metadata.get("sheet_name") or "",
"row_start": chunk_metadata.get("row_start") or 0,
"row_end": chunk_metadata.get("row_end") or 0,
"parser_version": chunk_metadata.get("parser_version") or document.parser_version or "",
"index_version": chunk_metadata.get("index_version") or document.index_version or "",
}
chunk.chroma_collection = f"user_{user_id}"
chunk.chroma_id = chunk.id
metadatas.append(meta)
collection.add(ids=ids, documents=documents, metadatas=metadatas)
document.is_indexed = True
document.ingestion_status = "ready"
document.ingestion_error = None
document.indexed_at = datetime.now(UTC)
await self.db.commit()
async def retrieve(
self,
query: str,
user_id: str,
folder_id: str | None = None,
top_k: int = 5,
use_rerank: bool = True,
) -> list[SearchResult]:
"""
混合检索 + Rerank支持按文件夹过滤
流程:
1. ChromaDB 向量检索 (扩大候选集)
2. 提取父 chunk完整上下文
3. Rerank 二次排序
4. 返回 top_k 结果
"""
collection = self.get_collection(user_id)
# 构建过滤条件
where = None
if folder_id:
folder_path = await self._get_folder_path(folder_id)
if folder_path:
where = {"folder_path": {"$starts_with": folder_path}}
try:
results = collection.query(
query_texts=[query],
n_results=top_k * 3,
where=where,
include=["documents", "metadatas", "distances"],
)
except Exception:
return []
if not results or not results.get("ids"):
return []
ids = results["ids"][0]
documents = results["documents"][0]
metadatas = results.get("metadatas", [[]])[0]
distances = results.get("distances", [[]])[0]
search_results: list[SearchResult] = []
for i, chunk_id in enumerate(ids):
meta = metadatas[i] if i < len(metadatas) else {}
score = 1.0 - (distances[i] if i < len(distances) else 0.0)
prev_chunk, next_chunk = await self._get_related_chunks(
chunk_id=chunk_id,
chunk_index=meta.get("chunk_index", 0),
document_id=meta.get("document_id", ""),
)
search_results.append(SearchResult(
chunk_id=chunk_id,
document_id=meta.get("document_id", ""),
document_title=meta.get("document_title", ""),
content=documents[i] if i < len(documents) else "",
score=score,
metadata_=json.dumps(meta, ensure_ascii=False),
prev_chunk=prev_chunk,
next_chunk=next_chunk,
))
if use_rerank:
search_results = self._rerank(query, search_results, top_k)
else:
search_results = search_results[:top_k]
return search_results
def _rerank(
self,
query: str,
results: list[SearchResult],
top_k: int,
) -> list[SearchResult]:
"""Rerank: 语义分 * 0.7 + 关键词匹配 * 0.2 + 标题匹配 * 0.1 + 结构加权"""
import re
query_words = set(re.findall(r"\w+", query.lower()))
table_query = any(token in query.lower() for token in ["sheet", "excel", "csv", "", "", "金额", "统计", "日期"])
scored = []
for r in results:
score = r.score * 0.7
content_words = set(re.findall(r"\w+", r.content.lower()))
keyword_overlap = len(query_words & content_words) / max(len(query_words), 1)
score += keyword_overlap * 0.2
if r.document_title:
title_words = set(re.findall(r"\w+", r.document_title.lower()))
title_overlap = len(query_words & title_words) / max(len(query_words), 1)
score += title_overlap * 0.1
metadata = self._parse_metadata(r.metadata_)
if table_query and metadata.get("content_type") == "table_schema":
score += 0.25
elif table_query and metadata.get("content_type") == "table_rows":
score += 0.15
scored.append((score, r))
scored.sort(key=lambda x: x[0], reverse=True)
return [r for _, r in scored[:top_k]]
async def _get_related_chunks(
self,
chunk_id: str,
chunk_index: int,
document_id: str,
) -> tuple[str | None, str | None]:
"""获取结构相关的上下文 chunk"""
current_result = await self.db.execute(
select(DocumentChunk).where(DocumentChunk.id == chunk_id)
)
current_chunk = current_result.scalar_one_or_none()
if not current_chunk:
return None, None
current_metadata = self._parse_metadata(current_chunk.metadata_)
section_path = current_metadata.get("section_path") or []
sheet_name = current_metadata.get("sheet_name")
chunk_result = await self.db.execute(
select(DocumentChunk)
.where(DocumentChunk.document_id == document_id)
.order_by(DocumentChunk.chunk_index)
)
chunks = list(chunk_result.scalars().all())
prev_chunk = None
next_chunk = None
for chunk in chunks:
if chunk.id == chunk_id:
continue
metadata = self._parse_metadata(chunk.metadata_)
same_sheet = bool(sheet_name) and metadata.get("sheet_name") == sheet_name
same_section = bool(section_path) and metadata.get("section_path") == section_path
if chunk.chunk_index < chunk_index and (same_sheet or same_section):
prev_chunk = chunk.content
if chunk.chunk_index > chunk_index and (same_sheet or same_section):
next_chunk = chunk.content
break
return prev_chunk, next_chunk
async def _get_folder_path(self, folder_id: str) -> str | None:
"""获取文件夹的完整路径"""
result = await self.db.execute(
select(Folder).where(Folder.id == folder_id)
)
folder = result.scalar_one_or_none()
if not folder:
return None
path_parts = [folder.name]
current_parent_id = folder.parent_id
while current_parent_id:
parent_result = await self.db.execute(
select(Folder).where(Folder.id == current_parent_id)
)
parent = parent_result.scalar_one_or_none()
if not parent:
break
path_parts.insert(0, parent.name)
current_parent_id = parent.parent_id
return "/" + "/".join(path_parts)
def _parse_metadata(self, raw_metadata: str | dict | None) -> dict:
if isinstance(raw_metadata, dict):
return raw_metadata
if not raw_metadata:
return {}
try:
return json.loads(raw_metadata)
except (TypeError, json.JSONDecodeError):
return {}
async def hybrid_search(
self,
query: str,
user_id: str,
top_k: int = 5,
) -> list[SearchResult]:
"""混合检索: 向量 + 关键词 + Rerank"""
vector_results = await self.retrieve(query, user_id, top_k=top_k * 2, use_rerank=False)
keyword_results = await self._keyword_search(query, user_id, top_k)
seen = set()
merged: list[SearchResult] = []
for r in vector_results + keyword_results:
if r.chunk_id not in seen:
seen.add(r.chunk_id)
merged.append(r)
return self._rerank(query, merged, top_k)
async def _keyword_search(
self,
query: str,
user_id: str,
top_k: int,
) -> list[SearchResult]:
"""SQL 关键词搜索"""
result = await self.db.execute(
select(DocumentChunk)
.join(Document)
.where(Document.user_id == user_id)
.where(
or_(
DocumentChunk.content.contains(query),
Document.title.contains(query),
)
)
.limit(top_k)
)
chunks = result.scalars().all()
results = []
for chunk in chunks:
doc_result = await self.db.execute(
select(Document).where(Document.id == chunk.document_id)
)
doc = doc_result.scalar_one_or_none()
results.append(SearchResult(
chunk_id=chunk.id,
document_id=chunk.document_id,
document_title=doc.title if doc else "",
content=chunk.content,
score=0.5,
metadata_=None,
))
return results
async def delete_from_vectorstore(self, user_id: str, document_id: str):
"""从向量库删除文档"""
collection = self.get_collection(user_id)
try:
collection.delete(where={"document_id": document_id})
except Exception:
pass
async def reindex_document(self, document_id: str, user_id: str) -> bool:
result = await self.db.execute(
select(Document).where(
Document.id == document_id,
Document.user_id == user_id,
)
)
document = result.scalar_one_or_none()
if not document:
return False
await self.delete_from_vectorstore(user_id, document_id)
document = await DocumentService(self.db, user_id=user_id).rebuild_document(document)
await self.index_document(document.id, user_id)
return True
async def reindex_document_chunks(self, document_id: str, user_id: str) -> bool:
result = await self.db.execute(
select(Document).where(
Document.id == document_id,
Document.user_id == user_id,
)
)
document = result.scalar_one_or_none()
if not document:
return False
chunks_result = await self.db.execute(
select(DocumentChunk)
.where(DocumentChunk.document_id == document_id)
.order_by(DocumentChunk.chunk_index)
)
chunks = list(chunks_result.scalars().all())
if not chunks:
return False
await self.delete_from_vectorstore(user_id, document_id)
await self._index_chunks(document, chunks, user_id)
return True