from __future__ import annotations import os import re import socket import threading from pathlib import Path from typing import Any from sqlalchemy.orm import Session from app.core.config import get_settings from app.core.logging import get_logger from app.db.session import get_session_factory from app.services.knowledge_ingest_log import ( build_document_graph_summary, build_ingest_document_summary, build_ingest_status_summary, ) from app.services.knowledge_rag_runtime import ( KnowledgeRagError, RuntimeModelConfig, _LightRagRuntime, ) from app.services.settings import SettingsService logger = get_logger("app.services.knowledge_rag") DEFAULT_QDRANT_URL = "http://127.0.0.1:6333" CONTAINER_QDRANT_URL = "http://qdrant:6333" DEFAULT_LIGHTRAG_WORKSPACE = "x_financial_knowledge" MAX_KNOWLEDGE_HIT_CONTENT_LENGTH = 2200 MAX_KNOWLEDGE_HIT_EXCERPT_LENGTH = 220 MAX_QUERY_TERMS = 12 QUERY_TERM_STOPWORDS = { "什么", "多少", "哪些", "怎么", "如何", "请问", "一下", "关于", "规定", "标准", "可以", "是否", "一个", "哪些人", } TABLE_OR_STANDARD_QUERY_HINTS = ( "标准", "金额", "限额", "补贴", "住宿", "餐费", "交通", "报销", "档位", "额度", ) STRUCTURED_APPENDIX_LEADING_MARKERS = ( "# 章节导航", "# 重点章节摘录", "# 问答线索补充", "# 结构化表格补充", ) STRUCTURED_APPENDIX_LEADING_WINDOW = 220 _runtime_lock = threading.RLock() _runtime_instances: dict[int, _LightRagRuntime] = {} _runtime_signatures: dict[int, tuple[Any, ...]] = {} class KnowledgeRagService: def __init__(self, db: Session | None = None, storage_root: Path | None = None) -> None: self.db = db self.storage_root = Path(storage_root or get_settings().resolved_storage_root_dir) def query_knowledge( self, query: str, *, conversation_history: list[dict[str, str]] | None = None, limit: int = 5, ) -> dict[str, Any]: normalized_query = str(query or "").strip() if not normalized_query: return { "result_type": "knowledge_search", "query": "", "record_count": 0, "hits": [], "references": [], "message": "请先输入要检索的知识库问题。", } try: runtime = self._get_runtime() raw = runtime.query_data(normalized_query, conversation_history=conversation_history) except Exception as exc: logger.warning("Knowledge query failed: %s", exc) return { "result_type": "knowledge_search", "query": normalized_query, "record_count": 0, "hits": [], "references": [], "message": f"知识库检索暂不可用:{exc}", } data = raw.get("data") if isinstance(raw, dict) else {} chunks = list(data.get("chunks") or []) if isinstance(data, dict) else [] entities = list(data.get("entities") or []) if isinstance(data, dict) else [] references = list(data.get("references") or []) if isinstance(data, dict) else [] hits = self._build_hits_from_query_data( query=normalized_query, chunks=chunks, entities=entities, limit=limit, ) if not hits: return { "result_type": "knowledge_search", "query": normalized_query, "record_count": 0, "hits": [], "references": [], "raw_references": references, "message": "当前知识库中没有检索到与本次问题直接匹配的内容。", } return { "result_type": "knowledge_search", "query": normalized_query, "record_count": len(hits), "hits": hits, "references": [ str(item.get("code") or "").strip() for item in hits if str(item.get("code") or "").strip() ], "raw_references": references, "metadata": raw.get("metadata") if isinstance(raw, dict) else {}, "message": f"已从知识库中检索到 {len(hits)} 条相关内容。", } def index_documents( self, *, document_ids: list[str], force: bool = False, ) -> dict[str, Any]: normalized_ids = [str(item).strip() for item in document_ids if str(item).strip()] if not normalized_ids: raise ValueError("没有可供索引的知识文档。") from app.services.knowledge import KnowledgeService from app.services.knowledge_normalizer import KnowledgeNormalizationService knowledge_service = KnowledgeService(storage_root=self.storage_root, db=self.db) normalization_service = ( KnowledgeNormalizationService(self.db) if self.db is not None else None ) texts: list[str] = [] file_paths: list[str] = [] document_summaries: list[dict[str, Any]] = [] runtime = self._get_runtime() existing_statuses = runtime.get_document_statuses(normalized_ids) for document_id in normalized_ids: entry = knowledge_service.get_document_entry(document_id) if force and document_id in existing_statuses: try: runtime.delete_document(document_id) except Exception as exc: logger.warning( "Delete existing LightRAG document failed doc_id=%s: %s", document_id, exc ) text = knowledge_service.extract_document_text(document_id) raw_text = text if normalization_service is not None: text = normalization_service.build_enriched_text(text) texts.append(text) file_paths.append( str( ( knowledge_service.library_root / entry["folder"] / entry["stored_name"] ).resolve() ) ) document_summaries.append( build_ingest_document_summary( document_id=document_id, entry=entry, raw_text=raw_text, indexed_text=text, ) ) track_id = runtime.insert_documents( texts=texts, document_ids=normalized_ids, file_paths=file_paths, ) statuses = runtime.get_document_statuses(normalized_ids) succeeded_document_ids: list[str] = [] failed_documents: list[dict[str, str]] = [] summary_by_id = { str(item.get("document_id") or "").strip(): item for item in document_summaries if str(item.get("document_id") or "").strip() } for document_id in normalized_ids: status_obj = statuses.get(document_id) status_text = self._status_value(status_obj) status_payload = self._serialize_status(status_obj) workspace = ( os.environ.get("LIGHTRAG_WORKSPACE", DEFAULT_LIGHTRAG_WORKSPACE).strip() or DEFAULT_LIGHTRAG_WORKSPACE ) graph_summary = build_document_graph_summary( self.storage_root, workspace=workspace, document_id=document_id, ) if document_id in summary_by_id: summary_by_id[document_id].update( build_ingest_status_summary( status_payload=status_payload, graph_summary=graph_summary, ) ) if self.is_query_ready_status(status_obj): succeeded_document_ids.append(document_id) continue failed_documents.append( { "document_id": document_id, "status": status_text or "unknown", "error": self._status_error(status_obj), } ) return { "track_id": track_id, "requested_document_ids": normalized_ids, "succeeded_document_ids": succeeded_document_ids, "failed_documents": failed_documents, "document_summaries": [ summary_by_id.get(document_id, {}) for document_id in normalized_ids ], "status_snapshot": { document_id: self._serialize_status(status_obj) for document_id, status_obj in statuses.items() }, } def get_document_status_map( self, document_ids: list[str] | None = None ) -> dict[str, dict[str, Any]]: target_ids = [str(item).strip() for item in document_ids or [] if str(item).strip()] if not target_ids: return {} try: statuses = self._get_runtime().get_document_statuses(target_ids) except Exception as exc: logger.warning("Load LightRAG document statuses failed: %s", exc) return {} return { document_id: self._serialize_status(status_obj) for document_id, status_obj in statuses.items() } def delete_document(self, document_id: str) -> None: normalized_id = str(document_id or "").strip() if not normalized_id: return try: self._get_runtime().delete_document(normalized_id) except Exception as exc: logger.warning("Delete LightRAG document ignored doc_id=%s: %s", normalized_id, exc) def _get_runtime(self) -> _LightRagRuntime: signature, runtime_kwargs = self._build_runtime_signature() thread_id = threading.get_ident() with _runtime_lock: runtime = _runtime_instances.get(thread_id) if runtime is not None and _runtime_signatures.get(thread_id) == signature: return runtime if runtime is not None: try: runtime.finalize() except Exception as exc: # pragma: no cover - best effort cleanup logger.warning("Finalize previous LightRAG runtime failed: %s", exc) runtime = _LightRagRuntime(**runtime_kwargs) _runtime_instances[thread_id] = runtime _runtime_signatures[thread_id] = signature return runtime def _build_runtime_signature(self) -> tuple[tuple[Any, ...], dict[str, Any]]: configs = self._load_runtime_configs() settings = get_settings() working_dir = (self.storage_root / "knowledge" / ".lightrag").resolve() workspace = ( os.environ.get("LIGHTRAG_WORKSPACE", DEFAULT_LIGHTRAG_WORKSPACE).strip() or DEFAULT_LIGHTRAG_WORKSPACE ) qdrant_url = os.environ.get("QDRANT_URL", "").strip() or _resolve_default_qdrant_url() qdrant_api_key = os.environ.get("QDRANT_API_KEY", "").strip() signature = ( str(working_dir), workspace, qdrant_url, qdrant_api_key, configs["main"].provider, configs["main"].model, configs["main"].endpoint, configs["main"].api_key, configs["backup"].provider if configs["backup"] else "", configs["backup"].model if configs["backup"] else "", configs["backup"].endpoint if configs["backup"] else "", configs["backup"].api_key if configs["backup"] else "", configs["embedding"].provider, configs["embedding"].model, configs["embedding"].endpoint, configs["embedding"].api_key, configs["reranker"].provider if configs["reranker"] else "", configs["reranker"].model if configs["reranker"] else "", configs["reranker"].endpoint if configs["reranker"] else "", configs["reranker"].api_key if configs["reranker"] else "", str(settings.resolved_storage_root_dir), ) return signature, { "working_dir": working_dir, "workspace": workspace, "qdrant_url": qdrant_url, "qdrant_api_key": qdrant_api_key, "primary_chat": configs["main"], "backup_chat": configs["backup"], "embedding": configs["embedding"], "reranker": configs["reranker"], } def _load_runtime_configs(self) -> dict[str, RuntimeModelConfig | None]: owned_session = False session = self.db if session is None: session = get_session_factory()() owned_session = True try: settings_service = SettingsService(session) main = self._normalize_runtime_model(settings_service.get_runtime_model_config("main")) embedding = self._normalize_runtime_model( settings_service.get_runtime_model_config("embedding") ) try: backup_raw = settings_service.get_runtime_model_config("backup") backup = self._normalize_runtime_model(backup_raw) except Exception: backup = None try: reranker_raw = settings_service.get_runtime_model_config("reranker") reranker = self._normalize_runtime_model(reranker_raw) except Exception: reranker = None if backup is not None and ( not backup.endpoint or not backup.model or (backup.provider != "Ollama" and not backup.api_key) ): backup = None if reranker is not None and ( not reranker.endpoint or not reranker.model or (reranker.provider != "Ollama" and not reranker.api_key) ): reranker = None if not main.endpoint or not main.model: raise KnowledgeRagError("主对话模型未配置,无法初始化 LightRAG。") if main.provider != "Ollama" and not main.api_key: raise KnowledgeRagError("主对话模型缺少 API Key,无法初始化 LightRAG。") if not embedding.endpoint or not embedding.model: raise KnowledgeRagError("Embedding 模型未配置,无法初始化 LightRAG。") if embedding.provider != "Ollama" and not embedding.api_key: raise KnowledgeRagError("Embedding 模型缺少 API Key,无法初始化 LightRAG。") return { "main": main, "backup": backup, "embedding": embedding, "reranker": reranker, } finally: if owned_session and session is not None: session.close() @staticmethod def _normalize_runtime_model(payload: dict[str, str]) -> RuntimeModelConfig: return RuntimeModelConfig( slot=str(payload.get("slot") or "").strip(), provider=str(payload.get("provider") or "").strip(), model=str(payload.get("model") or "").strip(), endpoint=str(payload.get("endpoint") or "").strip(), api_key=str(payload.get("apiKey") or "").strip(), capability=str(payload.get("capability") or "").strip(), ) @staticmethod def _build_hits_from_query_data( *, query: str, chunks: list[dict[str, Any]], entities: list[dict[str, Any]], limit: int, ) -> list[dict[str, Any]]: entity_tags_by_path: dict[str, list[str]] = {} for entity in entities: if not isinstance(entity, dict): continue file_path = str(entity.get("file_path") or "").strip() entity_name = str(entity.get("entity_name") or "").strip() if not file_path or not entity_name: continue entity_tags_by_path.setdefault(file_path, []) if entity_name not in entity_tags_by_path[file_path]: entity_tags_by_path[file_path].append(entity_name) query_terms = _extract_query_terms(query) prefers_tabular_evidence = any(hint in query for hint in TABLE_OR_STANDARD_QUERY_HINTS) candidates: list[dict[str, Any]] = [] for rank, chunk in enumerate(chunks, start=1): if not isinstance(chunk, dict): continue file_path = str(chunk.get("file_path") or "").strip() chunk_id = str(chunk.get("chunk_id") or "").strip() content = str(chunk.get("content") or "").strip() if not file_path or not content: continue document_id, document_name = _parse_document_identity(file_path) normalized_chunk_id = chunk_id or f"path-{rank}" normalized_content = _truncate_text( content, max_length=MAX_KNOWLEDGE_HIT_CONTENT_LENGTH ) excerpt = _build_query_focused_excerpt( normalized_content, query_terms=query_terms, max_length=MAX_KNOWLEDGE_HIT_EXCERPT_LENGTH, ) candidates.append( { "code": f"knowledge.{document_id or 'unknown'}.{normalized_chunk_id}", "candidate_id": normalized_chunk_id, "title": document_name or "知识库文档", "content": normalized_content, "excerpt": excerpt, "document_id": document_id, "document_name": document_name or Path(file_path).name, "version": None, "updated_at": None, "score": max(1, 100 - rank), "tags": entity_tags_by_path.get(file_path, [])[:5], "evidence": [normalized_chunk_id], "file_path": file_path, "_rank": rank, } ) ranked = sorted( candidates, key=lambda item: ( _score_knowledge_hit( item, query_terms=query_terms, prefers_tabular_evidence=prefers_tabular_evidence, ), -int(item.get("_rank") or 0), ), reverse=True, ) hits: list[dict[str, Any]] = [] for item in ranked[: max(1, limit)]: normalized = dict(item) normalized.pop("_rank", None) hits.append(normalized) return hits @staticmethod def _serialize_status(status_obj: Any) -> dict[str, Any]: if status_obj is None: return {} if hasattr(status_obj, "__dict__"): payload = dict(status_obj.__dict__) elif isinstance(status_obj, dict): payload = dict(status_obj) else: payload = {} payload["status"] = KnowledgeRagService._status_value(status_obj) payload["error_msg"] = KnowledgeRagService._status_error(status_obj) payload["query_ready"] = KnowledgeRagService.is_query_ready_status(status_obj) return payload @staticmethod def _status_value(status_obj: Any) -> str: raw_status = getattr(status_obj, "status", None) if raw_status is None and isinstance(status_obj, dict): raw_status = status_obj.get("status") normalized = str(raw_status or "").strip().lower() if "." in normalized: normalized = normalized.split(".")[-1].strip() if ":" in normalized and normalized.endswith(">"): normalized = normalized.split(":")[0].strip("<> '\"") return normalized @staticmethod def _status_error(status_obj: Any) -> str: value = getattr(status_obj, "error_msg", None) if value is None and isinstance(status_obj, dict): value = status_obj.get("error_msg") return str(value or "").strip() @staticmethod def is_query_ready_status(status_obj: Any) -> bool: status_text = KnowledgeRagService._status_value(status_obj) if status_text in {"failed", "error", "aborted"}: return False if status_text == "processed": return True if status_text in {"pending", "processing", "preprocessed"}: return False chunks_count = getattr(status_obj, "chunks_count", None) if chunks_count is None and isinstance(status_obj, dict): chunks_count = status_obj.get("chunks_count") try: if int(chunks_count or 0) > 0: return True except (TypeError, ValueError): pass chunks_list = getattr(status_obj, "chunks_list", None) if chunks_list is None and isinstance(status_obj, dict): chunks_list = status_obj.get("chunks_list") return bool(chunks_list) def shutdown_knowledge_rag_runtime() -> None: with _runtime_lock: for runtime in list(_runtime_instances.values()): try: runtime.finalize() except Exception as exc: # pragma: no cover - best effort cleanup logger.warning("Finalize LightRAG runtime failed during shutdown: %s", exc) _runtime_instances.clear() _runtime_signatures.clear() def _parse_document_identity(file_path: str) -> tuple[str, str]: path = Path(str(file_path or "").strip()) name = path.name if "__" not in name: return "", name document_id, document_name = name.split("__", maxsplit=1) return document_id.strip(), document_name.strip() def _build_excerpt(text: str, *, max_length: int = 180) -> str: normalized = " ".join(str(text or "").split()).strip() if len(normalized) <= max_length: return normalized return f"{normalized[: max_length - 3].rstrip()}..." def _build_query_focused_excerpt( text: str, *, query_terms: list[str], max_length: int = 180, ) -> str: normalized = " ".join(str(text or "").split()).strip() if not normalized: return "" lowered = normalized.lower() match_positions = [ lowered.find(term) for term in query_terms if term and lowered.find(term) >= 0 ] if not match_positions: return _build_excerpt(normalized, max_length=max_length) start = max(0, min(match_positions) - max_length // 3) end = min(len(normalized), start + max_length) snippet = normalized[start:end].strip() if start > 0: snippet = f"...{snippet.lstrip()}" if end < len(normalized): snippet = f"{snippet.rstrip()}..." return snippet def _truncate_text(text: str, *, max_length: int) -> str: normalized = str(text or "").strip() if len(normalized) <= max_length: return normalized return f"{normalized[: max_length - 3].rstrip()}..." def _resolve_default_qdrant_url() -> str: if _hostname_resolves("qdrant"): return CONTAINER_QDRANT_URL return DEFAULT_QDRANT_URL def _hostname_resolves(hostname: str) -> bool: try: socket.getaddrinfo(hostname, None) except OSError: return False return True def _extract_query_terms(query: str) -> list[str]: normalized_query = str(query or "").strip().lower() if not normalized_query: return [] terms: list[str] = [] seen: set[str] = set() def remember(term: str) -> None: normalized_term = str(term or "").strip().lower() if ( not normalized_term or normalized_term in seen or normalized_term in QUERY_TERM_STOPWORDS or len(normalized_term) < 2 ): return seen.add(normalized_term) terms.append(normalized_term) for item in re.findall(r"[a-z0-9][a-z0-9_\-]{1,}", normalized_query): remember(item) for block in re.findall(r"[\u4e00-\u9fff]{2,20}", normalized_query): if len(block) <= 4: remember(block) continue for size in (4, 3, 2): for start in range(0, len(block) - size + 1): remember(block[start : start + size]) if len(terms) >= MAX_QUERY_TERMS: return terms return terms[:MAX_QUERY_TERMS] def _score_knowledge_hit( item: dict[str, Any], *, query_terms: list[str], prefers_tabular_evidence: bool, ) -> int: rank = max(1, int(item.get("_rank") or 1)) title = str(item.get("title") or item.get("document_name") or "").lower() content = str(item.get("content") or "").lower() excerpt = str(item.get("excerpt") or "").lower() tags = " ".join(str(value).lower() for value in list(item.get("tags") or [])[:5]) haystack = "\n".join([title, excerpt, tags, content[:1200]]) score = max(1, 120 - rank * 4) matched_terms = [term for term in query_terms if term in haystack] score += len(matched_terms) * 8 score += sum(1 for term in matched_terms if term in title) * 6 leading_appendix_marker = _leading_structured_appendix_marker(content) if leading_appendix_marker == "# 章节导航": score -= 24 elif leading_appendix_marker == "# 重点章节摘录": score += 4 if matched_terms else -12 elif leading_appendix_marker == "# 问答线索补充": score += ( 8 if matched_terms and not prefers_tabular_evidence else 2 if matched_terms else -20 ) elif leading_appendix_marker == "# 结构化表格补充": if prefers_tabular_evidence and matched_terms: score += 16 elif matched_terms: score += 6 else: score -= 18 if prefers_tabular_evidence and matched_terms and ("|" in content or "表" in content): score += 10 if matched_terms and any(marker in content for marker in (":", ":")): score += 10 if matched_terms and "\n" in content: score += 4 if matched_terms and any(marker in content for marker in ("附表", "第", "条")): score += 4 if ( not prefers_tabular_evidence and matched_terms and any(marker in content for marker in ("第", "条", ":", "-", "•")) ): score += 4 if title and any(term in title for term in query_terms): score += 6 if re.search(r"没有.{0,8}(信息|规定|说明|依据)", content): score -= 12 return score def _leading_structured_appendix_marker(content: str) -> str: normalized = str(content or "").lstrip() for marker in STRUCTURED_APPENDIX_LEADING_MARKERS: index = normalized.find(marker) if 0 <= index <= STRUCTURED_APPENDIX_LEADING_WINDOW: return marker return ""