Files
X-Financial/server/src/app/services/knowledge_rag.py

1331 lines
49 KiB
Python
Raw Normal View History

from __future__ import annotations
import asyncio
import json
import os
import re
import socket
import threading
from dataclasses import dataclass
from datetime import UTC, datetime
from functools import partial
from http import HTTPStatus
from pathlib import Path
from time import perf_counter
from typing import Any
from urllib.error import HTTPError, URLError
from urllib.parse import quote
from urllib.request import Request, urlopen
from sqlalchemy.orm import Session
from app.core.config import get_settings
from app.core.logging import get_logger
from app.db.session import get_session_factory
from app.services.settings import SettingsService
logger = get_logger("app.services.knowledge_rag")
DEFAULT_QDRANT_URL = "http://127.0.0.1:6333"
CONTAINER_QDRANT_URL = "http://qdrant:6333"
DEFAULT_LIGHTRAG_WORKSPACE = "x_financial_knowledge"
DEFAULT_LIGHTRAG_QUERY_MODE = "naive"
DEFAULT_LLM_TIMEOUT_SECONDS = 180
DEFAULT_EMBEDDING_TIMEOUT_SECONDS = 120
MAX_KNOWLEDGE_HIT_CONTENT_LENGTH = 2200
MAX_KNOWLEDGE_HIT_EXCERPT_LENGTH = 220
MAX_QUERY_TERMS = 12
QUERY_TERM_STOPWORDS = {
"什么",
"多少",
"哪些",
"怎么",
"如何",
"请问",
"一下",
"关于",
"规定",
"标准",
"可以",
"是否",
"一个",
"哪些人",
}
TABLE_OR_STANDARD_QUERY_HINTS = (
"标准",
"金额",
"限额",
"补贴",
"住宿",
"餐费",
"交通",
"报销",
"档位",
"额度",
)
STRUCTURED_APPENDIX_LEADING_MARKERS = (
"# 章节导航",
"# 重点章节摘录",
"# 问答线索补充",
"# 结构化表格补充",
)
STRUCTURED_APPENDIX_LEADING_WINDOW = 220
_runtime_lock = threading.RLock()
_runtime_instance: _LightRagRuntime | None = None
_runtime_signature: tuple[Any, ...] | None = None
class KnowledgeRagError(RuntimeError):
pass
@dataclass(frozen=True, slots=True)
class RuntimeModelConfig:
slot: str
provider: str
model: str
endpoint: str
api_key: str
capability: str
class _LightRagRuntime:
def __init__(
self,
*,
working_dir: Path,
workspace: str,
qdrant_url: str,
qdrant_api_key: str,
primary_chat: RuntimeModelConfig,
backup_chat: RuntimeModelConfig | None,
embedding: RuntimeModelConfig,
reranker: RuntimeModelConfig | None,
) -> None:
self.working_dir = working_dir
self.workspace = workspace
self.qdrant_url = qdrant_url
self.qdrant_api_key = qdrant_api_key
self.primary_chat = primary_chat
self.backup_chat = backup_chat
self.embedding = embedding
self.reranker = reranker
self._rag = self._build_rag()
self._initialize()
self._graph_has_content_cache: bool | None = None
@property
def rag(self):
return self._rag
def _build_rag(self):
try:
from lightrag import LightRAG
from lightrag.utils import EmbeddingFunc
except ImportError as exc: # pragma: no cover - exercised in runtime env
raise KnowledgeRagError(
"LightRAG 依赖未安装,请先在 server 环境执行依赖安装。"
) from exc
self.working_dir.mkdir(parents=True, exist_ok=True)
if self.qdrant_url:
os.environ["QDRANT_URL"] = self.qdrant_url
if self.qdrant_api_key:
os.environ["QDRANT_API_KEY"] = self.qdrant_api_key
embedding_dim = self._probe_embedding_dimension(self.embedding)
logger.info(
"Initialize LightRAG runtime workspace=%s qdrant=%s embedding_model=%s dim=%s",
self.workspace,
self.qdrant_url,
self.embedding.model,
embedding_dim,
)
async def embedding_func(texts: list[str]) -> Any:
return await asyncio.to_thread(self._embed_sync, texts)
async def llm_model_func(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
keyword_extraction: bool = False,
**kwargs: Any,
) -> str:
return await asyncio.to_thread(
self._complete_sync,
prompt,
system_prompt,
history_messages or [],
keyword_extraction,
kwargs,
)
async def rerank_model_func(
query: str,
documents: list[str],
top_n: int | None = None,
**_kwargs: Any,
) -> list[dict[str, Any]]:
return await asyncio.to_thread(
self._rerank_sync,
query,
documents,
top_n,
)
return LightRAG(
working_dir=str(self.working_dir),
workspace=self.workspace,
kv_storage="JsonKVStorage",
graph_storage="NetworkXStorage",
vector_storage="QdrantVectorDBStorage",
doc_status_storage="JsonDocStatusStorage",
llm_model_name=self.primary_chat.model,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dim,
func=embedding_func,
max_token_size=8192,
model_name=self.embedding.model,
supports_asymmetric=False,
),
rerank_model_func=rerank_model_func if self.reranker is not None else None,
enable_llm_cache=False,
enable_llm_cache_for_entity_extract=False,
)
def _initialize(self) -> None:
from lightrag.utils import always_get_an_event_loop
loop = always_get_an_event_loop()
loop.run_until_complete(self._rag.initialize_storages())
def finalize(self) -> None:
from lightrag.utils import always_get_an_event_loop
loop = always_get_an_event_loop()
loop.run_until_complete(self._rag.finalize_storages())
def query_data(self, query: str, *, conversation_history: list[dict[str, str]] | None = None) -> dict[str, Any]:
from lightrag import QueryParam
configured_mode = os.environ.get("LIGHTRAG_QUERY_MODE", DEFAULT_LIGHTRAG_QUERY_MODE).strip() or DEFAULT_LIGHTRAG_QUERY_MODE
mode = "naive" if configured_mode != "naive" and not self._graph_has_content() else configured_mode
started_at = perf_counter()
param = QueryParam(
mode=mode,
top_k=8,
chunk_top_k=10,
only_need_context=True,
response_type="Multiple Paragraphs",
conversation_history=conversation_history or [],
include_references=True,
)
try:
result = self._rag.query_data(query, param)
logger.info("LightRAG query completed mode=%s elapsed=%.2fs", mode, perf_counter() - started_at)
return result
except Exception:
if mode == "naive":
raise
logger.warning("LightRAG query mode=%s failed, retry with naive mode", mode)
fallback_param = QueryParam(
mode="naive",
top_k=8,
chunk_top_k=10,
only_need_context=True,
response_type="Multiple Paragraphs",
conversation_history=conversation_history or [],
include_references=True,
)
result = self._rag.query_data(query, fallback_param)
logger.info("LightRAG query completed mode=naive elapsed=%.2fs", perf_counter() - started_at)
return result
def _graph_has_content(self) -> bool:
if self._graph_has_content_cache is not None:
return self._graph_has_content_cache
graph_path = self.working_dir / self.workspace / "graph_chunk_entity_relation.graphml"
try:
graph_text = graph_path.read_text(encoding="utf-8")
except OSError:
self._graph_has_content_cache = False
return False
self._graph_has_content_cache = "<node " in graph_text or "<edge " in graph_text
return self._graph_has_content_cache
def insert_documents(
self,
*,
texts: list[str],
document_ids: list[str],
file_paths: list[str],
) -> str:
return self._rag.insert(texts, ids=document_ids, file_paths=file_paths)
def get_document_statuses(self, document_ids: list[str]) -> dict[str, Any]:
from lightrag.utils import always_get_an_event_loop
loop = always_get_an_event_loop()
return loop.run_until_complete(self._rag.aget_docs_by_ids(document_ids))
def delete_document(self, document_id: str) -> None:
from lightrag.utils import always_get_an_event_loop
loop = always_get_an_event_loop()
result = loop.run_until_complete(self._rag.adelete_by_doc_id(document_id))
status = str(getattr(result, "status", "") or "")
if status not in {"success", "not_found"}:
raise KnowledgeRagError(str(getattr(result, "message", "") or "LightRAG 删除文档失败。"))
def _probe_embedding_dimension(self, config: RuntimeModelConfig) -> int:
vectors = self._request_embeddings(config, ["dimension probe"])
if not vectors or not isinstance(vectors[0], list):
raise KnowledgeRagError("无法从 embedding 模型返回结果中解析向量维度。")
dimension = len(vectors[0])
if dimension <= 0:
raise KnowledgeRagError("embedding 模型返回了无效的向量维度。")
return dimension
def _embed_sync(self, texts: list[str]) -> Any:
import numpy as np
vectors = self._request_embeddings(self.embedding, texts)
return np.array(vectors, dtype=float)
def _rerank_sync(
self,
query: str,
documents: list[str],
top_n: int | None,
) -> list[dict[str, Any]]:
if self.reranker is None:
return []
status_code, body = self._request_rerank(
self.reranker,
query=query,
documents=documents,
top_n=top_n,
)
if status_code >= HTTPStatus.BAD_REQUEST:
raise KnowledgeRagError(f"reranker 模型返回异常状态码 {status_code}")
return _extract_rerank_results(body, provider=self.reranker.provider)
def _complete_sync(
self,
prompt: str,
system_prompt: str | None,
history_messages: list[dict[str, Any]],
keyword_extraction: bool,
kwargs: dict[str, Any],
) -> str:
del keyword_extraction
last_error: Exception | None = None
for config in [self.primary_chat, self.backup_chat]:
if config is None:
continue
try:
return self._request_chat_completion(
config,
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
max_tokens=int(kwargs.get("max_tokens") or 1200),
temperature=float(kwargs.get("temperature") or 0.1),
)
except Exception as exc: # pragma: no cover - runtime fallback
last_error = exc
logger.warning(
"LightRAG LLM request failed slot=%s provider=%s model=%s: %s",
config.slot,
config.provider,
config.model,
exc,
)
continue
raise KnowledgeRagError(f"LightRAG 调用知识模型失败:{last_error or '没有可用模型配置'}")
def _request_chat_completion(
self,
config: RuntimeModelConfig,
*,
prompt: str,
system_prompt: str | None,
history_messages: list[dict[str, Any]],
max_tokens: int,
temperature: float,
) -> str:
messages: list[dict[str, Any]] = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if config.provider == "Azure OpenAI":
url = f"{_build_azure_deployment_base(config.endpoint, config.model)}/chat/completions?api-version={AZURE_API_VERSION}"
payload = {
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
}
status_code, body = _send_json_request(
"POST",
url,
headers=_build_headers(config.api_key, use_bearer=False, use_api_key=True),
payload=payload,
timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS,
)
elif config.provider == "Ollama":
url = _ensure_path(_normalize_endpoint(config.endpoint), "api/chat")
payload = {
"model": config.model,
"messages": messages,
"stream": False,
"options": {
"num_predict": max_tokens,
"temperature": temperature,
},
}
status_code, body = _send_json_request(
"POST",
url,
headers={"Content-Type": "application/json", "Accept": "application/json"},
payload=payload,
timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS,
)
else:
url = _ensure_path(_normalize_endpoint(config.endpoint), "chat/completions")
payload = {
"model": config.model,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
}
status_code, body = _send_json_request(
"POST",
url,
headers=_build_headers(config.api_key, use_bearer=True),
payload=payload,
timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS,
)
if status_code >= HTTPStatus.BAD_REQUEST:
raise KnowledgeRagError(f"知识模型返回异常状态码 {status_code}")
return _extract_chat_text(body, provider=config.provider)
def _request_embeddings(self, config: RuntimeModelConfig, texts: list[str]) -> list[list[float]]:
if config.provider == "Azure OpenAI":
url = f"{_build_azure_deployment_base(config.endpoint, config.model)}/embeddings?api-version={AZURE_API_VERSION}"
payload = {"input": texts}
status_code, body = _send_json_request(
"POST",
url,
headers=_build_headers(config.api_key, use_bearer=False, use_api_key=True),
payload=payload,
timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
)
elif config.provider == "Ollama":
url = _ensure_path(_normalize_endpoint(config.endpoint), "api/embed")
payload = {"model": config.model, "input": texts}
status_code, body = _send_json_request(
"POST",
url,
headers={"Content-Type": "application/json", "Accept": "application/json"},
payload=payload,
timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
)
else:
url = _ensure_path(_normalize_endpoint(config.endpoint), "embeddings")
payload = {"model": config.model, "input": texts}
status_code, body = _send_json_request(
"POST",
url,
headers=_build_headers(config.api_key, use_bearer=True),
payload=payload,
timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
)
if status_code >= HTTPStatus.BAD_REQUEST:
raise KnowledgeRagError(f"embedding 模型返回异常状态码 {status_code}")
return _extract_embedding_vectors(body, provider=config.provider)
def _request_rerank(
self,
config: RuntimeModelConfig,
*,
query: str,
documents: list[str],
top_n: int | None,
) -> tuple[int, Any]:
if config.provider == "Azure OpenAI":
url = f"{_build_azure_deployment_base(config.endpoint, config.model)}/rerank?api-version={AZURE_API_VERSION}"
payload: dict[str, Any] = {
"query": query,
"documents": documents,
}
if top_n is not None:
payload["top_n"] = top_n
return _send_json_request(
"POST",
url,
headers=_build_headers(config.api_key, use_bearer=False, use_api_key=True),
payload=payload,
timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS,
)
if config.provider == "Ali":
url, payload = _build_ali_rerank_request(
config.model,
query=query,
documents=documents,
top_n=top_n,
)
return _send_json_request(
"POST",
url,
headers=_build_headers(config.api_key, use_bearer=True),
payload=payload,
timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS,
)
url = _ensure_path(_normalize_endpoint(config.endpoint), "rerank")
payload = {
"model": config.model,
"query": query,
"documents": documents,
}
if top_n is not None:
payload["top_n"] = top_n
return _send_json_request(
"POST",
url,
headers=_build_headers(config.api_key, use_bearer=True),
payload=payload,
timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS,
)
class KnowledgeRagService:
def __init__(self, db: Session | None = None, storage_root: Path | None = None) -> None:
self.db = db
self.storage_root = Path(storage_root or get_settings().resolved_storage_root_dir)
def query_knowledge(
self,
query: str,
*,
conversation_history: list[dict[str, str]] | None = None,
limit: int = 5,
) -> dict[str, Any]:
normalized_query = str(query or "").strip()
if not normalized_query:
return {
"result_type": "knowledge_search",
"query": "",
"record_count": 0,
"hits": [],
"references": [],
"message": "请先输入要检索的知识库问题。",
}
try:
runtime = self._get_runtime()
raw = runtime.query_data(normalized_query, conversation_history=conversation_history)
except Exception as exc:
logger.warning("Knowledge query failed: %s", exc)
return {
"result_type": "knowledge_search",
"query": normalized_query,
"record_count": 0,
"hits": [],
"references": [],
"message": f"知识库检索暂不可用:{exc}",
}
data = raw.get("data") if isinstance(raw, dict) else {}
chunks = list(data.get("chunks") or []) if isinstance(data, dict) else []
entities = list(data.get("entities") or []) if isinstance(data, dict) else []
references = list(data.get("references") or []) if isinstance(data, dict) else []
hits = self._build_hits_from_query_data(
query=normalized_query,
chunks=chunks,
entities=entities,
limit=limit,
)
if not hits:
return {
"result_type": "knowledge_search",
"query": normalized_query,
"record_count": 0,
"hits": [],
"references": [],
"raw_references": references,
"message": "当前知识库中没有检索到与本次问题直接匹配的内容。",
}
return {
"result_type": "knowledge_search",
"query": normalized_query,
"record_count": len(hits),
"hits": hits,
"references": [str(item.get("code") or "").strip() for item in hits if str(item.get("code") or "").strip()],
"raw_references": references,
"metadata": raw.get("metadata") if isinstance(raw, dict) else {},
"message": f"已从知识库中检索到 {len(hits)} 条相关内容。",
}
def index_documents(
self,
*,
document_ids: list[str],
force: bool = False,
) -> dict[str, Any]:
normalized_ids = [str(item).strip() for item in document_ids if str(item).strip()]
if not normalized_ids:
raise ValueError("没有可供索引的知识文档。")
from app.services.knowledge import KnowledgeService
from app.services.knowledge_normalizer import KnowledgeNormalizationService
knowledge_service = KnowledgeService(storage_root=self.storage_root, db=self.db)
normalization_service = (
KnowledgeNormalizationService(self.db) if self.db is not None else None
)
texts: list[str] = []
file_paths: list[str] = []
runtime = self._get_runtime()
existing_statuses = runtime.get_document_statuses(normalized_ids)
for document_id in normalized_ids:
entry = knowledge_service.get_document_entry(document_id)
if force and document_id in existing_statuses:
try:
runtime.delete_document(document_id)
except Exception as exc:
logger.warning("Delete existing LightRAG document failed doc_id=%s: %s", document_id, exc)
text = knowledge_service.extract_document_text(document_id)
if normalization_service is not None:
text = normalization_service.build_enriched_text(text)
texts.append(text)
file_paths.append(str((knowledge_service.library_root / entry["folder"] / entry["stored_name"]).resolve()))
track_id = runtime.insert_documents(
texts=texts,
document_ids=normalized_ids,
file_paths=file_paths,
)
statuses = runtime.get_document_statuses(normalized_ids)
succeeded_document_ids: list[str] = []
failed_documents: list[dict[str, str]] = []
for document_id in normalized_ids:
status_obj = statuses.get(document_id)
status_text = self._status_value(status_obj)
if self.is_query_ready_status(status_obj):
succeeded_document_ids.append(document_id)
continue
failed_documents.append(
{
"document_id": document_id,
"status": status_text or "unknown",
"error": self._status_error(status_obj),
}
)
return {
"track_id": track_id,
"requested_document_ids": normalized_ids,
"succeeded_document_ids": succeeded_document_ids,
"failed_documents": failed_documents,
"status_snapshot": {
document_id: self._serialize_status(status_obj)
for document_id, status_obj in statuses.items()
},
}
def get_document_status_map(self, document_ids: list[str] | None = None) -> dict[str, dict[str, Any]]:
target_ids = [str(item).strip() for item in document_ids or [] if str(item).strip()]
if not target_ids:
return {}
try:
statuses = self._get_runtime().get_document_statuses(target_ids)
except Exception as exc:
logger.warning("Load LightRAG document statuses failed: %s", exc)
return {}
return {
document_id: self._serialize_status(status_obj)
for document_id, status_obj in statuses.items()
}
def delete_document(self, document_id: str) -> None:
normalized_id = str(document_id or "").strip()
if not normalized_id:
return
try:
self._get_runtime().delete_document(normalized_id)
except Exception as exc:
logger.warning("Delete LightRAG document ignored doc_id=%s: %s", normalized_id, exc)
def _get_runtime(self) -> _LightRagRuntime:
global _runtime_instance, _runtime_signature
signature, runtime_kwargs = self._build_runtime_signature()
with _runtime_lock:
if _runtime_instance is not None and _runtime_signature == signature:
return _runtime_instance
if _runtime_instance is not None:
try:
_runtime_instance.finalize()
except Exception as exc: # pragma: no cover - best effort cleanup
logger.warning("Finalize previous LightRAG runtime failed: %s", exc)
_runtime_instance = _LightRagRuntime(**runtime_kwargs)
_runtime_signature = signature
return _runtime_instance
def _build_runtime_signature(self) -> tuple[tuple[Any, ...], dict[str, Any]]:
configs = self._load_runtime_configs()
settings = get_settings()
working_dir = (self.storage_root / "knowledge" / ".lightrag").resolve()
workspace = os.environ.get("LIGHTRAG_WORKSPACE", DEFAULT_LIGHTRAG_WORKSPACE).strip() or DEFAULT_LIGHTRAG_WORKSPACE
qdrant_url = os.environ.get("QDRANT_URL", "").strip() or _resolve_default_qdrant_url()
qdrant_api_key = os.environ.get("QDRANT_API_KEY", "").strip()
signature = (
str(working_dir),
workspace,
qdrant_url,
qdrant_api_key,
configs["main"].provider,
configs["main"].model,
configs["main"].endpoint,
configs["main"].api_key,
configs["backup"].provider if configs["backup"] else "",
configs["backup"].model if configs["backup"] else "",
configs["backup"].endpoint if configs["backup"] else "",
configs["backup"].api_key if configs["backup"] else "",
configs["embedding"].provider,
configs["embedding"].model,
configs["embedding"].endpoint,
configs["embedding"].api_key,
configs["reranker"].provider if configs["reranker"] else "",
configs["reranker"].model if configs["reranker"] else "",
configs["reranker"].endpoint if configs["reranker"] else "",
configs["reranker"].api_key if configs["reranker"] else "",
str(settings.resolved_storage_root_dir),
)
return signature, {
"working_dir": working_dir,
"workspace": workspace,
"qdrant_url": qdrant_url,
"qdrant_api_key": qdrant_api_key,
"primary_chat": configs["main"],
"backup_chat": configs["backup"],
"embedding": configs["embedding"],
"reranker": configs["reranker"],
}
def _load_runtime_configs(self) -> dict[str, RuntimeModelConfig | None]:
owned_session = False
session = self.db
if session is None:
session = get_session_factory()()
owned_session = True
try:
settings_service = SettingsService(session)
main = self._normalize_runtime_model(settings_service.get_runtime_model_config("main"))
embedding = self._normalize_runtime_model(settings_service.get_runtime_model_config("embedding"))
try:
backup_raw = settings_service.get_runtime_model_config("backup")
backup = self._normalize_runtime_model(backup_raw)
except Exception:
backup = None
try:
reranker_raw = settings_service.get_runtime_model_config("reranker")
reranker = self._normalize_runtime_model(reranker_raw)
except Exception:
reranker = None
if backup is not None and (
not backup.endpoint
or not backup.model
or (backup.provider != "Ollama" and not backup.api_key)
):
backup = None
if reranker is not None and (
not reranker.endpoint
or not reranker.model
or (reranker.provider != "Ollama" and not reranker.api_key)
):
reranker = None
if not main.endpoint or not main.model:
raise KnowledgeRagError("主对话模型未配置,无法初始化 LightRAG。")
if main.provider != "Ollama" and not main.api_key:
raise KnowledgeRagError("主对话模型缺少 API Key无法初始化 LightRAG。")
if not embedding.endpoint or not embedding.model:
raise KnowledgeRagError("Embedding 模型未配置,无法初始化 LightRAG。")
if embedding.provider != "Ollama" and not embedding.api_key:
raise KnowledgeRagError("Embedding 模型缺少 API Key无法初始化 LightRAG。")
return {
"main": main,
"backup": backup,
"embedding": embedding,
"reranker": reranker,
}
finally:
if owned_session and session is not None:
session.close()
@staticmethod
def _normalize_runtime_model(payload: dict[str, str]) -> RuntimeModelConfig:
return RuntimeModelConfig(
slot=str(payload.get("slot") or "").strip(),
provider=str(payload.get("provider") or "").strip(),
model=str(payload.get("model") or "").strip(),
endpoint=str(payload.get("endpoint") or "").strip(),
api_key=str(payload.get("apiKey") or "").strip(),
capability=str(payload.get("capability") or "").strip(),
)
@staticmethod
def _build_hits_from_query_data(
*,
query: str,
chunks: list[dict[str, Any]],
entities: list[dict[str, Any]],
limit: int,
) -> list[dict[str, Any]]:
entity_tags_by_path: dict[str, list[str]] = {}
for entity in entities:
if not isinstance(entity, dict):
continue
file_path = str(entity.get("file_path") or "").strip()
entity_name = str(entity.get("entity_name") or "").strip()
if not file_path or not entity_name:
continue
entity_tags_by_path.setdefault(file_path, [])
if entity_name not in entity_tags_by_path[file_path]:
entity_tags_by_path[file_path].append(entity_name)
query_terms = _extract_query_terms(query)
prefers_tabular_evidence = any(hint in query for hint in TABLE_OR_STANDARD_QUERY_HINTS)
candidates: list[dict[str, Any]] = []
for rank, chunk in enumerate(chunks, start=1):
if not isinstance(chunk, dict):
continue
file_path = str(chunk.get("file_path") or "").strip()
chunk_id = str(chunk.get("chunk_id") or "").strip()
content = str(chunk.get("content") or "").strip()
if not file_path or not content:
continue
document_id, document_name = _parse_document_identity(file_path)
normalized_chunk_id = chunk_id or f"path-{rank}"
normalized_content = _truncate_text(content, max_length=MAX_KNOWLEDGE_HIT_CONTENT_LENGTH)
excerpt = _build_query_focused_excerpt(
normalized_content,
query_terms=query_terms,
max_length=MAX_KNOWLEDGE_HIT_EXCERPT_LENGTH,
)
candidates.append(
{
"code": f"knowledge.{document_id or 'unknown'}.{normalized_chunk_id}",
"candidate_id": normalized_chunk_id,
"title": document_name or "知识库文档",
"content": normalized_content,
"excerpt": excerpt,
"document_id": document_id,
"document_name": document_name or Path(file_path).name,
"version": None,
"updated_at": None,
"score": max(1, 100 - rank),
"tags": entity_tags_by_path.get(file_path, [])[:5],
"evidence": [normalized_chunk_id],
"file_path": file_path,
"_rank": rank,
}
)
ranked = sorted(
candidates,
key=lambda item: (
_score_knowledge_hit(
item,
query_terms=query_terms,
prefers_tabular_evidence=prefers_tabular_evidence,
),
-int(item.get("_rank") or 0),
),
reverse=True,
)
hits: list[dict[str, Any]] = []
for item in ranked[: max(1, limit)]:
normalized = dict(item)
normalized.pop("_rank", None)
hits.append(normalized)
return hits
@staticmethod
def _serialize_status(status_obj: Any) -> dict[str, Any]:
if status_obj is None:
return {}
if hasattr(status_obj, "__dict__"):
payload = dict(status_obj.__dict__)
elif isinstance(status_obj, dict):
payload = dict(status_obj)
else:
payload = {}
payload["status"] = KnowledgeRagService._status_value(status_obj)
payload["error_msg"] = KnowledgeRagService._status_error(status_obj)
payload["query_ready"] = KnowledgeRagService.is_query_ready_status(status_obj)
return payload
@staticmethod
def _status_value(status_obj: Any) -> str:
raw_status = getattr(status_obj, "status", None)
if raw_status is None and isinstance(status_obj, dict):
raw_status = status_obj.get("status")
normalized = str(raw_status or "").strip().lower()
if "." in normalized:
normalized = normalized.split(".")[-1].strip()
if ":" in normalized and normalized.endswith(">"):
normalized = normalized.split(":")[0].strip("<> '\"")
return normalized
@staticmethod
def _status_error(status_obj: Any) -> str:
value = getattr(status_obj, "error_msg", None)
if value is None and isinstance(status_obj, dict):
value = status_obj.get("error_msg")
return str(value or "").strip()
@staticmethod
def is_query_ready_status(status_obj: Any) -> bool:
status_text = KnowledgeRagService._status_value(status_obj)
if status_text in {"failed", "error", "aborted"}:
return False
if status_text == "processed":
return True
if status_text in {"pending", "processing", "preprocessed"}:
return False
chunks_count = getattr(status_obj, "chunks_count", None)
if chunks_count is None and isinstance(status_obj, dict):
chunks_count = status_obj.get("chunks_count")
try:
if int(chunks_count or 0) > 0:
return True
except (TypeError, ValueError):
pass
chunks_list = getattr(status_obj, "chunks_list", None)
if chunks_list is None and isinstance(status_obj, dict):
chunks_list = status_obj.get("chunks_list")
return bool(chunks_list)
def shutdown_knowledge_rag_runtime() -> None:
global _runtime_instance, _runtime_signature
with _runtime_lock:
if _runtime_instance is None:
return
try:
_runtime_instance.finalize()
except Exception as exc: # pragma: no cover - best effort cleanup
logger.warning("Finalize LightRAG runtime failed during shutdown: %s", exc)
_runtime_instance = None
_runtime_signature = None
def _normalize_endpoint(endpoint: str) -> str:
normalized = str(endpoint or "").strip()
if not normalized:
raise KnowledgeRagError("模型 endpoint 不能为空。")
return normalized.rstrip("/")
def _ensure_path(endpoint: str, suffix: str) -> str:
suffix = suffix.lstrip("/")
if endpoint.endswith(suffix):
return endpoint
return f"{endpoint}/{suffix}"
def _build_azure_deployment_base(endpoint: str, model: str) -> str:
normalized_endpoint = _normalize_endpoint(endpoint)
quoted_model = quote(model, safe="")
if "/openai/deployments/" in normalized_endpoint:
return normalized_endpoint
if "/openai/v1" in normalized_endpoint:
resource_root = normalized_endpoint.split("/openai/v1", maxsplit=1)[0]
return f"{resource_root}/openai/deployments/{quoted_model}"
if normalized_endpoint.endswith("/openai"):
return f"{normalized_endpoint}/deployments/{quoted_model}"
return f"{normalized_endpoint}/openai/deployments/{quoted_model}"
def _build_headers(
api_key: str,
*,
use_bearer: bool,
use_api_key: bool = False,
) -> dict[str, str]:
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
normalized_key = str(api_key or "").strip()
if normalized_key:
if use_api_key:
headers["api-key"] = normalized_key
elif use_bearer:
headers["Authorization"] = f"Bearer {normalized_key}"
return headers
def _send_json_request(
method: str,
url: str,
*,
headers: dict[str, str],
payload: dict[str, Any],
timeout_seconds: int,
) -> tuple[int, Any]:
data = json.dumps(payload).encode("utf-8")
request = Request(url=url, data=data, headers=headers, method=method)
try:
with urlopen(request, timeout=timeout_seconds) as response: # noqa: S310
body = response.read().decode("utf-8") if response.length != 0 else ""
return response.status, _parse_json_body(body)
except HTTPError as exc: # pragma: no cover - runtime path
body = exc.read().decode("utf-8", errors="ignore")
detail = _extract_error_message(_parse_json_body(body)) or f"接口返回 {exc.code}"
raise KnowledgeRagError(detail) from exc
except URLError as exc: # pragma: no cover - runtime path
raise KnowledgeRagError(f"无法连接模型接口:{getattr(exc, 'reason', exc)}") from exc
except TimeoutError as exc: # pragma: no cover - runtime path
raise KnowledgeRagError("模型接口调用超时。") from exc
def _parse_json_body(body: str) -> Any:
if not body:
return None
try:
return json.loads(body)
except json.JSONDecodeError:
return {"message": body}
def _extract_error_message(payload: Any) -> str | None:
if payload is None:
return None
if isinstance(payload, dict):
if isinstance(payload.get("detail"), str):
return payload["detail"]
if isinstance(payload.get("message"), str):
return payload["message"]
error_payload = payload.get("error")
if isinstance(error_payload, dict) and isinstance(error_payload.get("message"), str):
return error_payload["message"]
if isinstance(payload, str):
return payload
return None
def _extract_chat_text(payload: Any, *, provider: str) -> str:
if provider == "Ollama":
message = payload.get("message") if isinstance(payload, dict) else None
if isinstance(message, dict):
return str(message.get("content") or "").strip()
return ""
if not isinstance(payload, dict):
return ""
choices = payload.get("choices")
if not isinstance(choices, list) or not choices:
return ""
first_choice = choices[0]
if not isinstance(first_choice, dict):
return ""
message = first_choice.get("message")
if isinstance(message, dict):
content = message.get("content")
if isinstance(content, str):
return content.strip()
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
parts.append(str(item.get("text") or "").strip())
return "\n".join(part for part in parts if part).strip()
text = first_choice.get("text")
if isinstance(text, str):
return text.strip()
return ""
def _extract_embedding_vectors(payload: Any, *, provider: str) -> list[list[float]]:
if provider == "Ollama":
embeddings = payload.get("embeddings") if isinstance(payload, dict) else None
if isinstance(embeddings, list):
return [[float(value) for value in item] for item in embeddings if isinstance(item, list)]
embedding = payload.get("embedding") if isinstance(payload, dict) else None
if isinstance(embedding, list):
return [[float(value) for value in embedding]]
raise KnowledgeRagError("Ollama embedding 返回格式无法识别。")
if not isinstance(payload, dict):
raise KnowledgeRagError("embedding 接口返回格式无效。")
data = payload.get("data")
if not isinstance(data, list) or not data:
raise KnowledgeRagError("embedding 接口没有返回 data。")
vectors: list[list[float]] = []
for item in data:
if not isinstance(item, dict):
continue
embedding = item.get("embedding")
if isinstance(embedding, list):
vectors.append([float(value) for value in embedding])
if not vectors:
raise KnowledgeRagError("embedding 接口返回中未找到向量数据。")
return vectors
def _build_ali_rerank_request(
model: str,
*,
query: str,
documents: list[str],
top_n: int | None,
) -> tuple[str, dict[str, Any]]:
normalized_model = str(model or "").strip()
if normalized_model == "qwen3-rerank":
payload: dict[str, Any] = {
"model": normalized_model,
"query": query,
"documents": documents,
}
if top_n is not None:
payload["top_n"] = top_n
return "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", payload
payload = {
"model": normalized_model,
"input": {
"query": query,
"documents": documents,
},
"parameters": {
"return_documents": False,
},
}
if top_n is not None:
payload["parameters"]["top_n"] = top_n
return "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank", payload
def _extract_rerank_results(payload: Any, *, provider: str) -> list[dict[str, Any]]:
if not isinstance(payload, dict):
return []
if provider == "Ali" and isinstance(payload.get("output"), dict):
results = payload["output"].get("results")
else:
results = payload.get("results")
if not isinstance(results, list):
return []
normalized: list[dict[str, Any]] = []
for item in results:
if not isinstance(item, dict):
continue
try:
normalized.append(
{
"index": int(item["index"]),
"relevance_score": float(item["relevance_score"]),
}
)
except (KeyError, TypeError, ValueError):
continue
return normalized
def _parse_document_identity(file_path: str) -> tuple[str, str]:
path = Path(str(file_path or "").strip())
name = path.name
if "__" not in name:
return "", name
document_id, document_name = name.split("__", maxsplit=1)
return document_id.strip(), document_name.strip()
def _build_excerpt(text: str, *, max_length: int = 180) -> str:
normalized = " ".join(str(text or "").split()).strip()
if len(normalized) <= max_length:
return normalized
return f"{normalized[: max_length - 3].rstrip()}..."
def _build_query_focused_excerpt(
text: str,
*,
query_terms: list[str],
max_length: int = 180,
) -> str:
normalized = " ".join(str(text or "").split()).strip()
if not normalized:
return ""
lowered = normalized.lower()
match_positions = [
lowered.find(term)
for term in query_terms
if term and lowered.find(term) >= 0
]
if not match_positions:
return _build_excerpt(normalized, max_length=max_length)
start = max(0, min(match_positions) - max_length // 3)
end = min(len(normalized), start + max_length)
snippet = normalized[start:end].strip()
if start > 0:
snippet = f"...{snippet.lstrip()}"
if end < len(normalized):
snippet = f"{snippet.rstrip()}..."
return snippet
def _truncate_text(text: str, *, max_length: int) -> str:
normalized = str(text or "").strip()
if len(normalized) <= max_length:
return normalized
return f"{normalized[: max_length - 3].rstrip()}..."
def _resolve_default_qdrant_url() -> str:
if _hostname_resolves("qdrant"):
return CONTAINER_QDRANT_URL
return DEFAULT_QDRANT_URL
def _hostname_resolves(hostname: str) -> bool:
try:
socket.getaddrinfo(hostname, None)
except OSError:
return False
return True
def _extract_query_terms(query: str) -> list[str]:
normalized_query = str(query or "").strip().lower()
if not normalized_query:
return []
terms: list[str] = []
seen: set[str] = set()
def remember(term: str) -> None:
normalized_term = str(term or "").strip().lower()
if (
not normalized_term
or normalized_term in seen
or normalized_term in QUERY_TERM_STOPWORDS
or len(normalized_term) < 2
):
return
seen.add(normalized_term)
terms.append(normalized_term)
for item in re.findall(r"[a-z0-9][a-z0-9_\-]{1,}", normalized_query):
remember(item)
for block in re.findall(r"[\u4e00-\u9fff]{2,20}", normalized_query):
if len(block) <= 4:
remember(block)
continue
for size in (4, 3, 2):
for start in range(0, len(block) - size + 1):
remember(block[start : start + size])
if len(terms) >= MAX_QUERY_TERMS:
return terms
return terms[:MAX_QUERY_TERMS]
def _score_knowledge_hit(
item: dict[str, Any],
*,
query_terms: list[str],
prefers_tabular_evidence: bool,
) -> int:
rank = max(1, int(item.get("_rank") or 1))
title = str(item.get("title") or item.get("document_name") or "").lower()
content = str(item.get("content") or "").lower()
excerpt = str(item.get("excerpt") or "").lower()
tags = " ".join(str(value).lower() for value in list(item.get("tags") or [])[:5])
haystack = "\n".join([title, excerpt, tags, content[:1200]])
score = max(1, 120 - rank * 4)
matched_terms = [term for term in query_terms if term in haystack]
score += len(matched_terms) * 8
score += sum(1 for term in matched_terms if term in title) * 6
leading_appendix_marker = _leading_structured_appendix_marker(content)
if leading_appendix_marker == "# 章节导航":
score -= 24
elif leading_appendix_marker == "# 重点章节摘录":
score += 4 if matched_terms else -12
elif leading_appendix_marker == "# 问答线索补充":
score += 8 if matched_terms and not prefers_tabular_evidence else 2 if matched_terms else -20
elif leading_appendix_marker == "# 结构化表格补充":
if prefers_tabular_evidence and matched_terms:
score += 16
elif matched_terms:
score += 6
else:
score -= 18
if prefers_tabular_evidence and matched_terms and ("|" in content or "" in content):
score += 10
if matched_terms and any(marker in content for marker in ("", ":")):
score += 10
if matched_terms and "\n" in content:
score += 4
if matched_terms and any(marker in content for marker in ("附表", "", "")):
score += 4
if not prefers_tabular_evidence and matched_terms and any(marker in content for marker in ("", "", "", "-", "")):
score += 4
if title and any(term in title for term in query_terms):
score += 6
if re.search(r"没有.{0,8}(信息|规定|说明|依据)", content):
score -= 12
return score
def _leading_structured_appendix_marker(content: str) -> str:
normalized = str(content or "").lstrip()
for marker in STRUCTURED_APPENDIX_LEADING_MARKERS:
index = normalized.find(marker)
if 0 <= index <= STRUCTURED_APPENDIX_LEADING_WINDOW:
return marker
return ""