from __future__ import annotations import asyncio import json import os import re import socket import threading from dataclasses import dataclass from datetime import UTC, datetime from functools import partial from http import HTTPStatus from pathlib import Path from time import perf_counter from typing import Any from urllib.error import HTTPError, URLError from urllib.parse import quote from urllib.request import Request, urlopen from sqlalchemy.orm import Session from app.core.config import get_settings from app.core.logging import get_logger from app.db.session import get_session_factory from app.services.settings import SettingsService logger = get_logger("app.services.knowledge_rag") DEFAULT_QDRANT_URL = "http://127.0.0.1:6333" CONTAINER_QDRANT_URL = "http://qdrant:6333" DEFAULT_LIGHTRAG_WORKSPACE = "x_financial_knowledge" DEFAULT_LIGHTRAG_QUERY_MODE = "naive" DEFAULT_LLM_TIMEOUT_SECONDS = 180 DEFAULT_EMBEDDING_TIMEOUT_SECONDS = 120 MAX_KNOWLEDGE_HIT_CONTENT_LENGTH = 2200 MAX_QUERY_TERMS = 12 QUERY_TERM_STOPWORDS = { "什么", "多少", "哪些", "怎么", "如何", "请问", "一下", "关于", "规定", "标准", "可以", "是否", "一个", "哪些人", } TABLE_OR_STANDARD_QUERY_HINTS = ( "标准", "金额", "限额", "补贴", "住宿", "餐费", "交通", "报销", "档位", "额度", ) _runtime_lock = threading.RLock() _runtime_instance: _LightRagRuntime | None = None _runtime_signature: tuple[Any, ...] | None = None class KnowledgeRagError(RuntimeError): pass @dataclass(frozen=True, slots=True) class RuntimeModelConfig: slot: str provider: str model: str endpoint: str api_key: str capability: str class _LightRagRuntime: def __init__( self, *, working_dir: Path, workspace: str, qdrant_url: str, qdrant_api_key: str, primary_chat: RuntimeModelConfig, backup_chat: RuntimeModelConfig | None, embedding: RuntimeModelConfig, reranker: RuntimeModelConfig | None, ) -> None: self.working_dir = working_dir self.workspace = workspace self.qdrant_url = qdrant_url self.qdrant_api_key = qdrant_api_key self.primary_chat = primary_chat self.backup_chat = backup_chat self.embedding = embedding self.reranker = reranker self._rag = self._build_rag() self._initialize() self._graph_has_content_cache: bool | None = None @property def rag(self): return self._rag def _build_rag(self): try: from lightrag import LightRAG from lightrag.utils import EmbeddingFunc except ImportError as exc: # pragma: no cover - exercised in runtime env raise KnowledgeRagError( "LightRAG 依赖未安装,请先在 server 环境执行依赖安装。" ) from exc self.working_dir.mkdir(parents=True, exist_ok=True) if self.qdrant_url: os.environ["QDRANT_URL"] = self.qdrant_url if self.qdrant_api_key: os.environ["QDRANT_API_KEY"] = self.qdrant_api_key embedding_dim = self._probe_embedding_dimension(self.embedding) logger.info( "Initialize LightRAG runtime workspace=%s qdrant=%s embedding_model=%s dim=%s", self.workspace, self.qdrant_url, self.embedding.model, embedding_dim, ) async def embedding_func(texts: list[str]) -> Any: return await asyncio.to_thread(self._embed_sync, texts) async def llm_model_func( prompt: str, system_prompt: str | None = None, history_messages: list[dict[str, Any]] | None = None, keyword_extraction: bool = False, **kwargs: Any, ) -> str: return await asyncio.to_thread( self._complete_sync, prompt, system_prompt, history_messages or [], keyword_extraction, kwargs, ) async def rerank_model_func( query: str, documents: list[str], top_n: int | None = None, **_kwargs: Any, ) -> list[dict[str, Any]]: return await asyncio.to_thread( self._rerank_sync, query, documents, top_n, ) return LightRAG( working_dir=str(self.working_dir), workspace=self.workspace, kv_storage="JsonKVStorage", graph_storage="NetworkXStorage", vector_storage="QdrantVectorDBStorage", doc_status_storage="JsonDocStatusStorage", llm_model_name=self.primary_chat.model, llm_model_func=llm_model_func, embedding_func=EmbeddingFunc( embedding_dim=embedding_dim, func=embedding_func, max_token_size=8192, model_name=self.embedding.model, supports_asymmetric=False, ), rerank_model_func=rerank_model_func if self.reranker is not None else None, enable_llm_cache=False, enable_llm_cache_for_entity_extract=False, ) def _initialize(self) -> None: from lightrag.utils import always_get_an_event_loop loop = always_get_an_event_loop() loop.run_until_complete(self._rag.initialize_storages()) def finalize(self) -> None: from lightrag.utils import always_get_an_event_loop loop = always_get_an_event_loop() loop.run_until_complete(self._rag.finalize_storages()) def query_data(self, query: str, *, conversation_history: list[dict[str, str]] | None = None) -> dict[str, Any]: from lightrag import QueryParam configured_mode = os.environ.get("LIGHTRAG_QUERY_MODE", DEFAULT_LIGHTRAG_QUERY_MODE).strip() or DEFAULT_LIGHTRAG_QUERY_MODE mode = "naive" if configured_mode != "naive" and not self._graph_has_content() else configured_mode started_at = perf_counter() param = QueryParam( mode=mode, top_k=8, chunk_top_k=10, only_need_context=True, response_type="Multiple Paragraphs", conversation_history=conversation_history or [], include_references=True, ) try: result = self._rag.query_data(query, param) logger.info("LightRAG query completed mode=%s elapsed=%.2fs", mode, perf_counter() - started_at) return result except Exception: if mode == "naive": raise logger.warning("LightRAG query mode=%s failed, retry with naive mode", mode) fallback_param = QueryParam( mode="naive", top_k=8, chunk_top_k=10, only_need_context=True, response_type="Multiple Paragraphs", conversation_history=conversation_history or [], include_references=True, ) result = self._rag.query_data(query, fallback_param) logger.info("LightRAG query completed mode=naive elapsed=%.2fs", perf_counter() - started_at) return result def _graph_has_content(self) -> bool: if self._graph_has_content_cache is not None: return self._graph_has_content_cache graph_path = self.working_dir / self.workspace / "graph_chunk_entity_relation.graphml" try: graph_text = graph_path.read_text(encoding="utf-8") except OSError: self._graph_has_content_cache = False return False self._graph_has_content_cache = " str: return self._rag.insert(texts, ids=document_ids, file_paths=file_paths) def get_document_statuses(self, document_ids: list[str]) -> dict[str, Any]: from lightrag.utils import always_get_an_event_loop loop = always_get_an_event_loop() return loop.run_until_complete(self._rag.aget_docs_by_ids(document_ids)) def delete_document(self, document_id: str) -> None: from lightrag.utils import always_get_an_event_loop loop = always_get_an_event_loop() result = loop.run_until_complete(self._rag.adelete_by_doc_id(document_id)) status = str(getattr(result, "status", "") or "") if status not in {"success", "not_found"}: raise KnowledgeRagError(str(getattr(result, "message", "") or "LightRAG 删除文档失败。")) def _probe_embedding_dimension(self, config: RuntimeModelConfig) -> int: vectors = self._request_embeddings(config, ["dimension probe"]) if not vectors or not isinstance(vectors[0], list): raise KnowledgeRagError("无法从 embedding 模型返回结果中解析向量维度。") dimension = len(vectors[0]) if dimension <= 0: raise KnowledgeRagError("embedding 模型返回了无效的向量维度。") return dimension def _embed_sync(self, texts: list[str]) -> Any: import numpy as np vectors = self._request_embeddings(self.embedding, texts) return np.array(vectors, dtype=float) def _rerank_sync( self, query: str, documents: list[str], top_n: int | None, ) -> list[dict[str, Any]]: if self.reranker is None: return [] status_code, body = self._request_rerank( self.reranker, query=query, documents=documents, top_n=top_n, ) if status_code >= HTTPStatus.BAD_REQUEST: raise KnowledgeRagError(f"reranker 模型返回异常状态码 {status_code}。") return _extract_rerank_results(body, provider=self.reranker.provider) def _complete_sync( self, prompt: str, system_prompt: str | None, history_messages: list[dict[str, Any]], keyword_extraction: bool, kwargs: dict[str, Any], ) -> str: del keyword_extraction last_error: Exception | None = None for config in [self.primary_chat, self.backup_chat]: if config is None: continue try: return self._request_chat_completion( config, prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, max_tokens=int(kwargs.get("max_tokens") or 1200), temperature=float(kwargs.get("temperature") or 0.1), ) except Exception as exc: # pragma: no cover - runtime fallback last_error = exc logger.warning( "LightRAG LLM request failed slot=%s provider=%s model=%s: %s", config.slot, config.provider, config.model, exc, ) continue raise KnowledgeRagError(f"LightRAG 调用知识模型失败:{last_error or '没有可用模型配置'}") def _request_chat_completion( self, config: RuntimeModelConfig, *, prompt: str, system_prompt: str | None, history_messages: list[dict[str, Any]], max_tokens: int, temperature: float, ) -> str: messages: list[dict[str, Any]] = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) if config.provider == "Azure OpenAI": url = f"{_build_azure_deployment_base(config.endpoint, config.model)}/chat/completions?api-version={AZURE_API_VERSION}" payload = { "messages": messages, "max_tokens": max_tokens, "temperature": temperature, } status_code, body = _send_json_request( "POST", url, headers=_build_headers(config.api_key, use_bearer=False, use_api_key=True), payload=payload, timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS, ) elif config.provider == "Ollama": url = _ensure_path(_normalize_endpoint(config.endpoint), "api/chat") payload = { "model": config.model, "messages": messages, "stream": False, "options": { "num_predict": max_tokens, "temperature": temperature, }, } status_code, body = _send_json_request( "POST", url, headers={"Content-Type": "application/json", "Accept": "application/json"}, payload=payload, timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS, ) else: url = _ensure_path(_normalize_endpoint(config.endpoint), "chat/completions") payload = { "model": config.model, "messages": messages, "max_tokens": max_tokens, "temperature": temperature, } status_code, body = _send_json_request( "POST", url, headers=_build_headers(config.api_key, use_bearer=True), payload=payload, timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS, ) if status_code >= HTTPStatus.BAD_REQUEST: raise KnowledgeRagError(f"知识模型返回异常状态码 {status_code}。") return _extract_chat_text(body, provider=config.provider) def _request_embeddings(self, config: RuntimeModelConfig, texts: list[str]) -> list[list[float]]: if config.provider == "Azure OpenAI": url = f"{_build_azure_deployment_base(config.endpoint, config.model)}/embeddings?api-version={AZURE_API_VERSION}" payload = {"input": texts} status_code, body = _send_json_request( "POST", url, headers=_build_headers(config.api_key, use_bearer=False, use_api_key=True), payload=payload, timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS, ) elif config.provider == "Ollama": url = _ensure_path(_normalize_endpoint(config.endpoint), "api/embed") payload = {"model": config.model, "input": texts} status_code, body = _send_json_request( "POST", url, headers={"Content-Type": "application/json", "Accept": "application/json"}, payload=payload, timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS, ) else: url = _ensure_path(_normalize_endpoint(config.endpoint), "embeddings") payload = {"model": config.model, "input": texts} status_code, body = _send_json_request( "POST", url, headers=_build_headers(config.api_key, use_bearer=True), payload=payload, timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS, ) if status_code >= HTTPStatus.BAD_REQUEST: raise KnowledgeRagError(f"embedding 模型返回异常状态码 {status_code}。") return _extract_embedding_vectors(body, provider=config.provider) def _request_rerank( self, config: RuntimeModelConfig, *, query: str, documents: list[str], top_n: int | None, ) -> tuple[int, Any]: if config.provider == "Azure OpenAI": url = f"{_build_azure_deployment_base(config.endpoint, config.model)}/rerank?api-version={AZURE_API_VERSION}" payload: dict[str, Any] = { "query": query, "documents": documents, } if top_n is not None: payload["top_n"] = top_n return _send_json_request( "POST", url, headers=_build_headers(config.api_key, use_bearer=False, use_api_key=True), payload=payload, timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS, ) if config.provider == "Ali": url, payload = _build_ali_rerank_request( config.model, query=query, documents=documents, top_n=top_n, ) return _send_json_request( "POST", url, headers=_build_headers(config.api_key, use_bearer=True), payload=payload, timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS, ) url = _ensure_path(_normalize_endpoint(config.endpoint), "rerank") payload = { "model": config.model, "query": query, "documents": documents, } if top_n is not None: payload["top_n"] = top_n return _send_json_request( "POST", url, headers=_build_headers(config.api_key, use_bearer=True), payload=payload, timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS, ) class KnowledgeRagService: def __init__(self, db: Session | None = None, storage_root: Path | None = None) -> None: self.db = db self.storage_root = Path(storage_root or get_settings().resolved_storage_root_dir) def query_knowledge( self, query: str, *, conversation_history: list[dict[str, str]] | None = None, limit: int = 5, ) -> dict[str, Any]: normalized_query = str(query or "").strip() if not normalized_query: return { "result_type": "knowledge_search", "query": "", "record_count": 0, "hits": [], "references": [], "message": "请先输入要检索的知识库问题。", } try: runtime = self._get_runtime() raw = runtime.query_data(normalized_query, conversation_history=conversation_history) except Exception as exc: logger.warning("Knowledge query failed: %s", exc) return { "result_type": "knowledge_search", "query": normalized_query, "record_count": 0, "hits": [], "references": [], "message": f"知识库检索暂不可用:{exc}", } data = raw.get("data") if isinstance(raw, dict) else {} chunks = list(data.get("chunks") or []) if isinstance(data, dict) else [] entities = list(data.get("entities") or []) if isinstance(data, dict) else [] references = list(data.get("references") or []) if isinstance(data, dict) else [] hits = self._build_hits_from_query_data( query=normalized_query, chunks=chunks, entities=entities, limit=limit, ) if not hits: return { "result_type": "knowledge_search", "query": normalized_query, "record_count": 0, "hits": [], "references": [], "raw_references": references, "message": "当前知识库中没有检索到与本次问题直接匹配的内容。", } return { "result_type": "knowledge_search", "query": normalized_query, "record_count": len(hits), "hits": hits, "references": [str(item.get("code") or "").strip() for item in hits if str(item.get("code") or "").strip()], "raw_references": references, "metadata": raw.get("metadata") if isinstance(raw, dict) else {}, "message": f"已从知识库中检索到 {len(hits)} 条相关内容。", } def index_documents( self, *, document_ids: list[str], force: bool = False, ) -> dict[str, Any]: normalized_ids = [str(item).strip() for item in document_ids if str(item).strip()] if not normalized_ids: raise ValueError("没有可供索引的知识文档。") from app.services.knowledge import KnowledgeService from app.services.knowledge_normalizer import KnowledgeNormalizationService knowledge_service = KnowledgeService(storage_root=self.storage_root, db=self.db) normalization_service = ( KnowledgeNormalizationService(self.db) if self.db is not None else None ) texts: list[str] = [] file_paths: list[str] = [] runtime = self._get_runtime() existing_statuses = runtime.get_document_statuses(normalized_ids) for document_id in normalized_ids: entry = knowledge_service.get_document_entry(document_id) if force and document_id in existing_statuses: try: runtime.delete_document(document_id) except Exception as exc: logger.warning("Delete existing LightRAG document failed doc_id=%s: %s", document_id, exc) text = knowledge_service.extract_document_text(document_id) if normalization_service is not None: text = normalization_service.build_enriched_text(text) texts.append(text) file_paths.append(str((knowledge_service.library_root / entry["folder"] / entry["stored_name"]).resolve())) track_id = runtime.insert_documents( texts=texts, document_ids=normalized_ids, file_paths=file_paths, ) statuses = runtime.get_document_statuses(normalized_ids) succeeded_document_ids: list[str] = [] failed_documents: list[dict[str, str]] = [] for document_id in normalized_ids: status_obj = statuses.get(document_id) status_text = self._status_value(status_obj) if self.is_query_ready_status(status_obj): succeeded_document_ids.append(document_id) continue failed_documents.append( { "document_id": document_id, "status": status_text or "unknown", "error": self._status_error(status_obj), } ) return { "track_id": track_id, "requested_document_ids": normalized_ids, "succeeded_document_ids": succeeded_document_ids, "failed_documents": failed_documents, "status_snapshot": { document_id: self._serialize_status(status_obj) for document_id, status_obj in statuses.items() }, } def get_document_status_map(self, document_ids: list[str] | None = None) -> dict[str, dict[str, Any]]: target_ids = [str(item).strip() for item in document_ids or [] if str(item).strip()] if not target_ids: return {} try: statuses = self._get_runtime().get_document_statuses(target_ids) except Exception as exc: logger.warning("Load LightRAG document statuses failed: %s", exc) return {} return { document_id: self._serialize_status(status_obj) for document_id, status_obj in statuses.items() } def delete_document(self, document_id: str) -> None: normalized_id = str(document_id or "").strip() if not normalized_id: return try: self._get_runtime().delete_document(normalized_id) except Exception as exc: logger.warning("Delete LightRAG document ignored doc_id=%s: %s", normalized_id, exc) def _get_runtime(self) -> _LightRagRuntime: global _runtime_instance, _runtime_signature signature, runtime_kwargs = self._build_runtime_signature() with _runtime_lock: if _runtime_instance is not None and _runtime_signature == signature: return _runtime_instance if _runtime_instance is not None: try: _runtime_instance.finalize() except Exception as exc: # pragma: no cover - best effort cleanup logger.warning("Finalize previous LightRAG runtime failed: %s", exc) _runtime_instance = _LightRagRuntime(**runtime_kwargs) _runtime_signature = signature return _runtime_instance def _build_runtime_signature(self) -> tuple[tuple[Any, ...], dict[str, Any]]: configs = self._load_runtime_configs() settings = get_settings() working_dir = (self.storage_root / "knowledge" / ".lightrag").resolve() workspace = os.environ.get("LIGHTRAG_WORKSPACE", DEFAULT_LIGHTRAG_WORKSPACE).strip() or DEFAULT_LIGHTRAG_WORKSPACE qdrant_url = os.environ.get("QDRANT_URL", "").strip() or _resolve_default_qdrant_url() qdrant_api_key = os.environ.get("QDRANT_API_KEY", "").strip() signature = ( str(working_dir), workspace, qdrant_url, qdrant_api_key, configs["main"].provider, configs["main"].model, configs["main"].endpoint, configs["main"].api_key, configs["backup"].provider if configs["backup"] else "", configs["backup"].model if configs["backup"] else "", configs["backup"].endpoint if configs["backup"] else "", configs["backup"].api_key if configs["backup"] else "", configs["embedding"].provider, configs["embedding"].model, configs["embedding"].endpoint, configs["embedding"].api_key, configs["reranker"].provider if configs["reranker"] else "", configs["reranker"].model if configs["reranker"] else "", configs["reranker"].endpoint if configs["reranker"] else "", configs["reranker"].api_key if configs["reranker"] else "", str(settings.resolved_storage_root_dir), ) return signature, { "working_dir": working_dir, "workspace": workspace, "qdrant_url": qdrant_url, "qdrant_api_key": qdrant_api_key, "primary_chat": configs["main"], "backup_chat": configs["backup"], "embedding": configs["embedding"], "reranker": configs["reranker"], } def _load_runtime_configs(self) -> dict[str, RuntimeModelConfig | None]: owned_session = False session = self.db if session is None: session = get_session_factory()() owned_session = True try: settings_service = SettingsService(session) main = self._normalize_runtime_model(settings_service.get_runtime_model_config("main")) embedding = self._normalize_runtime_model(settings_service.get_runtime_model_config("embedding")) try: backup_raw = settings_service.get_runtime_model_config("backup") backup = self._normalize_runtime_model(backup_raw) except Exception: backup = None try: reranker_raw = settings_service.get_runtime_model_config("reranker") reranker = self._normalize_runtime_model(reranker_raw) except Exception: reranker = None if backup is not None and ( not backup.endpoint or not backup.model or (backup.provider != "Ollama" and not backup.api_key) ): backup = None if reranker is not None and ( not reranker.endpoint or not reranker.model or (reranker.provider != "Ollama" and not reranker.api_key) ): reranker = None if not main.endpoint or not main.model: raise KnowledgeRagError("主对话模型未配置,无法初始化 LightRAG。") if main.provider != "Ollama" and not main.api_key: raise KnowledgeRagError("主对话模型缺少 API Key,无法初始化 LightRAG。") if not embedding.endpoint or not embedding.model: raise KnowledgeRagError("Embedding 模型未配置,无法初始化 LightRAG。") if embedding.provider != "Ollama" and not embedding.api_key: raise KnowledgeRagError("Embedding 模型缺少 API Key,无法初始化 LightRAG。") return { "main": main, "backup": backup, "embedding": embedding, "reranker": reranker, } finally: if owned_session and session is not None: session.close() @staticmethod def _normalize_runtime_model(payload: dict[str, str]) -> RuntimeModelConfig: return RuntimeModelConfig( slot=str(payload.get("slot") or "").strip(), provider=str(payload.get("provider") or "").strip(), model=str(payload.get("model") or "").strip(), endpoint=str(payload.get("endpoint") or "").strip(), api_key=str(payload.get("apiKey") or "").strip(), capability=str(payload.get("capability") or "").strip(), ) @staticmethod def _build_hits_from_query_data( *, query: str, chunks: list[dict[str, Any]], entities: list[dict[str, Any]], limit: int, ) -> list[dict[str, Any]]: entity_tags_by_path: dict[str, list[str]] = {} for entity in entities: if not isinstance(entity, dict): continue file_path = str(entity.get("file_path") or "").strip() entity_name = str(entity.get("entity_name") or "").strip() if not file_path or not entity_name: continue entity_tags_by_path.setdefault(file_path, []) if entity_name not in entity_tags_by_path[file_path]: entity_tags_by_path[file_path].append(entity_name) query_terms = _extract_query_terms(query) prefers_tabular_evidence = any(hint in query for hint in TABLE_OR_STANDARD_QUERY_HINTS) candidates: list[dict[str, Any]] = [] for rank, chunk in enumerate(chunks, start=1): if not isinstance(chunk, dict): continue file_path = str(chunk.get("file_path") or "").strip() chunk_id = str(chunk.get("chunk_id") or "").strip() content = str(chunk.get("content") or "").strip() if not file_path or not content: continue document_id, document_name = _parse_document_identity(file_path) normalized_chunk_id = chunk_id or f"path-{rank}" normalized_content = _truncate_text(content, max_length=MAX_KNOWLEDGE_HIT_CONTENT_LENGTH) excerpt = _build_excerpt(normalized_content, max_length=220) candidates.append( { "code": f"knowledge.{document_id or 'unknown'}.{normalized_chunk_id}", "candidate_id": normalized_chunk_id, "title": document_name or "知识库文档", "content": normalized_content, "excerpt": excerpt, "document_id": document_id, "document_name": document_name or Path(file_path).name, "version": None, "updated_at": None, "score": max(1, 100 - rank), "tags": entity_tags_by_path.get(file_path, [])[:5], "evidence": [normalized_chunk_id], "file_path": file_path, "_rank": rank, } ) ranked = sorted( candidates, key=lambda item: ( _score_knowledge_hit( item, query_terms=query_terms, prefers_tabular_evidence=prefers_tabular_evidence, ), -int(item.get("_rank") or 0), ), reverse=True, ) hits: list[dict[str, Any]] = [] for item in ranked[: max(1, limit)]: normalized = dict(item) normalized.pop("_rank", None) hits.append(normalized) return hits @staticmethod def _serialize_status(status_obj: Any) -> dict[str, Any]: if status_obj is None: return {} if hasattr(status_obj, "__dict__"): payload = dict(status_obj.__dict__) elif isinstance(status_obj, dict): payload = dict(status_obj) else: payload = {} payload["status"] = KnowledgeRagService._status_value(status_obj) payload["error_msg"] = KnowledgeRagService._status_error(status_obj) payload["query_ready"] = KnowledgeRagService.is_query_ready_status(status_obj) return payload @staticmethod def _status_value(status_obj: Any) -> str: raw_status = getattr(status_obj, "status", None) if raw_status is None and isinstance(status_obj, dict): raw_status = status_obj.get("status") normalized = str(raw_status or "").strip().lower() if "." in normalized: normalized = normalized.split(".")[-1].strip() if ":" in normalized and normalized.endswith(">"): normalized = normalized.split(":")[0].strip("<> '\"") return normalized @staticmethod def _status_error(status_obj: Any) -> str: value = getattr(status_obj, "error_msg", None) if value is None and isinstance(status_obj, dict): value = status_obj.get("error_msg") return str(value or "").strip() @staticmethod def is_query_ready_status(status_obj: Any) -> bool: status_text = KnowledgeRagService._status_value(status_obj) if status_text == "processed": return True chunks_count = getattr(status_obj, "chunks_count", None) if chunks_count is None and isinstance(status_obj, dict): chunks_count = status_obj.get("chunks_count") try: if int(chunks_count or 0) > 0: return True except (TypeError, ValueError): pass chunks_list = getattr(status_obj, "chunks_list", None) if chunks_list is None and isinstance(status_obj, dict): chunks_list = status_obj.get("chunks_list") return bool(chunks_list) def shutdown_knowledge_rag_runtime() -> None: global _runtime_instance, _runtime_signature with _runtime_lock: if _runtime_instance is None: return try: _runtime_instance.finalize() except Exception as exc: # pragma: no cover - best effort cleanup logger.warning("Finalize LightRAG runtime failed during shutdown: %s", exc) _runtime_instance = None _runtime_signature = None def _normalize_endpoint(endpoint: str) -> str: normalized = str(endpoint or "").strip() if not normalized: raise KnowledgeRagError("模型 endpoint 不能为空。") return normalized.rstrip("/") def _ensure_path(endpoint: str, suffix: str) -> str: suffix = suffix.lstrip("/") if endpoint.endswith(suffix): return endpoint return f"{endpoint}/{suffix}" def _build_azure_deployment_base(endpoint: str, model: str) -> str: normalized_endpoint = _normalize_endpoint(endpoint) quoted_model = quote(model, safe="") if "/openai/deployments/" in normalized_endpoint: return normalized_endpoint if "/openai/v1" in normalized_endpoint: resource_root = normalized_endpoint.split("/openai/v1", maxsplit=1)[0] return f"{resource_root}/openai/deployments/{quoted_model}" if normalized_endpoint.endswith("/openai"): return f"{normalized_endpoint}/deployments/{quoted_model}" return f"{normalized_endpoint}/openai/deployments/{quoted_model}" def _build_headers( api_key: str, *, use_bearer: bool, use_api_key: bool = False, ) -> dict[str, str]: headers = { "Content-Type": "application/json", "Accept": "application/json", } normalized_key = str(api_key or "").strip() if normalized_key: if use_api_key: headers["api-key"] = normalized_key elif use_bearer: headers["Authorization"] = f"Bearer {normalized_key}" return headers def _send_json_request( method: str, url: str, *, headers: dict[str, str], payload: dict[str, Any], timeout_seconds: int, ) -> tuple[int, Any]: data = json.dumps(payload).encode("utf-8") request = Request(url=url, data=data, headers=headers, method=method) try: with urlopen(request, timeout=timeout_seconds) as response: # noqa: S310 body = response.read().decode("utf-8") if response.length != 0 else "" return response.status, _parse_json_body(body) except HTTPError as exc: # pragma: no cover - runtime path body = exc.read().decode("utf-8", errors="ignore") detail = _extract_error_message(_parse_json_body(body)) or f"接口返回 {exc.code}" raise KnowledgeRagError(detail) from exc except URLError as exc: # pragma: no cover - runtime path raise KnowledgeRagError(f"无法连接模型接口:{getattr(exc, 'reason', exc)}") from exc except TimeoutError as exc: # pragma: no cover - runtime path raise KnowledgeRagError("模型接口调用超时。") from exc def _parse_json_body(body: str) -> Any: if not body: return None try: return json.loads(body) except json.JSONDecodeError: return {"message": body} def _extract_error_message(payload: Any) -> str | None: if payload is None: return None if isinstance(payload, dict): if isinstance(payload.get("detail"), str): return payload["detail"] if isinstance(payload.get("message"), str): return payload["message"] error_payload = payload.get("error") if isinstance(error_payload, dict) and isinstance(error_payload.get("message"), str): return error_payload["message"] if isinstance(payload, str): return payload return None def _extract_chat_text(payload: Any, *, provider: str) -> str: if provider == "Ollama": message = payload.get("message") if isinstance(payload, dict) else None if isinstance(message, dict): return str(message.get("content") or "").strip() return "" if not isinstance(payload, dict): return "" choices = payload.get("choices") if not isinstance(choices, list) or not choices: return "" first_choice = choices[0] if not isinstance(first_choice, dict): return "" message = first_choice.get("message") if isinstance(message, dict): content = message.get("content") if isinstance(content, str): return content.strip() if isinstance(content, list): parts: list[str] = [] for item in content: if isinstance(item, dict) and item.get("type") == "text": parts.append(str(item.get("text") or "").strip()) return "\n".join(part for part in parts if part).strip() text = first_choice.get("text") if isinstance(text, str): return text.strip() return "" def _extract_embedding_vectors(payload: Any, *, provider: str) -> list[list[float]]: if provider == "Ollama": embeddings = payload.get("embeddings") if isinstance(payload, dict) else None if isinstance(embeddings, list): return [[float(value) for value in item] for item in embeddings if isinstance(item, list)] embedding = payload.get("embedding") if isinstance(payload, dict) else None if isinstance(embedding, list): return [[float(value) for value in embedding]] raise KnowledgeRagError("Ollama embedding 返回格式无法识别。") if not isinstance(payload, dict): raise KnowledgeRagError("embedding 接口返回格式无效。") data = payload.get("data") if not isinstance(data, list) or not data: raise KnowledgeRagError("embedding 接口没有返回 data。") vectors: list[list[float]] = [] for item in data: if not isinstance(item, dict): continue embedding = item.get("embedding") if isinstance(embedding, list): vectors.append([float(value) for value in embedding]) if not vectors: raise KnowledgeRagError("embedding 接口返回中未找到向量数据。") return vectors def _build_ali_rerank_request( model: str, *, query: str, documents: list[str], top_n: int | None, ) -> tuple[str, dict[str, Any]]: normalized_model = str(model or "").strip() if normalized_model == "qwen3-rerank": payload: dict[str, Any] = { "model": normalized_model, "query": query, "documents": documents, } if top_n is not None: payload["top_n"] = top_n return "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", payload payload = { "model": normalized_model, "input": { "query": query, "documents": documents, }, "parameters": { "return_documents": False, }, } if top_n is not None: payload["parameters"]["top_n"] = top_n return "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank", payload def _extract_rerank_results(payload: Any, *, provider: str) -> list[dict[str, Any]]: if not isinstance(payload, dict): return [] if provider == "Ali" and isinstance(payload.get("output"), dict): results = payload["output"].get("results") else: results = payload.get("results") if not isinstance(results, list): return [] normalized: list[dict[str, Any]] = [] for item in results: if not isinstance(item, dict): continue try: normalized.append( { "index": int(item["index"]), "relevance_score": float(item["relevance_score"]), } ) except (KeyError, TypeError, ValueError): continue return normalized def _parse_document_identity(file_path: str) -> tuple[str, str]: path = Path(str(file_path or "").strip()) name = path.name if "__" not in name: return "", name document_id, document_name = name.split("__", maxsplit=1) return document_id.strip(), document_name.strip() def _build_excerpt(text: str, *, max_length: int = 180) -> str: normalized = " ".join(str(text or "").split()).strip() if len(normalized) <= max_length: return normalized return f"{normalized[: max_length - 3].rstrip()}..." def _truncate_text(text: str, *, max_length: int) -> str: normalized = str(text or "").strip() if len(normalized) <= max_length: return normalized return f"{normalized[: max_length - 3].rstrip()}..." def _resolve_default_qdrant_url() -> str: if _hostname_resolves("qdrant"): return CONTAINER_QDRANT_URL return DEFAULT_QDRANT_URL def _hostname_resolves(hostname: str) -> bool: try: socket.getaddrinfo(hostname, None) except OSError: return False return True def _extract_query_terms(query: str) -> list[str]: normalized_query = str(query or "").strip().lower() if not normalized_query: return [] terms: list[str] = [] seen: set[str] = set() def remember(term: str) -> None: normalized_term = str(term or "").strip().lower() if ( not normalized_term or normalized_term in seen or normalized_term in QUERY_TERM_STOPWORDS or len(normalized_term) < 2 ): return seen.add(normalized_term) terms.append(normalized_term) for item in re.findall(r"[a-z0-9][a-z0-9_\-]{1,}", normalized_query): remember(item) for block in re.findall(r"[\u4e00-\u9fff]{2,20}", normalized_query): if len(block) <= 4: remember(block) continue for size in (4, 3, 2): for start in range(0, len(block) - size + 1): remember(block[start : start + size]) if len(terms) >= MAX_QUERY_TERMS: return terms return terms[:MAX_QUERY_TERMS] def _score_knowledge_hit( item: dict[str, Any], *, query_terms: list[str], prefers_tabular_evidence: bool, ) -> int: rank = max(1, int(item.get("_rank") or 1)) title = str(item.get("title") or item.get("document_name") or "").lower() content = str(item.get("content") or "").lower() excerpt = str(item.get("excerpt") or "").lower() tags = " ".join(str(value).lower() for value in list(item.get("tags") or [])[:5]) haystack = "\n".join([title, excerpt, tags, content[:1200]]) score = max(1, 120 - rank * 4) matched_terms = [term for term in query_terms if term in haystack] score += len(matched_terms) * 8 score += sum(1 for term in matched_terms if term in title) * 6 if "结构化表格补充" in content: score += 18 if "问答线索补充" in content: score += 16 if not prefers_tabular_evidence else 8 if "重点章节摘录" in content: score += 10 if "章节导航" in content: score += 4 if prefers_tabular_evidence and ("|" in content or "表" in content or "结构化表格补充" in content): score += 12 if not prefers_tabular_evidence and any(marker in content for marker in ("第", "条", ":", "-", "•")): score += 4 if title and any(term in title for term in query_terms): score += 6 return score