from __future__ import annotations import asyncio import json import os from dataclasses import dataclass from http import HTTPStatus from pathlib import Path from time import perf_counter from typing import Any from urllib.error import HTTPError, URLError from urllib.parse import quote from urllib.request import Request, urlopen from app.core.logging import get_logger from app.services.model_connectivity import AZURE_API_VERSION logger = get_logger("app.services.knowledge_rag") DEFAULT_LIGHTRAG_QUERY_MODE = "naive" DEFAULT_LLM_TIMEOUT_SECONDS = 180 DEFAULT_EMBEDDING_TIMEOUT_SECONDS = 120 class KnowledgeRagError(RuntimeError): pass @dataclass(frozen=True, slots=True) class RuntimeModelConfig: slot: str provider: str model: str endpoint: str api_key: str capability: str class _LightRagRuntime: def __init__( self, *, working_dir: Path, workspace: str, qdrant_url: str, qdrant_api_key: str, primary_chat: RuntimeModelConfig, backup_chat: RuntimeModelConfig | None, embedding: RuntimeModelConfig, reranker: RuntimeModelConfig | None, ) -> None: self.working_dir = working_dir self.workspace = workspace self.qdrant_url = qdrant_url self.qdrant_api_key = qdrant_api_key self.primary_chat = primary_chat self.backup_chat = backup_chat self.embedding = embedding self.reranker = reranker self._rag = self._build_rag() self._initialize() self._graph_has_content_cache: bool | None = None @property def rag(self): return self._rag def _build_rag(self): try: from lightrag import LightRAG from lightrag.utils import EmbeddingFunc except ImportError as exc: # pragma: no cover - exercised in runtime env raise KnowledgeRagError( "LightRAG 依赖未安装,请先在 server 环境执行依赖安装。" ) from exc self.working_dir.mkdir(parents=True, exist_ok=True) if self.qdrant_url: os.environ["QDRANT_URL"] = self.qdrant_url if self.qdrant_api_key: os.environ["QDRANT_API_KEY"] = self.qdrant_api_key embedding_dim = self._probe_embedding_dimension(self.embedding) logger.info( "Initialize LightRAG runtime workspace=%s qdrant=%s embedding_model=%s dim=%s", self.workspace, self.qdrant_url, self.embedding.model, embedding_dim, ) async def embedding_func(texts: list[str]) -> Any: return await asyncio.to_thread(self._embed_sync, texts) async def llm_model_func( prompt: str, system_prompt: str | None = None, history_messages: list[dict[str, Any]] | None = None, keyword_extraction: bool = False, **kwargs: Any, ) -> str: return await asyncio.to_thread( self._complete_sync, prompt, system_prompt, history_messages or [], keyword_extraction, kwargs, ) async def rerank_model_func( query: str, documents: list[str], top_n: int | None = None, **_kwargs: Any, ) -> list[dict[str, Any]]: return await asyncio.to_thread( self._rerank_sync, query, documents, top_n, ) return LightRAG( working_dir=str(self.working_dir), workspace=self.workspace, kv_storage="JsonKVStorage", graph_storage="NetworkXStorage", vector_storage="QdrantVectorDBStorage", doc_status_storage="JsonDocStatusStorage", llm_model_name=self.primary_chat.model, llm_model_func=llm_model_func, embedding_func=EmbeddingFunc( embedding_dim=embedding_dim, func=embedding_func, max_token_size=8192, model_name=self.embedding.model, supports_asymmetric=False, ), rerank_model_func=rerank_model_func if self.reranker is not None else None, enable_llm_cache=False, enable_llm_cache_for_entity_extract=False, ) def _initialize(self) -> None: from lightrag.utils import always_get_an_event_loop loop = always_get_an_event_loop() loop.run_until_complete(self._rag.initialize_storages()) def finalize(self) -> None: from lightrag.utils import always_get_an_event_loop loop = always_get_an_event_loop() loop.run_until_complete(self._rag.finalize_storages()) def query_data(self, query: str, *, conversation_history: list[dict[str, str]] | None = None) -> dict[str, Any]: from lightrag import QueryParam configured_mode = os.environ.get("LIGHTRAG_QUERY_MODE", DEFAULT_LIGHTRAG_QUERY_MODE).strip() or DEFAULT_LIGHTRAG_QUERY_MODE mode = "naive" if configured_mode != "naive" and not self._graph_has_content() else configured_mode started_at = perf_counter() param = QueryParam( mode=mode, top_k=8, chunk_top_k=10, only_need_context=True, response_type="Multiple Paragraphs", conversation_history=conversation_history or [], include_references=True, ) try: result = self._rag.query_data(query, param) logger.info("LightRAG query completed mode=%s elapsed=%.2fs", mode, perf_counter() - started_at) return result except Exception: if mode == "naive": raise logger.warning("LightRAG query mode=%s failed, retry with naive mode", mode) fallback_param = QueryParam( mode="naive", top_k=8, chunk_top_k=10, only_need_context=True, response_type="Multiple Paragraphs", conversation_history=conversation_history or [], include_references=True, ) result = self._rag.query_data(query, fallback_param) logger.info("LightRAG query completed mode=naive elapsed=%.2fs", perf_counter() - started_at) return result def _graph_has_content(self) -> bool: if self._graph_has_content_cache is not None: return self._graph_has_content_cache graph_path = self.working_dir / self.workspace / "graph_chunk_entity_relation.graphml" try: graph_text = graph_path.read_text(encoding="utf-8") except OSError: self._graph_has_content_cache = False return False self._graph_has_content_cache = " str: return self._rag.insert(texts, ids=document_ids, file_paths=file_paths) def get_document_statuses(self, document_ids: list[str]) -> dict[str, Any]: from lightrag.utils import always_get_an_event_loop loop = always_get_an_event_loop() return loop.run_until_complete(self._rag.aget_docs_by_ids(document_ids)) def delete_document(self, document_id: str) -> None: from lightrag.utils import always_get_an_event_loop loop = always_get_an_event_loop() result = loop.run_until_complete(self._rag.adelete_by_doc_id(document_id)) status = str(getattr(result, "status", "") or "") if status not in {"success", "not_found"}: raise KnowledgeRagError(str(getattr(result, "message", "") or "LightRAG 删除文档失败。")) def _probe_embedding_dimension(self, config: RuntimeModelConfig) -> int: try: vectors = self._request_embeddings(config, ["dimension probe"]) except Exception as exc: raise KnowledgeRagError( "Embedding model probe failed " f"(slot={config.slot}, provider={config.provider}, model={config.model}): {exc}" ) from exc if not vectors or not isinstance(vectors[0], list): raise KnowledgeRagError("无法从 embedding 模型返回结果中解析向量维度。") dimension = len(vectors[0]) if dimension <= 0: raise KnowledgeRagError("embedding 模型返回了无效的向量维度。") return dimension def _embed_sync(self, texts: list[str]) -> Any: import numpy as np vectors = self._request_embeddings(self.embedding, texts) return np.array(vectors, dtype=float) def _rerank_sync( self, query: str, documents: list[str], top_n: int | None, ) -> list[dict[str, Any]]: if self.reranker is None: return [] status_code, body = self._request_rerank( self.reranker, query=query, documents=documents, top_n=top_n, ) if status_code >= HTTPStatus.BAD_REQUEST: raise KnowledgeRagError(f"reranker 模型返回异常状态码 {status_code}。") return _extract_rerank_results(body, provider=self.reranker.provider) def _complete_sync( self, prompt: str, system_prompt: str | None, history_messages: list[dict[str, Any]], keyword_extraction: bool, kwargs: dict[str, Any], ) -> str: del keyword_extraction last_error: Exception | None = None for config in [self.primary_chat, self.backup_chat]: if config is None: continue try: return self._request_chat_completion( config, prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, max_tokens=int(kwargs.get("max_tokens") or 1200), temperature=float(kwargs.get("temperature") or 0.1), ) except Exception as exc: # pragma: no cover - runtime fallback last_error = exc logger.warning( "LightRAG LLM request failed slot=%s provider=%s model=%s: %s", config.slot, config.provider, config.model, exc, ) continue raise KnowledgeRagError(f"LightRAG 调用知识模型失败:{last_error or '没有可用模型配置'}") def _request_chat_completion( self, config: RuntimeModelConfig, *, prompt: str, system_prompt: str | None, history_messages: list[dict[str, Any]], max_tokens: int, temperature: float, ) -> str: messages: list[dict[str, Any]] = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) if config.provider == "Azure OpenAI": url = f"{_build_azure_deployment_base(config.endpoint, config.model)}/chat/completions?api-version={AZURE_API_VERSION}" payload = { "messages": messages, "max_tokens": max_tokens, "temperature": temperature, } status_code, body = _send_json_request( "POST", url, headers=_build_headers(config.api_key, use_bearer=False, use_api_key=True), payload=payload, timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS, ) elif config.provider == "Ollama": url = _ensure_path(_normalize_endpoint(config.endpoint), "api/chat") payload = { "model": config.model, "messages": messages, "stream": False, "options": { "num_predict": max_tokens, "temperature": temperature, }, } status_code, body = _send_json_request( "POST", url, headers={"Content-Type": "application/json", "Accept": "application/json"}, payload=payload, timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS, ) else: url = _ensure_path(_normalize_endpoint(config.endpoint), "chat/completions") payload = { "model": config.model, "messages": messages, "max_tokens": max_tokens, "temperature": temperature, } status_code, body = _send_json_request( "POST", url, headers=_build_headers(config.api_key, use_bearer=True), payload=payload, timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS, ) if status_code >= HTTPStatus.BAD_REQUEST: raise KnowledgeRagError(f"知识模型返回异常状态码 {status_code}。") return _extract_chat_text(body, provider=config.provider) def _request_embeddings(self, config: RuntimeModelConfig, texts: list[str]) -> list[list[float]]: if config.provider == "Azure OpenAI": url = f"{_build_azure_deployment_base(config.endpoint, config.model)}/embeddings?api-version={AZURE_API_VERSION}" payload = {"input": texts} status_code, body = _send_json_request( "POST", url, headers=_build_headers(config.api_key, use_bearer=False, use_api_key=True), payload=payload, timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS, ) elif config.provider == "Ollama": url = _ensure_path(_normalize_endpoint(config.endpoint), "api/embed") payload = {"model": config.model, "input": texts} status_code, body = _send_json_request( "POST", url, headers={"Content-Type": "application/json", "Accept": "application/json"}, payload=payload, timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS, ) else: url = _ensure_path(_normalize_endpoint(config.endpoint), "embeddings") payload = {"model": config.model, "input": texts} status_code, body = _send_json_request( "POST", url, headers=_build_headers(config.api_key, use_bearer=True), payload=payload, timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS, ) if status_code >= HTTPStatus.BAD_REQUEST: raise KnowledgeRagError(f"embedding 模型返回异常状态码 {status_code}。") return _extract_embedding_vectors(body, provider=config.provider) def _request_rerank( self, config: RuntimeModelConfig, *, query: str, documents: list[str], top_n: int | None, ) -> tuple[int, Any]: if config.provider == "Azure OpenAI": url = f"{_build_azure_deployment_base(config.endpoint, config.model)}/rerank?api-version={AZURE_API_VERSION}" payload: dict[str, Any] = { "query": query, "documents": documents, } if top_n is not None: payload["top_n"] = top_n return _send_json_request( "POST", url, headers=_build_headers(config.api_key, use_bearer=False, use_api_key=True), payload=payload, timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS, ) if config.provider == "Ali": url, payload = _build_ali_rerank_request( config.model, query=query, documents=documents, top_n=top_n, ) return _send_json_request( "POST", url, headers=_build_headers(config.api_key, use_bearer=True), payload=payload, timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS, ) url = _ensure_path(_normalize_endpoint(config.endpoint), "rerank") payload = { "model": config.model, "query": query, "documents": documents, } if top_n is not None: payload["top_n"] = top_n return _send_json_request( "POST", url, headers=_build_headers(config.api_key, use_bearer=True), payload=payload, timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS, ) def _normalize_endpoint(endpoint: str) -> str: normalized = str(endpoint or "").strip() if not normalized: raise KnowledgeRagError("模型 endpoint 不能为空。") return normalized.rstrip("/") def _ensure_path(endpoint: str, suffix: str) -> str: suffix = suffix.lstrip("/") if endpoint.endswith(suffix): return endpoint return f"{endpoint}/{suffix}" def _build_azure_deployment_base(endpoint: str, model: str) -> str: normalized_endpoint = _normalize_endpoint(endpoint) quoted_model = quote(model, safe="") if "/openai/deployments/" in normalized_endpoint: return normalized_endpoint if "/openai/v1" in normalized_endpoint: resource_root = normalized_endpoint.split("/openai/v1", maxsplit=1)[0] return f"{resource_root}/openai/deployments/{quoted_model}" if normalized_endpoint.endswith("/openai"): return f"{normalized_endpoint}/deployments/{quoted_model}" return f"{normalized_endpoint}/openai/deployments/{quoted_model}" def _build_headers( api_key: str, *, use_bearer: bool, use_api_key: bool = False, ) -> dict[str, str]: headers = { "Content-Type": "application/json", "Accept": "application/json", } normalized_key = str(api_key or "").strip() if normalized_key: if use_api_key: headers["api-key"] = normalized_key elif use_bearer: headers["Authorization"] = f"Bearer {normalized_key}" return headers def _send_json_request( method: str, url: str, *, headers: dict[str, str], payload: dict[str, Any], timeout_seconds: int, ) -> tuple[int, Any]: data = json.dumps(payload).encode("utf-8") request = Request(url=url, data=data, headers=headers, method=method) try: with urlopen(request, timeout=timeout_seconds) as response: # noqa: S310 body = response.read().decode("utf-8") if response.length != 0 else "" return response.status, _parse_json_body(body) except HTTPError as exc: # pragma: no cover - runtime path body = exc.read().decode("utf-8", errors="ignore") detail = _extract_error_message(_parse_json_body(body)) or f"接口返回 {exc.code}" raise KnowledgeRagError(detail) from exc except URLError as exc: # pragma: no cover - runtime path raise KnowledgeRagError(f"无法连接模型接口:{getattr(exc, 'reason', exc)}") from exc except TimeoutError as exc: # pragma: no cover - runtime path raise KnowledgeRagError("模型接口调用超时。") from exc def _parse_json_body(body: str) -> Any: if not body: return None try: return json.loads(body) except json.JSONDecodeError: return {"message": body} def _extract_error_message(payload: Any) -> str | None: if payload is None: return None if isinstance(payload, dict): if isinstance(payload.get("detail"), str): return payload["detail"] if isinstance(payload.get("message"), str): return payload["message"] error_payload = payload.get("error") if isinstance(error_payload, dict) and isinstance(error_payload.get("message"), str): return error_payload["message"] if isinstance(payload, str): return payload return None def _extract_chat_text(payload: Any, *, provider: str) -> str: if provider == "Ollama": message = payload.get("message") if isinstance(payload, dict) else None if isinstance(message, dict): return str(message.get("content") or "").strip() return "" if not isinstance(payload, dict): return "" choices = payload.get("choices") if not isinstance(choices, list) or not choices: return "" first_choice = choices[0] if not isinstance(first_choice, dict): return "" message = first_choice.get("message") if isinstance(message, dict): content = message.get("content") if isinstance(content, str): return content.strip() if isinstance(content, list): parts: list[str] = [] for item in content: if isinstance(item, dict) and item.get("type") == "text": parts.append(str(item.get("text") or "").strip()) return "\n".join(part for part in parts if part).strip() text = first_choice.get("text") if isinstance(text, str): return text.strip() return "" def _extract_embedding_vectors(payload: Any, *, provider: str) -> list[list[float]]: if provider == "Ollama": embeddings = payload.get("embeddings") if isinstance(payload, dict) else None if isinstance(embeddings, list): return [[float(value) for value in item] for item in embeddings if isinstance(item, list)] embedding = payload.get("embedding") if isinstance(payload, dict) else None if isinstance(embedding, list): return [[float(value) for value in embedding]] raise KnowledgeRagError("Ollama embedding 返回格式无法识别。") if not isinstance(payload, dict): raise KnowledgeRagError("embedding 接口返回格式无效。") data = payload.get("data") if not isinstance(data, list) or not data: raise KnowledgeRagError("embedding 接口没有返回 data。") vectors: list[list[float]] = [] for item in data: if not isinstance(item, dict): continue embedding = item.get("embedding") if isinstance(embedding, list): vectors.append([float(value) for value in embedding]) if not vectors: raise KnowledgeRagError("embedding 接口返回中未找到向量数据。") return vectors def _build_ali_rerank_request( model: str, *, query: str, documents: list[str], top_n: int | None, ) -> tuple[str, dict[str, Any]]: normalized_model = str(model or "").strip() if normalized_model == "qwen3-rerank": payload: dict[str, Any] = { "model": normalized_model, "query": query, "documents": documents, } if top_n is not None: payload["top_n"] = top_n return "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", payload payload = { "model": normalized_model, "input": { "query": query, "documents": documents, }, "parameters": { "return_documents": False, }, } if top_n is not None: payload["parameters"]["top_n"] = top_n return "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank", payload def _extract_rerank_results(payload: Any, *, provider: str) -> list[dict[str, Any]]: if not isinstance(payload, dict): return [] if provider == "Ali" and isinstance(payload.get("output"), dict): results = payload["output"].get("results") else: results = payload.get("results") if not isinstance(results, list): return [] normalized: list[dict[str, Any]] = [] for item in results: if not isinstance(item, dict): continue try: normalized.append( { "index": int(item["index"]), "relevance_score": float(item["relevance_score"]), } ) except (KeyError, TypeError, ValueError): continue return normalized