1262 lines
46 KiB
Python
1262 lines
46 KiB
Python
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import asyncio
|
|||
|
|
import json
|
|||
|
|
import os
|
|||
|
|
import re
|
|||
|
|
import socket
|
|||
|
|
import threading
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
from datetime import UTC, datetime
|
|||
|
|
from functools import partial
|
|||
|
|
from http import HTTPStatus
|
|||
|
|
from pathlib import Path
|
|||
|
|
from time import perf_counter
|
|||
|
|
from typing import Any
|
|||
|
|
from urllib.error import HTTPError, URLError
|
|||
|
|
from urllib.parse import quote
|
|||
|
|
from urllib.request import Request, urlopen
|
|||
|
|
|
|||
|
|
from sqlalchemy.orm import Session
|
|||
|
|
|
|||
|
|
from app.core.config import get_settings
|
|||
|
|
from app.core.logging import get_logger
|
|||
|
|
from app.db.session import get_session_factory
|
|||
|
|
from app.services.settings import SettingsService
|
|||
|
|
|
|||
|
|
logger = get_logger("app.services.knowledge_rag")
|
|||
|
|
|
|||
|
|
DEFAULT_QDRANT_URL = "http://127.0.0.1:6333"
|
|||
|
|
CONTAINER_QDRANT_URL = "http://qdrant:6333"
|
|||
|
|
DEFAULT_LIGHTRAG_WORKSPACE = "x_financial_knowledge"
|
|||
|
|
DEFAULT_LIGHTRAG_QUERY_MODE = "naive"
|
|||
|
|
DEFAULT_LLM_TIMEOUT_SECONDS = 180
|
|||
|
|
DEFAULT_EMBEDDING_TIMEOUT_SECONDS = 120
|
|||
|
|
MAX_KNOWLEDGE_HIT_CONTENT_LENGTH = 2200
|
|||
|
|
MAX_QUERY_TERMS = 12
|
|||
|
|
QUERY_TERM_STOPWORDS = {
|
|||
|
|
"什么",
|
|||
|
|
"多少",
|
|||
|
|
"哪些",
|
|||
|
|
"怎么",
|
|||
|
|
"如何",
|
|||
|
|
"请问",
|
|||
|
|
"一下",
|
|||
|
|
"关于",
|
|||
|
|
"规定",
|
|||
|
|
"标准",
|
|||
|
|
"可以",
|
|||
|
|
"是否",
|
|||
|
|
"一个",
|
|||
|
|
"哪些人",
|
|||
|
|
}
|
|||
|
|
TABLE_OR_STANDARD_QUERY_HINTS = (
|
|||
|
|
"标准",
|
|||
|
|
"金额",
|
|||
|
|
"限额",
|
|||
|
|
"补贴",
|
|||
|
|
"住宿",
|
|||
|
|
"餐费",
|
|||
|
|
"交通",
|
|||
|
|
"报销",
|
|||
|
|
"档位",
|
|||
|
|
"额度",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
_runtime_lock = threading.RLock()
|
|||
|
|
_runtime_instance: _LightRagRuntime | None = None
|
|||
|
|
_runtime_signature: tuple[Any, ...] | None = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class KnowledgeRagError(RuntimeError):
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass(frozen=True, slots=True)
|
|||
|
|
class RuntimeModelConfig:
|
|||
|
|
slot: str
|
|||
|
|
provider: str
|
|||
|
|
model: str
|
|||
|
|
endpoint: str
|
|||
|
|
api_key: str
|
|||
|
|
capability: str
|
|||
|
|
|
|||
|
|
|
|||
|
|
class _LightRagRuntime:
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
*,
|
|||
|
|
working_dir: Path,
|
|||
|
|
workspace: str,
|
|||
|
|
qdrant_url: str,
|
|||
|
|
qdrant_api_key: str,
|
|||
|
|
primary_chat: RuntimeModelConfig,
|
|||
|
|
backup_chat: RuntimeModelConfig | None,
|
|||
|
|
embedding: RuntimeModelConfig,
|
|||
|
|
reranker: RuntimeModelConfig | None,
|
|||
|
|
) -> None:
|
|||
|
|
self.working_dir = working_dir
|
|||
|
|
self.workspace = workspace
|
|||
|
|
self.qdrant_url = qdrant_url
|
|||
|
|
self.qdrant_api_key = qdrant_api_key
|
|||
|
|
self.primary_chat = primary_chat
|
|||
|
|
self.backup_chat = backup_chat
|
|||
|
|
self.embedding = embedding
|
|||
|
|
self.reranker = reranker
|
|||
|
|
self._rag = self._build_rag()
|
|||
|
|
self._initialize()
|
|||
|
|
self._graph_has_content_cache: bool | None = None
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def rag(self):
|
|||
|
|
return self._rag
|
|||
|
|
|
|||
|
|
def _build_rag(self):
|
|||
|
|
try:
|
|||
|
|
from lightrag import LightRAG
|
|||
|
|
from lightrag.utils import EmbeddingFunc
|
|||
|
|
except ImportError as exc: # pragma: no cover - exercised in runtime env
|
|||
|
|
raise KnowledgeRagError(
|
|||
|
|
"LightRAG 依赖未安装,请先在 server 环境执行依赖安装。"
|
|||
|
|
) from exc
|
|||
|
|
|
|||
|
|
self.working_dir.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
if self.qdrant_url:
|
|||
|
|
os.environ["QDRANT_URL"] = self.qdrant_url
|
|||
|
|
if self.qdrant_api_key:
|
|||
|
|
os.environ["QDRANT_API_KEY"] = self.qdrant_api_key
|
|||
|
|
|
|||
|
|
embedding_dim = self._probe_embedding_dimension(self.embedding)
|
|||
|
|
logger.info(
|
|||
|
|
"Initialize LightRAG runtime workspace=%s qdrant=%s embedding_model=%s dim=%s",
|
|||
|
|
self.workspace,
|
|||
|
|
self.qdrant_url,
|
|||
|
|
self.embedding.model,
|
|||
|
|
embedding_dim,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
async def embedding_func(texts: list[str]) -> Any:
|
|||
|
|
return await asyncio.to_thread(self._embed_sync, texts)
|
|||
|
|
|
|||
|
|
async def llm_model_func(
|
|||
|
|
prompt: str,
|
|||
|
|
system_prompt: str | None = None,
|
|||
|
|
history_messages: list[dict[str, Any]] | None = None,
|
|||
|
|
keyword_extraction: bool = False,
|
|||
|
|
**kwargs: Any,
|
|||
|
|
) -> str:
|
|||
|
|
return await asyncio.to_thread(
|
|||
|
|
self._complete_sync,
|
|||
|
|
prompt,
|
|||
|
|
system_prompt,
|
|||
|
|
history_messages or [],
|
|||
|
|
keyword_extraction,
|
|||
|
|
kwargs,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
async def rerank_model_func(
|
|||
|
|
query: str,
|
|||
|
|
documents: list[str],
|
|||
|
|
top_n: int | None = None,
|
|||
|
|
**_kwargs: Any,
|
|||
|
|
) -> list[dict[str, Any]]:
|
|||
|
|
return await asyncio.to_thread(
|
|||
|
|
self._rerank_sync,
|
|||
|
|
query,
|
|||
|
|
documents,
|
|||
|
|
top_n,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return LightRAG(
|
|||
|
|
working_dir=str(self.working_dir),
|
|||
|
|
workspace=self.workspace,
|
|||
|
|
kv_storage="JsonKVStorage",
|
|||
|
|
graph_storage="NetworkXStorage",
|
|||
|
|
vector_storage="QdrantVectorDBStorage",
|
|||
|
|
doc_status_storage="JsonDocStatusStorage",
|
|||
|
|
llm_model_name=self.primary_chat.model,
|
|||
|
|
llm_model_func=llm_model_func,
|
|||
|
|
embedding_func=EmbeddingFunc(
|
|||
|
|
embedding_dim=embedding_dim,
|
|||
|
|
func=embedding_func,
|
|||
|
|
max_token_size=8192,
|
|||
|
|
model_name=self.embedding.model,
|
|||
|
|
supports_asymmetric=False,
|
|||
|
|
),
|
|||
|
|
rerank_model_func=rerank_model_func if self.reranker is not None else None,
|
|||
|
|
enable_llm_cache=False,
|
|||
|
|
enable_llm_cache_for_entity_extract=False,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _initialize(self) -> None:
|
|||
|
|
from lightrag.utils import always_get_an_event_loop
|
|||
|
|
|
|||
|
|
loop = always_get_an_event_loop()
|
|||
|
|
loop.run_until_complete(self._rag.initialize_storages())
|
|||
|
|
|
|||
|
|
def finalize(self) -> None:
|
|||
|
|
from lightrag.utils import always_get_an_event_loop
|
|||
|
|
|
|||
|
|
loop = always_get_an_event_loop()
|
|||
|
|
loop.run_until_complete(self._rag.finalize_storages())
|
|||
|
|
|
|||
|
|
def query_data(self, query: str, *, conversation_history: list[dict[str, str]] | None = None) -> dict[str, Any]:
|
|||
|
|
from lightrag import QueryParam
|
|||
|
|
|
|||
|
|
configured_mode = os.environ.get("LIGHTRAG_QUERY_MODE", DEFAULT_LIGHTRAG_QUERY_MODE).strip() or DEFAULT_LIGHTRAG_QUERY_MODE
|
|||
|
|
mode = "naive" if configured_mode != "naive" and not self._graph_has_content() else configured_mode
|
|||
|
|
started_at = perf_counter()
|
|||
|
|
param = QueryParam(
|
|||
|
|
mode=mode,
|
|||
|
|
top_k=8,
|
|||
|
|
chunk_top_k=10,
|
|||
|
|
only_need_context=True,
|
|||
|
|
response_type="Multiple Paragraphs",
|
|||
|
|
conversation_history=conversation_history or [],
|
|||
|
|
include_references=True,
|
|||
|
|
)
|
|||
|
|
try:
|
|||
|
|
result = self._rag.query_data(query, param)
|
|||
|
|
logger.info("LightRAG query completed mode=%s elapsed=%.2fs", mode, perf_counter() - started_at)
|
|||
|
|
return result
|
|||
|
|
except Exception:
|
|||
|
|
if mode == "naive":
|
|||
|
|
raise
|
|||
|
|
logger.warning("LightRAG query mode=%s failed, retry with naive mode", mode)
|
|||
|
|
fallback_param = QueryParam(
|
|||
|
|
mode="naive",
|
|||
|
|
top_k=8,
|
|||
|
|
chunk_top_k=10,
|
|||
|
|
only_need_context=True,
|
|||
|
|
response_type="Multiple Paragraphs",
|
|||
|
|
conversation_history=conversation_history or [],
|
|||
|
|
include_references=True,
|
|||
|
|
)
|
|||
|
|
result = self._rag.query_data(query, fallback_param)
|
|||
|
|
logger.info("LightRAG query completed mode=naive elapsed=%.2fs", perf_counter() - started_at)
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
def _graph_has_content(self) -> bool:
|
|||
|
|
if self._graph_has_content_cache is not None:
|
|||
|
|
return self._graph_has_content_cache
|
|||
|
|
|
|||
|
|
graph_path = self.working_dir / self.workspace / "graph_chunk_entity_relation.graphml"
|
|||
|
|
try:
|
|||
|
|
graph_text = graph_path.read_text(encoding="utf-8")
|
|||
|
|
except OSError:
|
|||
|
|
self._graph_has_content_cache = False
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
self._graph_has_content_cache = "<node " in graph_text or "<edge " in graph_text
|
|||
|
|
return self._graph_has_content_cache
|
|||
|
|
|
|||
|
|
def insert_documents(
|
|||
|
|
self,
|
|||
|
|
*,
|
|||
|
|
texts: list[str],
|
|||
|
|
document_ids: list[str],
|
|||
|
|
file_paths: list[str],
|
|||
|
|
) -> str:
|
|||
|
|
return self._rag.insert(texts, ids=document_ids, file_paths=file_paths)
|
|||
|
|
|
|||
|
|
def get_document_statuses(self, document_ids: list[str]) -> dict[str, Any]:
|
|||
|
|
from lightrag.utils import always_get_an_event_loop
|
|||
|
|
|
|||
|
|
loop = always_get_an_event_loop()
|
|||
|
|
return loop.run_until_complete(self._rag.aget_docs_by_ids(document_ids))
|
|||
|
|
|
|||
|
|
def delete_document(self, document_id: str) -> None:
|
|||
|
|
from lightrag.utils import always_get_an_event_loop
|
|||
|
|
|
|||
|
|
loop = always_get_an_event_loop()
|
|||
|
|
result = loop.run_until_complete(self._rag.adelete_by_doc_id(document_id))
|
|||
|
|
status = str(getattr(result, "status", "") or "")
|
|||
|
|
if status not in {"success", "not_found"}:
|
|||
|
|
raise KnowledgeRagError(str(getattr(result, "message", "") or "LightRAG 删除文档失败。"))
|
|||
|
|
|
|||
|
|
def _probe_embedding_dimension(self, config: RuntimeModelConfig) -> int:
|
|||
|
|
vectors = self._request_embeddings(config, ["dimension probe"])
|
|||
|
|
if not vectors or not isinstance(vectors[0], list):
|
|||
|
|
raise KnowledgeRagError("无法从 embedding 模型返回结果中解析向量维度。")
|
|||
|
|
dimension = len(vectors[0])
|
|||
|
|
if dimension <= 0:
|
|||
|
|
raise KnowledgeRagError("embedding 模型返回了无效的向量维度。")
|
|||
|
|
return dimension
|
|||
|
|
|
|||
|
|
def _embed_sync(self, texts: list[str]) -> Any:
|
|||
|
|
import numpy as np
|
|||
|
|
|
|||
|
|
vectors = self._request_embeddings(self.embedding, texts)
|
|||
|
|
return np.array(vectors, dtype=float)
|
|||
|
|
|
|||
|
|
def _rerank_sync(
|
|||
|
|
self,
|
|||
|
|
query: str,
|
|||
|
|
documents: list[str],
|
|||
|
|
top_n: int | None,
|
|||
|
|
) -> list[dict[str, Any]]:
|
|||
|
|
if self.reranker is None:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
status_code, body = self._request_rerank(
|
|||
|
|
self.reranker,
|
|||
|
|
query=query,
|
|||
|
|
documents=documents,
|
|||
|
|
top_n=top_n,
|
|||
|
|
)
|
|||
|
|
if status_code >= HTTPStatus.BAD_REQUEST:
|
|||
|
|
raise KnowledgeRagError(f"reranker 模型返回异常状态码 {status_code}。")
|
|||
|
|
return _extract_rerank_results(body, provider=self.reranker.provider)
|
|||
|
|
|
|||
|
|
def _complete_sync(
|
|||
|
|
self,
|
|||
|
|
prompt: str,
|
|||
|
|
system_prompt: str | None,
|
|||
|
|
history_messages: list[dict[str, Any]],
|
|||
|
|
keyword_extraction: bool,
|
|||
|
|
kwargs: dict[str, Any],
|
|||
|
|
) -> str:
|
|||
|
|
del keyword_extraction
|
|||
|
|
|
|||
|
|
last_error: Exception | None = None
|
|||
|
|
for config in [self.primary_chat, self.backup_chat]:
|
|||
|
|
if config is None:
|
|||
|
|
continue
|
|||
|
|
try:
|
|||
|
|
return self._request_chat_completion(
|
|||
|
|
config,
|
|||
|
|
prompt=prompt,
|
|||
|
|
system_prompt=system_prompt,
|
|||
|
|
history_messages=history_messages,
|
|||
|
|
max_tokens=int(kwargs.get("max_tokens") or 1200),
|
|||
|
|
temperature=float(kwargs.get("temperature") or 0.1),
|
|||
|
|
)
|
|||
|
|
except Exception as exc: # pragma: no cover - runtime fallback
|
|||
|
|
last_error = exc
|
|||
|
|
logger.warning(
|
|||
|
|
"LightRAG LLM request failed slot=%s provider=%s model=%s: %s",
|
|||
|
|
config.slot,
|
|||
|
|
config.provider,
|
|||
|
|
config.model,
|
|||
|
|
exc,
|
|||
|
|
)
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
raise KnowledgeRagError(f"LightRAG 调用知识模型失败:{last_error or '没有可用模型配置'}")
|
|||
|
|
|
|||
|
|
def _request_chat_completion(
|
|||
|
|
self,
|
|||
|
|
config: RuntimeModelConfig,
|
|||
|
|
*,
|
|||
|
|
prompt: str,
|
|||
|
|
system_prompt: str | None,
|
|||
|
|
history_messages: list[dict[str, Any]],
|
|||
|
|
max_tokens: int,
|
|||
|
|
temperature: float,
|
|||
|
|
) -> str:
|
|||
|
|
messages: list[dict[str, Any]] = []
|
|||
|
|
if system_prompt:
|
|||
|
|
messages.append({"role": "system", "content": system_prompt})
|
|||
|
|
messages.extend(history_messages)
|
|||
|
|
messages.append({"role": "user", "content": prompt})
|
|||
|
|
|
|||
|
|
if config.provider == "Azure OpenAI":
|
|||
|
|
url = f"{_build_azure_deployment_base(config.endpoint, config.model)}/chat/completions?api-version={AZURE_API_VERSION}"
|
|||
|
|
payload = {
|
|||
|
|
"messages": messages,
|
|||
|
|
"max_tokens": max_tokens,
|
|||
|
|
"temperature": temperature,
|
|||
|
|
}
|
|||
|
|
status_code, body = _send_json_request(
|
|||
|
|
"POST",
|
|||
|
|
url,
|
|||
|
|
headers=_build_headers(config.api_key, use_bearer=False, use_api_key=True),
|
|||
|
|
payload=payload,
|
|||
|
|
timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS,
|
|||
|
|
)
|
|||
|
|
elif config.provider == "Ollama":
|
|||
|
|
url = _ensure_path(_normalize_endpoint(config.endpoint), "api/chat")
|
|||
|
|
payload = {
|
|||
|
|
"model": config.model,
|
|||
|
|
"messages": messages,
|
|||
|
|
"stream": False,
|
|||
|
|
"options": {
|
|||
|
|
"num_predict": max_tokens,
|
|||
|
|
"temperature": temperature,
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
status_code, body = _send_json_request(
|
|||
|
|
"POST",
|
|||
|
|
url,
|
|||
|
|
headers={"Content-Type": "application/json", "Accept": "application/json"},
|
|||
|
|
payload=payload,
|
|||
|
|
timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS,
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
url = _ensure_path(_normalize_endpoint(config.endpoint), "chat/completions")
|
|||
|
|
payload = {
|
|||
|
|
"model": config.model,
|
|||
|
|
"messages": messages,
|
|||
|
|
"max_tokens": max_tokens,
|
|||
|
|
"temperature": temperature,
|
|||
|
|
}
|
|||
|
|
status_code, body = _send_json_request(
|
|||
|
|
"POST",
|
|||
|
|
url,
|
|||
|
|
headers=_build_headers(config.api_key, use_bearer=True),
|
|||
|
|
payload=payload,
|
|||
|
|
timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if status_code >= HTTPStatus.BAD_REQUEST:
|
|||
|
|
raise KnowledgeRagError(f"知识模型返回异常状态码 {status_code}。")
|
|||
|
|
|
|||
|
|
return _extract_chat_text(body, provider=config.provider)
|
|||
|
|
|
|||
|
|
def _request_embeddings(self, config: RuntimeModelConfig, texts: list[str]) -> list[list[float]]:
|
|||
|
|
if config.provider == "Azure OpenAI":
|
|||
|
|
url = f"{_build_azure_deployment_base(config.endpoint, config.model)}/embeddings?api-version={AZURE_API_VERSION}"
|
|||
|
|
payload = {"input": texts}
|
|||
|
|
status_code, body = _send_json_request(
|
|||
|
|
"POST",
|
|||
|
|
url,
|
|||
|
|
headers=_build_headers(config.api_key, use_bearer=False, use_api_key=True),
|
|||
|
|
payload=payload,
|
|||
|
|
timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
|
|||
|
|
)
|
|||
|
|
elif config.provider == "Ollama":
|
|||
|
|
url = _ensure_path(_normalize_endpoint(config.endpoint), "api/embed")
|
|||
|
|
payload = {"model": config.model, "input": texts}
|
|||
|
|
status_code, body = _send_json_request(
|
|||
|
|
"POST",
|
|||
|
|
url,
|
|||
|
|
headers={"Content-Type": "application/json", "Accept": "application/json"},
|
|||
|
|
payload=payload,
|
|||
|
|
timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
url = _ensure_path(_normalize_endpoint(config.endpoint), "embeddings")
|
|||
|
|
payload = {"model": config.model, "input": texts}
|
|||
|
|
status_code, body = _send_json_request(
|
|||
|
|
"POST",
|
|||
|
|
url,
|
|||
|
|
headers=_build_headers(config.api_key, use_bearer=True),
|
|||
|
|
payload=payload,
|
|||
|
|
timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if status_code >= HTTPStatus.BAD_REQUEST:
|
|||
|
|
raise KnowledgeRagError(f"embedding 模型返回异常状态码 {status_code}。")
|
|||
|
|
|
|||
|
|
return _extract_embedding_vectors(body, provider=config.provider)
|
|||
|
|
|
|||
|
|
def _request_rerank(
|
|||
|
|
self,
|
|||
|
|
config: RuntimeModelConfig,
|
|||
|
|
*,
|
|||
|
|
query: str,
|
|||
|
|
documents: list[str],
|
|||
|
|
top_n: int | None,
|
|||
|
|
) -> tuple[int, Any]:
|
|||
|
|
if config.provider == "Azure OpenAI":
|
|||
|
|
url = f"{_build_azure_deployment_base(config.endpoint, config.model)}/rerank?api-version={AZURE_API_VERSION}"
|
|||
|
|
payload: dict[str, Any] = {
|
|||
|
|
"query": query,
|
|||
|
|
"documents": documents,
|
|||
|
|
}
|
|||
|
|
if top_n is not None:
|
|||
|
|
payload["top_n"] = top_n
|
|||
|
|
return _send_json_request(
|
|||
|
|
"POST",
|
|||
|
|
url,
|
|||
|
|
headers=_build_headers(config.api_key, use_bearer=False, use_api_key=True),
|
|||
|
|
payload=payload,
|
|||
|
|
timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if config.provider == "Ali":
|
|||
|
|
url, payload = _build_ali_rerank_request(
|
|||
|
|
config.model,
|
|||
|
|
query=query,
|
|||
|
|
documents=documents,
|
|||
|
|
top_n=top_n,
|
|||
|
|
)
|
|||
|
|
return _send_json_request(
|
|||
|
|
"POST",
|
|||
|
|
url,
|
|||
|
|
headers=_build_headers(config.api_key, use_bearer=True),
|
|||
|
|
payload=payload,
|
|||
|
|
timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
url = _ensure_path(_normalize_endpoint(config.endpoint), "rerank")
|
|||
|
|
payload = {
|
|||
|
|
"model": config.model,
|
|||
|
|
"query": query,
|
|||
|
|
"documents": documents,
|
|||
|
|
}
|
|||
|
|
if top_n is not None:
|
|||
|
|
payload["top_n"] = top_n
|
|||
|
|
return _send_json_request(
|
|||
|
|
"POST",
|
|||
|
|
url,
|
|||
|
|
headers=_build_headers(config.api_key, use_bearer=True),
|
|||
|
|
payload=payload,
|
|||
|
|
timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class KnowledgeRagService:
|
|||
|
|
def __init__(self, db: Session | None = None, storage_root: Path | None = None) -> None:
|
|||
|
|
self.db = db
|
|||
|
|
self.storage_root = Path(storage_root or get_settings().resolved_storage_root_dir)
|
|||
|
|
|
|||
|
|
def query_knowledge(
|
|||
|
|
self,
|
|||
|
|
query: str,
|
|||
|
|
*,
|
|||
|
|
conversation_history: list[dict[str, str]] | None = None,
|
|||
|
|
limit: int = 5,
|
|||
|
|
) -> dict[str, Any]:
|
|||
|
|
normalized_query = str(query or "").strip()
|
|||
|
|
if not normalized_query:
|
|||
|
|
return {
|
|||
|
|
"result_type": "knowledge_search",
|
|||
|
|
"query": "",
|
|||
|
|
"record_count": 0,
|
|||
|
|
"hits": [],
|
|||
|
|
"references": [],
|
|||
|
|
"message": "请先输入要检索的知识库问题。",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
runtime = self._get_runtime()
|
|||
|
|
raw = runtime.query_data(normalized_query, conversation_history=conversation_history)
|
|||
|
|
except Exception as exc:
|
|||
|
|
logger.warning("Knowledge query failed: %s", exc)
|
|||
|
|
return {
|
|||
|
|
"result_type": "knowledge_search",
|
|||
|
|
"query": normalized_query,
|
|||
|
|
"record_count": 0,
|
|||
|
|
"hits": [],
|
|||
|
|
"references": [],
|
|||
|
|
"message": f"知识库检索暂不可用:{exc}",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
data = raw.get("data") if isinstance(raw, dict) else {}
|
|||
|
|
chunks = list(data.get("chunks") or []) if isinstance(data, dict) else []
|
|||
|
|
entities = list(data.get("entities") or []) if isinstance(data, dict) else []
|
|||
|
|
references = list(data.get("references") or []) if isinstance(data, dict) else []
|
|||
|
|
hits = self._build_hits_from_query_data(
|
|||
|
|
query=normalized_query,
|
|||
|
|
chunks=chunks,
|
|||
|
|
entities=entities,
|
|||
|
|
limit=limit,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if not hits:
|
|||
|
|
return {
|
|||
|
|
"result_type": "knowledge_search",
|
|||
|
|
"query": normalized_query,
|
|||
|
|
"record_count": 0,
|
|||
|
|
"hits": [],
|
|||
|
|
"references": [],
|
|||
|
|
"raw_references": references,
|
|||
|
|
"message": "当前知识库中没有检索到与本次问题直接匹配的内容。",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"result_type": "knowledge_search",
|
|||
|
|
"query": normalized_query,
|
|||
|
|
"record_count": len(hits),
|
|||
|
|
"hits": hits,
|
|||
|
|
"references": [str(item.get("code") or "").strip() for item in hits if str(item.get("code") or "").strip()],
|
|||
|
|
"raw_references": references,
|
|||
|
|
"metadata": raw.get("metadata") if isinstance(raw, dict) else {},
|
|||
|
|
"message": f"已从知识库中检索到 {len(hits)} 条相关内容。",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def index_documents(
|
|||
|
|
self,
|
|||
|
|
*,
|
|||
|
|
document_ids: list[str],
|
|||
|
|
force: bool = False,
|
|||
|
|
) -> dict[str, Any]:
|
|||
|
|
normalized_ids = [str(item).strip() for item in document_ids if str(item).strip()]
|
|||
|
|
if not normalized_ids:
|
|||
|
|
raise ValueError("没有可供索引的知识文档。")
|
|||
|
|
|
|||
|
|
from app.services.knowledge import KnowledgeService
|
|||
|
|
from app.services.knowledge_normalizer import KnowledgeNormalizationService
|
|||
|
|
|
|||
|
|
knowledge_service = KnowledgeService(storage_root=self.storage_root, db=self.db)
|
|||
|
|
normalization_service = (
|
|||
|
|
KnowledgeNormalizationService(self.db) if self.db is not None else None
|
|||
|
|
)
|
|||
|
|
texts: list[str] = []
|
|||
|
|
file_paths: list[str] = []
|
|||
|
|
|
|||
|
|
runtime = self._get_runtime()
|
|||
|
|
existing_statuses = runtime.get_document_statuses(normalized_ids)
|
|||
|
|
|
|||
|
|
for document_id in normalized_ids:
|
|||
|
|
entry = knowledge_service.get_document_entry(document_id)
|
|||
|
|
if force and document_id in existing_statuses:
|
|||
|
|
try:
|
|||
|
|
runtime.delete_document(document_id)
|
|||
|
|
except Exception as exc:
|
|||
|
|
logger.warning("Delete existing LightRAG document failed doc_id=%s: %s", document_id, exc)
|
|||
|
|
text = knowledge_service.extract_document_text(document_id)
|
|||
|
|
if normalization_service is not None:
|
|||
|
|
text = normalization_service.build_enriched_text(text)
|
|||
|
|
texts.append(text)
|
|||
|
|
file_paths.append(str((knowledge_service.library_root / entry["folder"] / entry["stored_name"]).resolve()))
|
|||
|
|
|
|||
|
|
track_id = runtime.insert_documents(
|
|||
|
|
texts=texts,
|
|||
|
|
document_ids=normalized_ids,
|
|||
|
|
file_paths=file_paths,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
statuses = runtime.get_document_statuses(normalized_ids)
|
|||
|
|
succeeded_document_ids: list[str] = []
|
|||
|
|
failed_documents: list[dict[str, str]] = []
|
|||
|
|
|
|||
|
|
for document_id in normalized_ids:
|
|||
|
|
status_obj = statuses.get(document_id)
|
|||
|
|
status_text = self._status_value(status_obj)
|
|||
|
|
if self.is_query_ready_status(status_obj):
|
|||
|
|
succeeded_document_ids.append(document_id)
|
|||
|
|
continue
|
|||
|
|
failed_documents.append(
|
|||
|
|
{
|
|||
|
|
"document_id": document_id,
|
|||
|
|
"status": status_text or "unknown",
|
|||
|
|
"error": self._status_error(status_obj),
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"track_id": track_id,
|
|||
|
|
"requested_document_ids": normalized_ids,
|
|||
|
|
"succeeded_document_ids": succeeded_document_ids,
|
|||
|
|
"failed_documents": failed_documents,
|
|||
|
|
"status_snapshot": {
|
|||
|
|
document_id: self._serialize_status(status_obj)
|
|||
|
|
for document_id, status_obj in statuses.items()
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def get_document_status_map(self, document_ids: list[str] | None = None) -> dict[str, dict[str, Any]]:
|
|||
|
|
target_ids = [str(item).strip() for item in document_ids or [] if str(item).strip()]
|
|||
|
|
if not target_ids:
|
|||
|
|
return {}
|
|||
|
|
try:
|
|||
|
|
statuses = self._get_runtime().get_document_statuses(target_ids)
|
|||
|
|
except Exception as exc:
|
|||
|
|
logger.warning("Load LightRAG document statuses failed: %s", exc)
|
|||
|
|
return {}
|
|||
|
|
return {
|
|||
|
|
document_id: self._serialize_status(status_obj)
|
|||
|
|
for document_id, status_obj in statuses.items()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def delete_document(self, document_id: str) -> None:
|
|||
|
|
normalized_id = str(document_id or "").strip()
|
|||
|
|
if not normalized_id:
|
|||
|
|
return
|
|||
|
|
try:
|
|||
|
|
self._get_runtime().delete_document(normalized_id)
|
|||
|
|
except Exception as exc:
|
|||
|
|
logger.warning("Delete LightRAG document ignored doc_id=%s: %s", normalized_id, exc)
|
|||
|
|
|
|||
|
|
def _get_runtime(self) -> _LightRagRuntime:
|
|||
|
|
global _runtime_instance, _runtime_signature
|
|||
|
|
|
|||
|
|
signature, runtime_kwargs = self._build_runtime_signature()
|
|||
|
|
with _runtime_lock:
|
|||
|
|
if _runtime_instance is not None and _runtime_signature == signature:
|
|||
|
|
return _runtime_instance
|
|||
|
|
|
|||
|
|
if _runtime_instance is not None:
|
|||
|
|
try:
|
|||
|
|
_runtime_instance.finalize()
|
|||
|
|
except Exception as exc: # pragma: no cover - best effort cleanup
|
|||
|
|
logger.warning("Finalize previous LightRAG runtime failed: %s", exc)
|
|||
|
|
|
|||
|
|
_runtime_instance = _LightRagRuntime(**runtime_kwargs)
|
|||
|
|
_runtime_signature = signature
|
|||
|
|
return _runtime_instance
|
|||
|
|
|
|||
|
|
def _build_runtime_signature(self) -> tuple[tuple[Any, ...], dict[str, Any]]:
|
|||
|
|
configs = self._load_runtime_configs()
|
|||
|
|
settings = get_settings()
|
|||
|
|
working_dir = (self.storage_root / "knowledge" / ".lightrag").resolve()
|
|||
|
|
workspace = os.environ.get("LIGHTRAG_WORKSPACE", DEFAULT_LIGHTRAG_WORKSPACE).strip() or DEFAULT_LIGHTRAG_WORKSPACE
|
|||
|
|
qdrant_url = os.environ.get("QDRANT_URL", "").strip() or _resolve_default_qdrant_url()
|
|||
|
|
qdrant_api_key = os.environ.get("QDRANT_API_KEY", "").strip()
|
|||
|
|
|
|||
|
|
signature = (
|
|||
|
|
str(working_dir),
|
|||
|
|
workspace,
|
|||
|
|
qdrant_url,
|
|||
|
|
qdrant_api_key,
|
|||
|
|
configs["main"].provider,
|
|||
|
|
configs["main"].model,
|
|||
|
|
configs["main"].endpoint,
|
|||
|
|
configs["main"].api_key,
|
|||
|
|
configs["backup"].provider if configs["backup"] else "",
|
|||
|
|
configs["backup"].model if configs["backup"] else "",
|
|||
|
|
configs["backup"].endpoint if configs["backup"] else "",
|
|||
|
|
configs["backup"].api_key if configs["backup"] else "",
|
|||
|
|
configs["embedding"].provider,
|
|||
|
|
configs["embedding"].model,
|
|||
|
|
configs["embedding"].endpoint,
|
|||
|
|
configs["embedding"].api_key,
|
|||
|
|
configs["reranker"].provider if configs["reranker"] else "",
|
|||
|
|
configs["reranker"].model if configs["reranker"] else "",
|
|||
|
|
configs["reranker"].endpoint if configs["reranker"] else "",
|
|||
|
|
configs["reranker"].api_key if configs["reranker"] else "",
|
|||
|
|
str(settings.resolved_storage_root_dir),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return signature, {
|
|||
|
|
"working_dir": working_dir,
|
|||
|
|
"workspace": workspace,
|
|||
|
|
"qdrant_url": qdrant_url,
|
|||
|
|
"qdrant_api_key": qdrant_api_key,
|
|||
|
|
"primary_chat": configs["main"],
|
|||
|
|
"backup_chat": configs["backup"],
|
|||
|
|
"embedding": configs["embedding"],
|
|||
|
|
"reranker": configs["reranker"],
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def _load_runtime_configs(self) -> dict[str, RuntimeModelConfig | None]:
|
|||
|
|
owned_session = False
|
|||
|
|
session = self.db
|
|||
|
|
if session is None:
|
|||
|
|
session = get_session_factory()()
|
|||
|
|
owned_session = True
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
settings_service = SettingsService(session)
|
|||
|
|
main = self._normalize_runtime_model(settings_service.get_runtime_model_config("main"))
|
|||
|
|
embedding = self._normalize_runtime_model(settings_service.get_runtime_model_config("embedding"))
|
|||
|
|
try:
|
|||
|
|
backup_raw = settings_service.get_runtime_model_config("backup")
|
|||
|
|
backup = self._normalize_runtime_model(backup_raw)
|
|||
|
|
except Exception:
|
|||
|
|
backup = None
|
|||
|
|
try:
|
|||
|
|
reranker_raw = settings_service.get_runtime_model_config("reranker")
|
|||
|
|
reranker = self._normalize_runtime_model(reranker_raw)
|
|||
|
|
except Exception:
|
|||
|
|
reranker = None
|
|||
|
|
if backup is not None and (
|
|||
|
|
not backup.endpoint
|
|||
|
|
or not backup.model
|
|||
|
|
or (backup.provider != "Ollama" and not backup.api_key)
|
|||
|
|
):
|
|||
|
|
backup = None
|
|||
|
|
if reranker is not None and (
|
|||
|
|
not reranker.endpoint
|
|||
|
|
or not reranker.model
|
|||
|
|
or (reranker.provider != "Ollama" and not reranker.api_key)
|
|||
|
|
):
|
|||
|
|
reranker = None
|
|||
|
|
if not main.endpoint or not main.model:
|
|||
|
|
raise KnowledgeRagError("主对话模型未配置,无法初始化 LightRAG。")
|
|||
|
|
if main.provider != "Ollama" and not main.api_key:
|
|||
|
|
raise KnowledgeRagError("主对话模型缺少 API Key,无法初始化 LightRAG。")
|
|||
|
|
if not embedding.endpoint or not embedding.model:
|
|||
|
|
raise KnowledgeRagError("Embedding 模型未配置,无法初始化 LightRAG。")
|
|||
|
|
if embedding.provider != "Ollama" and not embedding.api_key:
|
|||
|
|
raise KnowledgeRagError("Embedding 模型缺少 API Key,无法初始化 LightRAG。")
|
|||
|
|
return {
|
|||
|
|
"main": main,
|
|||
|
|
"backup": backup,
|
|||
|
|
"embedding": embedding,
|
|||
|
|
"reranker": reranker,
|
|||
|
|
}
|
|||
|
|
finally:
|
|||
|
|
if owned_session and session is not None:
|
|||
|
|
session.close()
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _normalize_runtime_model(payload: dict[str, str]) -> RuntimeModelConfig:
|
|||
|
|
return RuntimeModelConfig(
|
|||
|
|
slot=str(payload.get("slot") or "").strip(),
|
|||
|
|
provider=str(payload.get("provider") or "").strip(),
|
|||
|
|
model=str(payload.get("model") or "").strip(),
|
|||
|
|
endpoint=str(payload.get("endpoint") or "").strip(),
|
|||
|
|
api_key=str(payload.get("apiKey") or "").strip(),
|
|||
|
|
capability=str(payload.get("capability") or "").strip(),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _build_hits_from_query_data(
|
|||
|
|
*,
|
|||
|
|
query: str,
|
|||
|
|
chunks: list[dict[str, Any]],
|
|||
|
|
entities: list[dict[str, Any]],
|
|||
|
|
limit: int,
|
|||
|
|
) -> list[dict[str, Any]]:
|
|||
|
|
entity_tags_by_path: dict[str, list[str]] = {}
|
|||
|
|
|
|||
|
|
for entity in entities:
|
|||
|
|
if not isinstance(entity, dict):
|
|||
|
|
continue
|
|||
|
|
file_path = str(entity.get("file_path") or "").strip()
|
|||
|
|
entity_name = str(entity.get("entity_name") or "").strip()
|
|||
|
|
if not file_path or not entity_name:
|
|||
|
|
continue
|
|||
|
|
entity_tags_by_path.setdefault(file_path, [])
|
|||
|
|
if entity_name not in entity_tags_by_path[file_path]:
|
|||
|
|
entity_tags_by_path[file_path].append(entity_name)
|
|||
|
|
|
|||
|
|
query_terms = _extract_query_terms(query)
|
|||
|
|
prefers_tabular_evidence = any(hint in query for hint in TABLE_OR_STANDARD_QUERY_HINTS)
|
|||
|
|
candidates: list[dict[str, Any]] = []
|
|||
|
|
for rank, chunk in enumerate(chunks, start=1):
|
|||
|
|
if not isinstance(chunk, dict):
|
|||
|
|
continue
|
|||
|
|
file_path = str(chunk.get("file_path") or "").strip()
|
|||
|
|
chunk_id = str(chunk.get("chunk_id") or "").strip()
|
|||
|
|
content = str(chunk.get("content") or "").strip()
|
|||
|
|
if not file_path or not content:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
document_id, document_name = _parse_document_identity(file_path)
|
|||
|
|
normalized_chunk_id = chunk_id or f"path-{rank}"
|
|||
|
|
normalized_content = _truncate_text(content, max_length=MAX_KNOWLEDGE_HIT_CONTENT_LENGTH)
|
|||
|
|
excerpt = _build_excerpt(normalized_content, max_length=220)
|
|||
|
|
candidates.append(
|
|||
|
|
{
|
|||
|
|
"code": f"knowledge.{document_id or 'unknown'}.{normalized_chunk_id}",
|
|||
|
|
"candidate_id": normalized_chunk_id,
|
|||
|
|
"title": document_name or "知识库文档",
|
|||
|
|
"content": normalized_content,
|
|||
|
|
"excerpt": excerpt,
|
|||
|
|
"document_id": document_id,
|
|||
|
|
"document_name": document_name or Path(file_path).name,
|
|||
|
|
"version": None,
|
|||
|
|
"updated_at": None,
|
|||
|
|
"score": max(1, 100 - rank),
|
|||
|
|
"tags": entity_tags_by_path.get(file_path, [])[:5],
|
|||
|
|
"evidence": [normalized_chunk_id],
|
|||
|
|
"file_path": file_path,
|
|||
|
|
"_rank": rank,
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
ranked = sorted(
|
|||
|
|
candidates,
|
|||
|
|
key=lambda item: (
|
|||
|
|
_score_knowledge_hit(
|
|||
|
|
item,
|
|||
|
|
query_terms=query_terms,
|
|||
|
|
prefers_tabular_evidence=prefers_tabular_evidence,
|
|||
|
|
),
|
|||
|
|
-int(item.get("_rank") or 0),
|
|||
|
|
),
|
|||
|
|
reverse=True,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
hits: list[dict[str, Any]] = []
|
|||
|
|
for item in ranked[: max(1, limit)]:
|
|||
|
|
normalized = dict(item)
|
|||
|
|
normalized.pop("_rank", None)
|
|||
|
|
hits.append(normalized)
|
|||
|
|
return hits
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _serialize_status(status_obj: Any) -> dict[str, Any]:
|
|||
|
|
if status_obj is None:
|
|||
|
|
return {}
|
|||
|
|
if hasattr(status_obj, "__dict__"):
|
|||
|
|
payload = dict(status_obj.__dict__)
|
|||
|
|
elif isinstance(status_obj, dict):
|
|||
|
|
payload = dict(status_obj)
|
|||
|
|
else:
|
|||
|
|
payload = {}
|
|||
|
|
payload["status"] = KnowledgeRagService._status_value(status_obj)
|
|||
|
|
payload["error_msg"] = KnowledgeRagService._status_error(status_obj)
|
|||
|
|
payload["query_ready"] = KnowledgeRagService.is_query_ready_status(status_obj)
|
|||
|
|
return payload
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _status_value(status_obj: Any) -> str:
|
|||
|
|
raw_status = getattr(status_obj, "status", None)
|
|||
|
|
if raw_status is None and isinstance(status_obj, dict):
|
|||
|
|
raw_status = status_obj.get("status")
|
|||
|
|
normalized = str(raw_status or "").strip().lower()
|
|||
|
|
if "." in normalized:
|
|||
|
|
normalized = normalized.split(".")[-1].strip()
|
|||
|
|
if ":" in normalized and normalized.endswith(">"):
|
|||
|
|
normalized = normalized.split(":")[0].strip("<> '\"")
|
|||
|
|
return normalized
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _status_error(status_obj: Any) -> str:
|
|||
|
|
value = getattr(status_obj, "error_msg", None)
|
|||
|
|
if value is None and isinstance(status_obj, dict):
|
|||
|
|
value = status_obj.get("error_msg")
|
|||
|
|
return str(value or "").strip()
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def is_query_ready_status(status_obj: Any) -> bool:
|
|||
|
|
status_text = KnowledgeRagService._status_value(status_obj)
|
|||
|
|
if status_text == "processed":
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
chunks_count = getattr(status_obj, "chunks_count", None)
|
|||
|
|
if chunks_count is None and isinstance(status_obj, dict):
|
|||
|
|
chunks_count = status_obj.get("chunks_count")
|
|||
|
|
try:
|
|||
|
|
if int(chunks_count or 0) > 0:
|
|||
|
|
return True
|
|||
|
|
except (TypeError, ValueError):
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
chunks_list = getattr(status_obj, "chunks_list", None)
|
|||
|
|
if chunks_list is None and isinstance(status_obj, dict):
|
|||
|
|
chunks_list = status_obj.get("chunks_list")
|
|||
|
|
return bool(chunks_list)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def shutdown_knowledge_rag_runtime() -> None:
|
|||
|
|
global _runtime_instance, _runtime_signature
|
|||
|
|
|
|||
|
|
with _runtime_lock:
|
|||
|
|
if _runtime_instance is None:
|
|||
|
|
return
|
|||
|
|
try:
|
|||
|
|
_runtime_instance.finalize()
|
|||
|
|
except Exception as exc: # pragma: no cover - best effort cleanup
|
|||
|
|
logger.warning("Finalize LightRAG runtime failed during shutdown: %s", exc)
|
|||
|
|
_runtime_instance = None
|
|||
|
|
_runtime_signature = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _normalize_endpoint(endpoint: str) -> str:
|
|||
|
|
normalized = str(endpoint or "").strip()
|
|||
|
|
if not normalized:
|
|||
|
|
raise KnowledgeRagError("模型 endpoint 不能为空。")
|
|||
|
|
return normalized.rstrip("/")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _ensure_path(endpoint: str, suffix: str) -> str:
|
|||
|
|
suffix = suffix.lstrip("/")
|
|||
|
|
if endpoint.endswith(suffix):
|
|||
|
|
return endpoint
|
|||
|
|
return f"{endpoint}/{suffix}"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _build_azure_deployment_base(endpoint: str, model: str) -> str:
|
|||
|
|
normalized_endpoint = _normalize_endpoint(endpoint)
|
|||
|
|
quoted_model = quote(model, safe="")
|
|||
|
|
if "/openai/deployments/" in normalized_endpoint:
|
|||
|
|
return normalized_endpoint
|
|||
|
|
if "/openai/v1" in normalized_endpoint:
|
|||
|
|
resource_root = normalized_endpoint.split("/openai/v1", maxsplit=1)[0]
|
|||
|
|
return f"{resource_root}/openai/deployments/{quoted_model}"
|
|||
|
|
if normalized_endpoint.endswith("/openai"):
|
|||
|
|
return f"{normalized_endpoint}/deployments/{quoted_model}"
|
|||
|
|
return f"{normalized_endpoint}/openai/deployments/{quoted_model}"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _build_headers(
|
|||
|
|
api_key: str,
|
|||
|
|
*,
|
|||
|
|
use_bearer: bool,
|
|||
|
|
use_api_key: bool = False,
|
|||
|
|
) -> dict[str, str]:
|
|||
|
|
headers = {
|
|||
|
|
"Content-Type": "application/json",
|
|||
|
|
"Accept": "application/json",
|
|||
|
|
}
|
|||
|
|
normalized_key = str(api_key or "").strip()
|
|||
|
|
if normalized_key:
|
|||
|
|
if use_api_key:
|
|||
|
|
headers["api-key"] = normalized_key
|
|||
|
|
elif use_bearer:
|
|||
|
|
headers["Authorization"] = f"Bearer {normalized_key}"
|
|||
|
|
return headers
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _send_json_request(
|
|||
|
|
method: str,
|
|||
|
|
url: str,
|
|||
|
|
*,
|
|||
|
|
headers: dict[str, str],
|
|||
|
|
payload: dict[str, Any],
|
|||
|
|
timeout_seconds: int,
|
|||
|
|
) -> tuple[int, Any]:
|
|||
|
|
data = json.dumps(payload).encode("utf-8")
|
|||
|
|
request = Request(url=url, data=data, headers=headers, method=method)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
with urlopen(request, timeout=timeout_seconds) as response: # noqa: S310
|
|||
|
|
body = response.read().decode("utf-8") if response.length != 0 else ""
|
|||
|
|
return response.status, _parse_json_body(body)
|
|||
|
|
except HTTPError as exc: # pragma: no cover - runtime path
|
|||
|
|
body = exc.read().decode("utf-8", errors="ignore")
|
|||
|
|
detail = _extract_error_message(_parse_json_body(body)) or f"接口返回 {exc.code}"
|
|||
|
|
raise KnowledgeRagError(detail) from exc
|
|||
|
|
except URLError as exc: # pragma: no cover - runtime path
|
|||
|
|
raise KnowledgeRagError(f"无法连接模型接口:{getattr(exc, 'reason', exc)}") from exc
|
|||
|
|
except TimeoutError as exc: # pragma: no cover - runtime path
|
|||
|
|
raise KnowledgeRagError("模型接口调用超时。") from exc
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _parse_json_body(body: str) -> Any:
|
|||
|
|
if not body:
|
|||
|
|
return None
|
|||
|
|
try:
|
|||
|
|
return json.loads(body)
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
return {"message": body}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _extract_error_message(payload: Any) -> str | None:
|
|||
|
|
if payload is None:
|
|||
|
|
return None
|
|||
|
|
if isinstance(payload, dict):
|
|||
|
|
if isinstance(payload.get("detail"), str):
|
|||
|
|
return payload["detail"]
|
|||
|
|
if isinstance(payload.get("message"), str):
|
|||
|
|
return payload["message"]
|
|||
|
|
error_payload = payload.get("error")
|
|||
|
|
if isinstance(error_payload, dict) and isinstance(error_payload.get("message"), str):
|
|||
|
|
return error_payload["message"]
|
|||
|
|
if isinstance(payload, str):
|
|||
|
|
return payload
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _extract_chat_text(payload: Any, *, provider: str) -> str:
|
|||
|
|
if provider == "Ollama":
|
|||
|
|
message = payload.get("message") if isinstance(payload, dict) else None
|
|||
|
|
if isinstance(message, dict):
|
|||
|
|
return str(message.get("content") or "").strip()
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
if not isinstance(payload, dict):
|
|||
|
|
return ""
|
|||
|
|
choices = payload.get("choices")
|
|||
|
|
if not isinstance(choices, list) or not choices:
|
|||
|
|
return ""
|
|||
|
|
first_choice = choices[0]
|
|||
|
|
if not isinstance(first_choice, dict):
|
|||
|
|
return ""
|
|||
|
|
message = first_choice.get("message")
|
|||
|
|
if isinstance(message, dict):
|
|||
|
|
content = message.get("content")
|
|||
|
|
if isinstance(content, str):
|
|||
|
|
return content.strip()
|
|||
|
|
if isinstance(content, list):
|
|||
|
|
parts: list[str] = []
|
|||
|
|
for item in content:
|
|||
|
|
if isinstance(item, dict) and item.get("type") == "text":
|
|||
|
|
parts.append(str(item.get("text") or "").strip())
|
|||
|
|
return "\n".join(part for part in parts if part).strip()
|
|||
|
|
text = first_choice.get("text")
|
|||
|
|
if isinstance(text, str):
|
|||
|
|
return text.strip()
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _extract_embedding_vectors(payload: Any, *, provider: str) -> list[list[float]]:
|
|||
|
|
if provider == "Ollama":
|
|||
|
|
embeddings = payload.get("embeddings") if isinstance(payload, dict) else None
|
|||
|
|
if isinstance(embeddings, list):
|
|||
|
|
return [[float(value) for value in item] for item in embeddings if isinstance(item, list)]
|
|||
|
|
embedding = payload.get("embedding") if isinstance(payload, dict) else None
|
|||
|
|
if isinstance(embedding, list):
|
|||
|
|
return [[float(value) for value in embedding]]
|
|||
|
|
raise KnowledgeRagError("Ollama embedding 返回格式无法识别。")
|
|||
|
|
|
|||
|
|
if not isinstance(payload, dict):
|
|||
|
|
raise KnowledgeRagError("embedding 接口返回格式无效。")
|
|||
|
|
data = payload.get("data")
|
|||
|
|
if not isinstance(data, list) or not data:
|
|||
|
|
raise KnowledgeRagError("embedding 接口没有返回 data。")
|
|||
|
|
vectors: list[list[float]] = []
|
|||
|
|
for item in data:
|
|||
|
|
if not isinstance(item, dict):
|
|||
|
|
continue
|
|||
|
|
embedding = item.get("embedding")
|
|||
|
|
if isinstance(embedding, list):
|
|||
|
|
vectors.append([float(value) for value in embedding])
|
|||
|
|
if not vectors:
|
|||
|
|
raise KnowledgeRagError("embedding 接口返回中未找到向量数据。")
|
|||
|
|
return vectors
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _build_ali_rerank_request(
|
|||
|
|
model: str,
|
|||
|
|
*,
|
|||
|
|
query: str,
|
|||
|
|
documents: list[str],
|
|||
|
|
top_n: int | None,
|
|||
|
|
) -> tuple[str, dict[str, Any]]:
|
|||
|
|
normalized_model = str(model or "").strip()
|
|||
|
|
if normalized_model == "qwen3-rerank":
|
|||
|
|
payload: dict[str, Any] = {
|
|||
|
|
"model": normalized_model,
|
|||
|
|
"query": query,
|
|||
|
|
"documents": documents,
|
|||
|
|
}
|
|||
|
|
if top_n is not None:
|
|||
|
|
payload["top_n"] = top_n
|
|||
|
|
return "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", payload
|
|||
|
|
|
|||
|
|
payload = {
|
|||
|
|
"model": normalized_model,
|
|||
|
|
"input": {
|
|||
|
|
"query": query,
|
|||
|
|
"documents": documents,
|
|||
|
|
},
|
|||
|
|
"parameters": {
|
|||
|
|
"return_documents": False,
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
if top_n is not None:
|
|||
|
|
payload["parameters"]["top_n"] = top_n
|
|||
|
|
return "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank", payload
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _extract_rerank_results(payload: Any, *, provider: str) -> list[dict[str, Any]]:
|
|||
|
|
if not isinstance(payload, dict):
|
|||
|
|
return []
|
|||
|
|
if provider == "Ali" and isinstance(payload.get("output"), dict):
|
|||
|
|
results = payload["output"].get("results")
|
|||
|
|
else:
|
|||
|
|
results = payload.get("results")
|
|||
|
|
if not isinstance(results, list):
|
|||
|
|
return []
|
|||
|
|
normalized: list[dict[str, Any]] = []
|
|||
|
|
for item in results:
|
|||
|
|
if not isinstance(item, dict):
|
|||
|
|
continue
|
|||
|
|
try:
|
|||
|
|
normalized.append(
|
|||
|
|
{
|
|||
|
|
"index": int(item["index"]),
|
|||
|
|
"relevance_score": float(item["relevance_score"]),
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
except (KeyError, TypeError, ValueError):
|
|||
|
|
continue
|
|||
|
|
return normalized
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _parse_document_identity(file_path: str) -> tuple[str, str]:
|
|||
|
|
path = Path(str(file_path or "").strip())
|
|||
|
|
name = path.name
|
|||
|
|
if "__" not in name:
|
|||
|
|
return "", name
|
|||
|
|
document_id, document_name = name.split("__", maxsplit=1)
|
|||
|
|
return document_id.strip(), document_name.strip()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _build_excerpt(text: str, *, max_length: int = 180) -> str:
|
|||
|
|
normalized = " ".join(str(text or "").split()).strip()
|
|||
|
|
if len(normalized) <= max_length:
|
|||
|
|
return normalized
|
|||
|
|
return f"{normalized[: max_length - 3].rstrip()}..."
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _truncate_text(text: str, *, max_length: int) -> str:
|
|||
|
|
normalized = str(text or "").strip()
|
|||
|
|
if len(normalized) <= max_length:
|
|||
|
|
return normalized
|
|||
|
|
return f"{normalized[: max_length - 3].rstrip()}..."
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _resolve_default_qdrant_url() -> str:
|
|||
|
|
if _hostname_resolves("qdrant"):
|
|||
|
|
return CONTAINER_QDRANT_URL
|
|||
|
|
return DEFAULT_QDRANT_URL
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _hostname_resolves(hostname: str) -> bool:
|
|||
|
|
try:
|
|||
|
|
socket.getaddrinfo(hostname, None)
|
|||
|
|
except OSError:
|
|||
|
|
return False
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _extract_query_terms(query: str) -> list[str]:
|
|||
|
|
normalized_query = str(query or "").strip().lower()
|
|||
|
|
if not normalized_query:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
terms: list[str] = []
|
|||
|
|
seen: set[str] = set()
|
|||
|
|
|
|||
|
|
def remember(term: str) -> None:
|
|||
|
|
normalized_term = str(term or "").strip().lower()
|
|||
|
|
if (
|
|||
|
|
not normalized_term
|
|||
|
|
or normalized_term in seen
|
|||
|
|
or normalized_term in QUERY_TERM_STOPWORDS
|
|||
|
|
or len(normalized_term) < 2
|
|||
|
|
):
|
|||
|
|
return
|
|||
|
|
seen.add(normalized_term)
|
|||
|
|
terms.append(normalized_term)
|
|||
|
|
|
|||
|
|
for item in re.findall(r"[a-z0-9][a-z0-9_\-]{1,}", normalized_query):
|
|||
|
|
remember(item)
|
|||
|
|
|
|||
|
|
for block in re.findall(r"[\u4e00-\u9fff]{2,20}", normalized_query):
|
|||
|
|
if len(block) <= 4:
|
|||
|
|
remember(block)
|
|||
|
|
continue
|
|||
|
|
for size in (4, 3, 2):
|
|||
|
|
for start in range(0, len(block) - size + 1):
|
|||
|
|
remember(block[start : start + size])
|
|||
|
|
if len(terms) >= MAX_QUERY_TERMS:
|
|||
|
|
return terms
|
|||
|
|
|
|||
|
|
return terms[:MAX_QUERY_TERMS]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _score_knowledge_hit(
|
|||
|
|
item: dict[str, Any],
|
|||
|
|
*,
|
|||
|
|
query_terms: list[str],
|
|||
|
|
prefers_tabular_evidence: bool,
|
|||
|
|
) -> int:
|
|||
|
|
rank = max(1, int(item.get("_rank") or 1))
|
|||
|
|
title = str(item.get("title") or item.get("document_name") or "").lower()
|
|||
|
|
content = str(item.get("content") or "").lower()
|
|||
|
|
excerpt = str(item.get("excerpt") or "").lower()
|
|||
|
|
tags = " ".join(str(value).lower() for value in list(item.get("tags") or [])[:5])
|
|||
|
|
haystack = "\n".join([title, excerpt, tags, content[:1200]])
|
|||
|
|
|
|||
|
|
score = max(1, 120 - rank * 4)
|
|||
|
|
matched_terms = [term for term in query_terms if term in haystack]
|
|||
|
|
score += len(matched_terms) * 8
|
|||
|
|
score += sum(1 for term in matched_terms if term in title) * 6
|
|||
|
|
|
|||
|
|
if "结构化表格补充" in content:
|
|||
|
|
score += 18
|
|||
|
|
if "问答线索补充" in content:
|
|||
|
|
score += 16 if not prefers_tabular_evidence else 8
|
|||
|
|
if "重点章节摘录" in content:
|
|||
|
|
score += 10
|
|||
|
|
if "章节导航" in content:
|
|||
|
|
score += 4
|
|||
|
|
if prefers_tabular_evidence and ("|" in content or "表" in content or "结构化表格补充" in content):
|
|||
|
|
score += 12
|
|||
|
|
if not prefers_tabular_evidence and any(marker in content for marker in ("第", "条", ":", "-", "•")):
|
|||
|
|
score += 4
|
|||
|
|
if title and any(term in title for term in query_terms):
|
|||
|
|
score += 6
|
|||
|
|
|
|||
|
|
return score
|