""" OpenSearch Storage Implementation for LightRAG This module provides OpenSearch-based storage backends for LightRAG, including KV storage, document status storage, graph storage, and vector storage. Requirements: - opensearch-py >= 3.0.0 - OpenSearch 3.x or higher with k-NN plugin enabled """ import os import re import ssl as ssl_module import time import asyncio from dataclasses import dataclass, field from typing import Any, AsyncIterator, Union, final import numpy as np import configparser from ..base import ( BaseGraphStorage, BaseKVStorage, BaseVectorStorage, DocProcessingStatus, DocStatus, DocStatusStorage, ) from ..utils import logger, compute_mdhash_id, _cooperative_yield from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..constants import GRAPH_FIELD_SEP from ..kg.shared_storage import get_data_init_lock import pipmaster as pm if not pm.is_installed("opensearch-py"): pm.install("opensearch-py") from opensearchpy import AsyncOpenSearch, helpers # type: ignore from opensearchpy.exceptions import OpenSearchException, NotFoundError, RequestError # type: ignore config = configparser.ConfigParser() config.read("config.ini", "utf-8") def _get_opensearch_env(key, fallback): cfg_key = key.replace("OPENSEARCH_", "").lower() return os.environ.get(key, config.get("opensearch", cfg_key, fallback=fallback)) def _get_index_number_of_shards() -> int: return int(_get_opensearch_env("OPENSEARCH_NUMBER_OF_SHARDS", "1")) def _get_index_number_of_replicas() -> int: return int(_get_opensearch_env("OPENSEARCH_NUMBER_OF_REPLICAS", "0")) def _sanitize_index_name(name: str) -> str: """Sanitize a string to be a valid OpenSearch index name.""" sanitized = re.sub(r"[^a-z0-9_-]", "_", name.lower()) if sanitized and sanitized[0] in "-_+": sanitized = "x" + sanitized return sanitized # Detected at first connection; True when OpenSearch >= 3.3.0. _shard_doc_supported: bool | None = None def _pit_sort_with_field(field: str) -> list[dict]: """Return PIT sort clause with a unique field as primary sort. Used purely as a pagination tiebreaker — order is fixed to asc since the business sort (when present) is applied separately by the caller. >= 3.3.0: _shard_doc only (most efficient, already unique within PIT). < 3.3.0: field + _doc (field is unique, _doc for efficiency). """ if _shard_doc_supported: return [{"_shard_doc": "asc"}] return [{field: {"order": "asc"}}, {"_doc": "asc"}] def _pit_sort_with_composite_key(*fields: str) -> list[dict]: """Return PIT sort clause with multiple fields forming a composite unique key. >= 3.3.0: _shard_doc (most efficient, ignores the fields). < 3.3.0: field1 + field2 + ... + _doc (composite is unique, _doc for efficiency). """ if _shard_doc_supported: return [{"_shard_doc": "asc"}] return [{f: {"order": "asc"}} for f in fields] + [{"_doc": "asc"}] async def _detect_shard_doc_support(client: AsyncOpenSearch) -> bool: """Check if the cluster supports _shard_doc (OpenSearch >= 3.3.0).""" try: info = await client.info() version_str = info.get("version", {}).get("number", "0.0.0") # Strip pre-release suffixes (e.g. "3.3.0-SNAPSHOT" → "3", "3", "0") parts = [p.split("-")[0] for p in version_str.split(".")] major = int(parts[0]) if parts[0].isdigit() else 0 minor = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 0 supported = (major > 3) or (major == 3 and minor >= 3) logger.info( f"OpenSearch version {version_str}: " f"_shard_doc {'supported' if supported else 'not supported, using field+_doc fallback'}" ) return supported except Exception as e: logger.warning( f"Failed to detect OpenSearch version, assuming _shard_doc not supported: {e}" ) return False class ClientManager: """Singleton manager for OpenSearch client connections.""" _instances = {"client": None, "ref_count": 0} _lock = asyncio.Lock() @classmethod async def get_client(cls) -> AsyncOpenSearch: """Get or create a shared AsyncOpenSearch client with reference counting.""" global _shard_doc_supported async with cls._lock: if cls._instances["client"] is None: hosts_str = _get_opensearch_env("OPENSEARCH_HOSTS", "localhost:9200") hosts = [h.strip() for h in hosts_str.split(",") if h.strip()] username = _get_opensearch_env("OPENSEARCH_USER", "admin") password = _get_opensearch_env("OPENSEARCH_PASSWORD", "admin") use_ssl = _get_opensearch_env("OPENSEARCH_USE_SSL", "true").lower() in ( "true", "1", "yes", ) verify_certs = _get_opensearch_env( "OPENSEARCH_VERIFY_CERTS", "false" ).lower() in ("true", "1", "yes") timeout = int(_get_opensearch_env("OPENSEARCH_TIMEOUT", "30")) max_retries = int(_get_opensearch_env("OPENSEARCH_MAX_RETRIES", "3")) ssl_context = None if use_ssl and not verify_certs: ssl_context = ssl_module.create_default_context() ssl_context.check_hostname = False ssl_context.verify_mode = ssl_module.CERT_NONE client = AsyncOpenSearch( hosts=hosts, http_auth=(username, password) if username else None, use_ssl=use_ssl, verify_certs=verify_certs, ssl_context=ssl_context, ssl_show_warn=False, timeout=timeout, max_retries=max_retries, retry_on_timeout=True, ) cls._instances["client"] = client cls._instances["ref_count"] = 0 _shard_doc_supported = await _detect_shard_doc_support(client) logger.info(f"OpenSearch client connected to {hosts}") cls._instances["ref_count"] += 1 return cls._instances["client"] @classmethod async def release_client(cls, client: AsyncOpenSearch): """Release a client reference. Closes the connection when ref count reaches 0.""" global _shard_doc_supported async with cls._lock: if client is not None and client is cls._instances["client"]: cls._instances["ref_count"] -= 1 if cls._instances["ref_count"] <= 0: try: await cls._instances["client"].close() except Exception: pass cls._instances["client"] = None cls._instances["ref_count"] = 0 _shard_doc_supported = None logger.info("OpenSearch client connection closed") def _resolve_workspace(workspace: str, namespace: str): """Resolve effective workspace from env or parameter.""" opensearch_workspace = os.environ.get("OPENSEARCH_WORKSPACE") if opensearch_workspace and opensearch_workspace.strip(): effective = opensearch_workspace.strip() logger.info( f"Using OPENSEARCH_WORKSPACE: '{effective}' (overriding '{workspace}/{namespace}')" ) return effective return workspace def _build_index_name(workspace: str, namespace: str) -> tuple[str, str, str]: """Build index name and return (effective_workspace, final_namespace, index_name).""" effective = _resolve_workspace(workspace, namespace) if effective: final_ns = f"{effective}_{namespace}" else: final_ns = namespace effective = "" index_name = _sanitize_index_name(final_ns) return effective, final_ns, index_name async def _mget_optional_doc( client: AsyncOpenSearch, index_name: str, doc_id: str ) -> dict[str, Any] | None: """Fetch a single document via mget and return None when it is absent.""" response = await client.mget(index=index_name, body={"ids": [doc_id]}) docs = response.get("docs", []) if not docs: return None doc = docs[0] if not doc.get("found"): return None return doc def _is_missing_index_error(exc: Exception) -> bool: """Return True when an OpenSearch exception means the target index is missing.""" return "index_not_found_exception" in str(exc) async def _verify_mirrored_id_mapping(client: AsyncOpenSearch, index_name: str) -> None: """Fail-fast when an existing index lacks the __mirrored_id keyword mapping. Only enforced on OpenSearch < 3.3.0, where __mirrored_id serves as the cross-shard pagination tiebreaker. Indices created by older LightRAG releases will be missing this mapping; sorting by a missing field on a multi-shard index can drop or duplicate documents during PIT pagination. """ if _shard_doc_supported: return try: mapping = await client.indices.get_mapping(index=index_name) except OpenSearchException: return props = mapping.get(index_name, {}).get("mappings", {}).get("properties", {}) if "__mirrored_id" not in props: raise RuntimeError( f"Index '{index_name}' lacks the '__mirrored_id' keyword mapping " f"required for stable PIT pagination on OpenSearch < 3.3.0. " f"This index was likely created by an older LightRAG release. " f"Please reindex the data, or upgrade the cluster to OpenSearch >= 3.3.0." ) @final @dataclass class OpenSearchKVStorage(BaseKVStorage): """Key-Value storage using OpenSearch. Uses dynamic mapping to support varied schemas.""" client: AsyncOpenSearch = field(default=None) _index_name: str = field(default="", init=False) _index_ready: bool = field(default=False, init=False) def __init__(self, namespace, global_config, embedding_func, workspace=None): super().__init__( namespace=namespace, workspace=workspace or "", global_config=global_config, embedding_func=embedding_func, ) self.__post_init__() def __post_init__(self): self.workspace, self.final_namespace, self._index_name = _build_index_name( self.workspace, self.namespace ) async def initialize(self): """Initialize client connection and create index if needed.""" async with get_data_init_lock(): if self.client is None: self.client = await ClientManager.get_client() await self._create_index_if_not_exists() self._index_ready = True logger.debug( f"[{self.workspace}] OpenSearch KV storage initialized: {self._index_name}" ) async def _ensure_index_ready(self): """Recreate the KV index after drop before the next write.""" if self._index_ready: return async with get_data_init_lock(): if self.client is None: self.client = await ClientManager.get_client() if not self._index_ready: await self._create_index_if_not_exists() self._index_ready = True def _mark_index_missing(self): """Mark the KV index as unavailable for subsequent read short-circuiting.""" self._index_ready = False async def _create_index_if_not_exists(self): try: if not await self.client.indices.exists(index=self._index_name): # Use dynamic mapping so any namespace schema works body = { "mappings": { "dynamic": True, "properties": { "__mirrored_id": {"type": "keyword"}, }, }, "settings": { "index": { "number_of_shards": _get_index_number_of_shards(), "number_of_replicas": _get_index_number_of_replicas(), }, }, } await self.client.indices.create(index=self._index_name, body=body) logger.info(f"[{self.workspace}] Created index: {self._index_name}") else: await _verify_mirrored_id_mapping(self.client, self._index_name) except RequestError as e: if "resource_already_exists_exception" not in str(e): raise except OpenSearchException as e: logger.error(f"[{self.workspace}] Error creating index: {e}") raise async def finalize(self): """Release the OpenSearch client connection.""" if self.client is not None: await ClientManager.release_client(self.client) self.client = None async def _iter_raw_docs( self, batch_size: int = 1000 ) -> AsyncIterator[list[dict[str, Any]]]: """Yield raw OpenSearch hits using PIT + search_after pagination.""" if not self._index_ready: return try: pit = await self.client.create_pit( index=self._index_name, params={"keep_alive": "1m"} ) pit_id = pit["pit_id"] try: search_after = None while True: body = { "query": {"match_all": {}}, "size": batch_size, "pit": {"id": pit_id, "keep_alive": "1m"}, "sort": _pit_sort_with_field("__mirrored_id"), } if search_after: body["search_after"] = search_after response = await self.client.search(body=body) hits = response["hits"]["hits"] if not hits: break yield hits search_after = hits[-1]["sort"] if len(hits) < batch_size: break finally: try: await self.client.delete_pit(body={"pit_id": [pit_id]}) except Exception: pass except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return logger.error(f"[{self.workspace}] Error scanning documents: {e}") raise async def get_by_id(self, id: str) -> dict[str, Any] | None: """Get a document by its ID, or None if not found.""" if not self._index_ready: return None try: response = await _mget_optional_doc(self.client, self._index_name, id) if response is None: return None doc = response["_source"] doc.pop("__mirrored_id", None) doc["_id"] = response["_id"] doc.setdefault("create_time", 0) doc.setdefault("update_time", 0) return doc except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return None logger.error(f"[{self.workspace}] Error getting document {id}: {e}") return None async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: """Get multiple documents by IDs, preserving input order.""" if not self._index_ready: return [None] * len(ids) try: response = await self.client.mget(index=self._index_name, body={"ids": ids}) doc_map = {} for doc in response["docs"]: if doc.get("found"): data = doc["_source"] data.pop("__mirrored_id", None) data["_id"] = doc["_id"] data.setdefault("create_time", 0) data.setdefault("update_time", 0) doc_map[doc["_id"]] = data return [doc_map.get(id) for id in ids] except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return [None] * len(ids) logger.error(f"[{self.workspace}] Error getting documents: {e}") return [None] * len(ids) async def filter_keys(self, keys: set[str]) -> set[str]: """Return the subset of keys that do not exist in storage.""" if not self._index_ready: return keys try: response = await self.client.mget( index=self._index_name, body={"ids": list(keys)}, _source=False ) existing_ids = {doc["_id"] for doc in response["docs"] if doc.get("found")} return keys - existing_ids except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return keys logger.error(f"[{self.workspace}] Error filtering keys: {e}") return keys async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """Insert or update documents with automatic timestamping.""" if not data: return await self._ensure_index_ready() logger.debug( f"[{self.workspace}] Upserting {len(data)} documents to {self.namespace}" ) current_time = int(time.time()) actions = [] for i, (doc_id, doc_data) in enumerate(data.items(), start=1): doc_data["update_time"] = current_time doc_data.setdefault("create_time", current_time) source = {k: v for k, v in doc_data.items() if k != "_id"} source["__mirrored_id"] = doc_id actions.append( { "_op_type": "index", "_index": self._index_name, "_id": doc_id, "_source": source, } ) await _cooperative_yield(i) try: # No per-operation refresh: immediate reads use ID-based mget (translog), # search visibility is guaranteed after index_done_callback() batch refresh. success, failed = await helpers.async_bulk( self.client, actions, raise_on_error=False ) if failed: logger.warning( f"[{self.workspace}] {len(failed)} documents failed to upsert" ) except OpenSearchException as e: logger.error(f"[{self.workspace}] Error upserting documents: {e}") raise async def index_done_callback(self) -> None: """Refresh index to make recently indexed documents searchable.""" if not self._index_ready: return try: await self.client.indices.refresh(index=self._index_name) except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return except Exception: pass async def is_empty(self) -> bool: """Return True if the index contains no documents.""" if not self._index_ready: return True try: response = await self.client.count(index=self._index_name) return response["count"] == 0 except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return True async def delete(self, ids: list[str]) -> None: """Delete documents by their IDs.""" if not ids: return if not self._index_ready: return if isinstance(ids, set): ids = list(ids) try: # No per-operation refresh: immediate reads use ID-based mget (translog), # search visibility is guaranteed after index_done_callback() batch refresh. actions = [ {"_op_type": "delete", "_index": self._index_name, "_id": doc_id} for doc_id in ids ] success, _ = await helpers.async_bulk( self.client, actions, raise_on_error=False ) logger.info( f"[{self.workspace}] Deleted {success} documents from {self.namespace}" ) except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return logger.error(f"[{self.workspace}] Error deleting documents: {e}") async def drop(self) -> dict[str, str]: """Delete the entire index.""" try: try: await self.client.indices.delete(index=self._index_name) logger.info(f"[{self.workspace}] Dropped index: {self._index_name}") except NotFoundError: logger.info( f"[{self.workspace}] Index already missing during drop: {self._index_name}" ) self._mark_index_missing() return {"status": "success", "message": f"Index {self._index_name} dropped"} except OpenSearchException as e: self._mark_index_missing() logger.error(f"[{self.workspace}] Error dropping index: {e}") return {"status": "error", "message": str(e)} except Exception as e: self._mark_index_missing() logger.error(f"[{self.workspace}] Unexpected error dropping index: {e}") return {"status": "error", "message": str(e)} @final @dataclass class OpenSearchDocStatusStorage(DocStatusStorage): """Document status storage using OpenSearch.""" client: AsyncOpenSearch = field(default=None) _index_name: str = field(default="", init=False) _index_ready: bool = field(default=False, init=False) def __init__(self, namespace, global_config, embedding_func, workspace=None): super().__init__( namespace=namespace, workspace=workspace or "", global_config=global_config, embedding_func=embedding_func, ) self.__post_init__() def __post_init__(self): self.workspace, self.final_namespace, self._index_name = _build_index_name( self.workspace, self.namespace ) def _prepare_doc_status_data(self, doc: dict[str, Any]) -> dict[str, Any]: """Normalize a raw OpenSearch document to DocProcessingStatus-compatible dict.""" data = doc.copy() data.pop("_id", None) data.pop("__mirrored_id", None) if "file_path" not in data: data["file_path"] = "no-file-path" data.setdefault("metadata", {}) data.setdefault("error_msg", None) if "error" in data: if not data.get("error_msg"): data["error_msg"] = data.pop("error") else: data.pop("error", None) return data async def initialize(self): """Initialize client connection and create doc status index.""" async with get_data_init_lock(): if self.client is None: self.client = await ClientManager.get_client() await self._create_index_if_not_exists() self._index_ready = True logger.debug( f"[{self.workspace}] OpenSearch DocStatus storage initialized: {self._index_name}" ) async def _ensure_index_ready(self): """Recreate the doc status index after drop before the next write.""" if self._index_ready: return async with get_data_init_lock(): if self.client is None: self.client = await ClientManager.get_client() if not self._index_ready: await self._create_index_if_not_exists() self._index_ready = True def _mark_index_missing(self): """Mark the doc status index as unavailable for subsequent read short-circuiting.""" self._index_ready = False async def _create_index_if_not_exists(self): try: if not await self.client.indices.exists(index=self._index_name): body = { "mappings": { "dynamic": True, "properties": { "__mirrored_id": {"type": "keyword"}, "status": {"type": "keyword"}, "file_path": {"type": "keyword"}, "track_id": {"type": "keyword"}, "created_at": {"type": "date"}, "updated_at": {"type": "date"}, }, }, "settings": { "index": { "number_of_shards": _get_index_number_of_shards(), "number_of_replicas": _get_index_number_of_replicas(), }, }, } await self.client.indices.create(index=self._index_name, body=body) logger.info( f"[{self.workspace}] Created doc status index: {self._index_name}" ) else: await _verify_mirrored_id_mapping(self.client, self._index_name) except RequestError as e: if "resource_already_exists_exception" not in str(e): raise except OpenSearchException as e: logger.error(f"[{self.workspace}] Error creating doc status index: {e}") raise async def finalize(self): """Release the OpenSearch client connection.""" if self.client is not None: await ClientManager.release_client(self.client) self.client = None async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: """Get a document status record by ID.""" if not self._index_ready: return None try: response = await _mget_optional_doc(self.client, self._index_name, id) if response is None: return None doc = response["_source"] doc["_id"] = response["_id"] return doc except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return None logger.error(f"[{self.workspace}] Error getting doc status {id}: {e}") return None async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: """Get multiple document status records by IDs.""" if not self._index_ready: return [None] * len(ids) try: response = await self.client.mget(index=self._index_name, body={"ids": ids}) doc_map = {} for doc in response["docs"]: if doc.get("found"): data = doc["_source"] data["_id"] = doc["_id"] doc_map[doc["_id"]] = data return [doc_map.get(id) for id in ids] except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return [None] * len(ids) logger.error(f"[{self.workspace}] Error getting doc statuses: {e}") return [None] * len(ids) async def filter_keys(self, keys: set[str]) -> set[str]: """Return the subset of keys that do not exist in storage.""" if not self._index_ready: return keys try: response = await self.client.mget( index=self._index_name, body={"ids": list(keys)}, _source=False ) existing_ids = {doc["_id"] for doc in response["docs"] if doc.get("found")} return keys - existing_ids except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return keys logger.error(f"[{self.workspace}] Error filtering keys: {e}") return keys async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """Insert or update document status records.""" if not data: return await self._ensure_index_ready() logger.debug(f"[{self.workspace}] Upserting {len(data)} doc statuses") actions = [] for i, (k, v) in enumerate(data.items(), start=1): v.setdefault("chunks_list", []) source = {fk: fv for fk, fv in v.items() if fk != "_id"} source["__mirrored_id"] = k actions.append( { "_op_type": "index", "_index": self._index_name, "_id": k, "_source": source, } ) await _cooperative_yield(i) try: # DocStatus needs refresh="wait_for" because get_docs_by_status # (search-based) is called immediately after enqueue upserts. await helpers.async_bulk( self.client, actions, raise_on_error=False, refresh="wait_for" ) except OpenSearchException as e: logger.error(f"[{self.workspace}] Error upserting doc statuses: {e}") async def get_status_counts(self) -> dict[str, int]: """Get document counts grouped by status.""" if not self._index_ready: return {} try: body = { "size": 0, "aggs": {"status_counts": {"terms": {"field": "status", "size": 100}}}, } response = await self.client.search(index=self._index_name, body=body) return { bucket["key"]: bucket["doc_count"] for bucket in response["aggregations"]["status_counts"]["buckets"] } except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return {} logger.error(f"[{self.workspace}] Error getting status counts: {e}") return {} async def _search_all_docs(self, query: dict) -> dict[str, DocProcessingStatus]: """Fetch all documents matching a query using PIT + search_after.""" if not self._index_ready: return {} result = {} batch_size = 10000 try: pit = await self.client.create_pit( index=self._index_name, params={"keep_alive": "1m"} ) pit_id = pit["pit_id"] try: search_after = None while True: body = { "query": query, "size": batch_size, "pit": {"id": pit_id, "keep_alive": "1m"}, "sort": _pit_sort_with_field("__mirrored_id"), } if search_after: body["search_after"] = search_after response = await self.client.search(body=body) hits = response["hits"]["hits"] if not hits: break for hit in hits: try: data = self._prepare_doc_status_data(hit["_source"]) result[hit["_id"]] = DocProcessingStatus(**data) except (KeyError, TypeError) as e: logger.error( f"[{self.workspace}] Error parsing doc {hit['_id']}: {e}" ) search_after = hits[-1]["sort"] if len(hits) < batch_size: break finally: try: await self.client.delete_pit(body={"pit_id": [pit_id]}) except Exception: pass except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return {} logger.error(f"[{self.workspace}] Error fetching docs: {e}") return result async def get_docs_by_status( self, status: DocStatus ) -> dict[str, DocProcessingStatus]: """Get all documents matching a specific processing status.""" return await self.get_docs_by_statuses([status]) async def get_docs_by_statuses( self, statuses: list[DocStatus] ) -> dict[str, DocProcessingStatus]: """Get all documents matching any of the given statuses in a single query. Uses OpenSearch's terms query (multi-value equivalent of term) to fetch all matching statuses in one PIT + search_after pass instead of one full scan per status. """ if not statuses: return {} status_values = [s.value for s in statuses] return await self._search_all_docs({"terms": {"status": status_values}}) async def get_docs_by_track_id( self, track_id: str ) -> dict[str, DocProcessingStatus]: """Get all documents matching a specific track ID.""" return await self._search_all_docs({"term": {"track_id": track_id}}) async def get_docs_paginated( self, status_filter: DocStatus | None = None, page: int = 1, page_size: int = 50, sort_field: str = "updated_at", sort_direction: str = "desc", ) -> tuple[list[tuple[str, DocProcessingStatus]], int]: """Get documents with pagination using PIT + search_after.""" if not self._index_ready: return [], 0 page = max(1, page) page_size = max(10, min(200, page_size)) if sort_field == "id": sort_field = "_id" if sort_field not in ("created_at", "updated_at", "_id", "file_path"): sort_field = "updated_at" sort_order = "asc" if sort_direction.lower() == "asc" else "desc" query = {"match_all": {}} if status_filter is not None: query = {"term": {"status": status_filter.value}} skip_count = (page - 1) * page_size try: count_resp = await self.client.count( index=self._index_name, body={"query": query} ) total_count = count_resp.get("count", 0) if total_count == 0 or skip_count >= total_count: return [], total_count sort_clause = [{sort_field: {"order": sort_order}}] + _pit_sort_with_field( "__mirrored_id" ) pit = await self.client.create_pit( index=self._index_name, params={"keep_alive": "1m"} ) pit_id = pit["pit_id"] try: search_after = None skipped = 0 while skipped < skip_count: batch = min(page_size, skip_count - skipped) body = { "query": query, "sort": sort_clause, "size": batch, "pit": {"id": pit_id, "keep_alive": "1m"}, } if search_after: body["search_after"] = search_after resp = await self.client.search(body=body) hits = resp["hits"]["hits"] if not hits: return [], total_count search_after = hits[-1]["sort"] skipped += len(hits) body = { "query": query, "sort": sort_clause, "size": page_size, "pit": {"id": pit_id, "keep_alive": "1m"}, } if search_after: body["search_after"] = search_after response = await self.client.search(body=body) finally: try: await self.client.delete_pit(body={"pit_id": [pit_id]}) except Exception: pass documents = [] for hit in response["hits"]["hits"]: try: data = self._prepare_doc_status_data(hit["_source"]) documents.append((hit["_id"], DocProcessingStatus(**data))) except (KeyError, TypeError) as e: logger.error( f"[{self.workspace}] Error parsing doc {hit['_id']}: {e}" ) return documents, total_count except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return [], 0 logger.error(f"[{self.workspace}] Error in paginated query: {e}") return [], 0 async def get_all_status_counts(self) -> dict[str, int]: """Get document counts for all statuses including an 'all' total.""" if not self._index_ready: return {} try: body = { "size": 0, "aggs": {"status_counts": {"terms": {"field": "status", "size": 100}}}, } response = await self.client.search(index=self._index_name, body=body) counts = {} total = 0 for bucket in response["aggregations"]["status_counts"]["buckets"]: counts[bucket["key"]] = bucket["doc_count"] total += bucket["doc_count"] counts["all"] = total return counts except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return {} logger.error(f"[{self.workspace}] Error getting all status counts: {e}") return {} async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]: """Find a document status record by its file_path field.""" if not self._index_ready: return None try: body = {"query": {"term": {"file_path": file_path}}, "size": 1} response = await self.client.search(index=self._index_name, body=body) hits = response["hits"]["hits"] if hits: doc = hits[0]["_source"] doc["_id"] = hits[0]["_id"] return doc return None except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return None logger.error(f"[{self.workspace}] Error getting doc by file_path: {e}") return None async def index_done_callback(self) -> None: """Refresh index to make recently indexed documents searchable.""" if not self._index_ready: return try: await self.client.indices.refresh(index=self._index_name) except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return except Exception: pass async def is_empty(self) -> bool: """Return True if the index contains no documents.""" if not self._index_ready: return True try: response = await self.client.count(index=self._index_name) return response["count"] == 0 except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return True async def delete(self, ids: list[str]) -> None: """Delete document status records by IDs.""" if not ids: return if not self._index_ready: return if isinstance(ids, set): ids = list(ids) try: # DocStatus needs refresh="wait_for" because downstream readers # (get_docs_by_status, get_docs_paginated, etc.) are search-based # and callers like _validate_and_fix_document_consistency() may # query immediately after deletion without index_done_callback(). actions = [ {"_op_type": "delete", "_index": self._index_name, "_id": doc_id} for doc_id in ids ] await helpers.async_bulk( self.client, actions, raise_on_error=False, refresh="wait_for" ) except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return logger.error(f"[{self.workspace}] Error deleting doc statuses: {e}") async def drop(self) -> dict[str, str]: """Delete the entire doc status index.""" try: try: await self.client.indices.delete(index=self._index_name) logger.info( f"[{self.workspace}] Dropped doc status index: {self._index_name}" ) except NotFoundError: logger.info( f"[{self.workspace}] Doc status index already missing during drop: {self._index_name}" ) self._mark_index_missing() return {"status": "success", "message": f"Index {self._index_name} dropped"} except OpenSearchException as e: self._mark_index_missing() logger.error(f"[{self.workspace}] Error dropping doc status index: {e}") return {"status": "error", "message": str(e)} except Exception as e: self._mark_index_missing() logger.error( f"[{self.workspace}] Unexpected error dropping doc status index: {e}" ) return {"status": "error", "message": str(e)} @final @dataclass class OpenSearchGraphStorage(BaseGraphStorage): """Graph storage using OpenSearch with separate nodes and edges indices. Supports two BFS traversal strategies: - PPL graphlookup (server-side BFS, requires OpenSearch SQL plugin with Calcite engine) - Application-level batched BFS (fallback, works on any OpenSearch 3.x+) The strategy is auto-detected during initialize() and can be overridden via the OPENSEARCH_USE_PPL_GRAPHLOOKUP environment variable (true/false). """ client: AsyncOpenSearch = field(default=None) _nodes_index: str = field(default="", init=False) _edges_index: str = field(default="", init=False) _indices_ready: bool = field(default=False, init=False) _nodes_dirty: bool = field(default=False, init=False) _edges_dirty: bool = field(default=False, init=False) _ppl_graphlookup_available: bool = field(default=False, init=False) def __init__(self, namespace, global_config, embedding_func, workspace=None): super().__init__( namespace=namespace, workspace=workspace or "", global_config=global_config, embedding_func=embedding_func, ) self.__post_init__() def __post_init__(self): self.workspace, self.final_namespace, base_name = _build_index_name( self.workspace, self.namespace ) self._nodes_index = f"{base_name}-nodes" self._edges_index = f"{base_name}-edges" async def initialize(self): """Initialize client, create indices, and detect PPL graphlookup support.""" async with get_data_init_lock(): if self.client is None: self.client = await ClientManager.get_client() await self._create_indices_if_not_exist() self._indices_ready = True self._nodes_dirty = False self._edges_dirty = False await self._detect_ppl_graphlookup() logger.debug( f"[{self.workspace}] OpenSearch Graph storage initialized: " f"{self._nodes_index}, {self._edges_index} " f"(PPL graphlookup: {self._ppl_graphlookup_available})" ) async def _ensure_indices_ready(self): """Recreate graph indices after drop before the next write.""" if self._indices_ready: return async with get_data_init_lock(): if self.client is None: self.client = await ClientManager.get_client() if not self._indices_ready: await self._create_indices_if_not_exist() self._indices_ready = True def _mark_indices_missing(self): """Mark graph indices as unavailable for subsequent read short-circuiting.""" self._indices_ready = False self._nodes_dirty = False self._edges_dirty = False async def _refresh_graph_indices_if_dirty( self, *, refresh_nodes: bool = False, refresh_edges: bool = False ) -> None: """Refresh graph indices only when prior writes made search views stale.""" if not self._indices_ready: return if not ( (refresh_nodes and self._nodes_dirty) or (refresh_edges and self._edges_dirty) ): return try: async with get_data_init_lock(): if refresh_nodes and self._nodes_dirty: await self.client.indices.refresh(index=self._nodes_index) self._nodes_dirty = False if refresh_edges and self._edges_dirty: await self.client.indices.refresh(index=self._edges_index) self._edges_dirty = False except OpenSearchException as e: if _is_missing_index_error(e): self._mark_indices_missing() return raise async def _detect_ppl_graphlookup(self): """Detect whether PPL graphlookup command is available on this cluster.""" env_override = os.environ.get("OPENSEARCH_USE_PPL_GRAPHLOOKUP", "").lower() if env_override == "true": self._ppl_graphlookup_available = True return if env_override == "false": self._ppl_graphlookup_available = False return # Auto-detect by sending a minimal PPL query try: await self.client.transport.perform_request( "POST", "/_plugins/_ppl", body={"query": f"source = {self._edges_index} | head 0"}, ) # PPL endpoint works; now test graphlookup syntax with a no-op query await self.client.transport.perform_request( "POST", "/_plugins/_ppl", body={ "query": ( f"source = {self._edges_index} | head 1 " f"| graphLookup {self._edges_index} " f"start=source_node_id edge=target_node_id-->source_node_id " f"maxDepth=0 as _gl_probe" ) }, ) self._ppl_graphlookup_available = True logger.info( f"[{self.workspace}] PPL graphlookup is available, using server-side BFS" ) except Exception: self._ppl_graphlookup_available = False logger.info( f"[{self.workspace}] PPL graphlookup not available, using client-side BFS" ) async def _create_indices_if_not_exist(self): try: if not await self.client.indices.exists(index=self._nodes_index): body = { "mappings": { "dynamic": True, "properties": { "entity_id": {"type": "keyword"}, "entity_type": {"type": "keyword"}, "description": {"type": "text"}, "source_id": {"type": "text"}, "source_ids": {"type": "keyword"}, "file_path": {"type": "keyword"}, "created_at": {"type": "long"}, }, }, "settings": { "index": { "number_of_shards": _get_index_number_of_shards(), "number_of_replicas": _get_index_number_of_replicas(), } }, } await self.client.indices.create(index=self._nodes_index, body=body) logger.info( f"[{self.workspace}] Created nodes index: {self._nodes_index}" ) except RequestError as e: if "resource_already_exists_exception" not in str(e): raise try: if not await self.client.indices.exists(index=self._edges_index): body = { "mappings": { "dynamic": True, "properties": { "source_node_id": {"type": "keyword"}, "target_node_id": {"type": "keyword"}, "relationship": {"type": "keyword"}, "description": {"type": "text"}, "weight": {"type": "float"}, "keywords": {"type": "text"}, "source_id": {"type": "text"}, "source_ids": {"type": "keyword"}, "file_path": {"type": "keyword"}, "created_at": {"type": "long"}, }, }, "settings": { "index": { "number_of_shards": _get_index_number_of_shards(), "number_of_replicas": _get_index_number_of_replicas(), } }, } await self.client.indices.create(index=self._edges_index, body=body) logger.info( f"[{self.workspace}] Created edges index: {self._edges_index}" ) except RequestError as e: if "resource_already_exists_exception" not in str(e): raise async def finalize(self): """Release the OpenSearch client connection.""" if self.client is not None: await ClientManager.release_client(self.client) self.client = None # --- Basic queries --- async def has_node(self, node_id: str) -> bool: """Check whether a node exists in the graph.""" if not self._indices_ready: return False try: return await self.client.exists(index=self._nodes_index, id=node_id) except OpenSearchException as e: if _is_missing_index_error(e): self._mark_indices_missing() return False async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: """Check whether an edge exists between two nodes (bidirectional). Uses mget with the two candidate edge IDs so the check is real-time (translog-backed), consistent with has_node() and independent of the index refresh cycle. """ if not self._indices_ready: return False try: forward_id = compute_mdhash_id( f"{source_node_id}-{target_node_id}", prefix="edge-" ) reverse_id = compute_mdhash_id( f"{target_node_id}-{source_node_id}", prefix="edge-" ) response = await self.client.mget( index=self._edges_index, body={"ids": [forward_id, reverse_id]} ) return any(doc.get("found") for doc in response.get("docs", [])) except OpenSearchException as e: if _is_missing_index_error(e): self._mark_indices_missing() return False async def node_degree(self, node_id: str) -> int: """Count the number of edges connected to a node.""" if not self._indices_ready: return 0 try: await self._refresh_graph_indices_if_dirty(refresh_edges=True) response = await self.client.count( index=self._edges_index, body={ "query": { "bool": { "should": [ {"term": {"source_node_id": node_id}}, {"term": {"target_node_id": node_id}}, ] } } }, ) return response.get("count", 0) except OpenSearchException as e: if _is_missing_index_error(e): self._mark_indices_missing() return 0 async def edge_degree(self, src_id: str, tgt_id: str) -> int: """Sum of degrees of both endpoint nodes.""" src_degree = await self.node_degree(src_id) tgt_degree = await self.node_degree(tgt_id) return src_degree + tgt_degree async def get_node(self, node_id: str) -> dict[str, str] | None: """Get a node document by ID, or None if not found.""" if not self._indices_ready: return None try: response = await _mget_optional_doc(self.client, self._nodes_index, node_id) if response is None: return None doc = response["_source"] doc["_id"] = response["_id"] return doc except OpenSearchException as e: if _is_missing_index_error(e): self._mark_indices_missing() return None async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: """Get an edge between two nodes (bidirectional), or None. Uses mget with the two candidate edge IDs so the read is real-time (translog-backed), consistent with get_node() and independent of the index refresh cycle. """ if not self._indices_ready: return None try: forward_id = compute_mdhash_id( f"{source_node_id}-{target_node_id}", prefix="edge-" ) reverse_id = compute_mdhash_id( f"{target_node_id}-{source_node_id}", prefix="edge-" ) response = await self.client.mget( index=self._edges_index, body={"ids": [forward_id, reverse_id]} ) for doc in response.get("docs", []): if doc.get("found"): result = doc["_source"] result["_id"] = doc["_id"] return result return None except OpenSearchException as e: if _is_missing_index_error(e): self._mark_indices_missing() return None async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """Get all (source, target) edge tuples connected to a node.""" if not self._indices_ready: return None try: await self._refresh_graph_indices_if_dirty(refresh_edges=True) query = { "bool": { "should": [ {"term": {"source_node_id": source_node_id}}, {"term": {"target_node_id": source_node_id}}, ] } } edges = [] pit = await self.client.create_pit( index=self._edges_index, params={"keep_alive": "1m"} ) pit_id = pit["pit_id"] try: search_after = None while True: body = { "query": query, "_source": ["source_node_id", "target_node_id"], "size": 10000, "pit": {"id": pit_id, "keep_alive": "1m"}, "sort": _pit_sort_with_composite_key( "source_node_id", "target_node_id" ), } if search_after: body["search_after"] = search_after response = await self.client.search(body=body) hits = response["hits"]["hits"] if not hits: break for hit in hits: edges.append( ( hit["_source"]["source_node_id"], hit["_source"]["target_node_id"], ) ) search_after = hits[-1]["sort"] if len(hits) < 10000: break finally: try: await self.client.delete_pit(body={"pit_id": [pit_id]}) except Exception: pass return edges except OpenSearchException as e: if _is_missing_index_error(e): self._mark_indices_missing() return None # --- Batch operations --- async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: """Batch-fetch multiple nodes by ID.""" if not self._indices_ready: return {} try: response = await self.client.mget( index=self._nodes_index, body={"ids": node_ids} ) result = {} for doc in response["docs"]: if doc.get("found"): data = doc["_source"] data["_id"] = doc["_id"] result[doc["_id"]] = data return result except OpenSearchException as e: if _is_missing_index_error(e): self._mark_indices_missing() return {} async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: """Batch-fetch edge counts for multiple nodes using aggregations.""" if not node_ids: return {} if not self._indices_ready: return {} try: await self._refresh_graph_indices_if_dirty(refresh_edges=True) # Use a single query with aggregations for both source and target body = { "size": 0, "query": { "bool": { "should": [ {"terms": {"source_node_id": node_ids}}, {"terms": {"target_node_id": node_ids}}, ] } }, "aggs": { "source_degrees": { "terms": { "field": "source_node_id", "size": len(node_ids) * 2, } }, "target_degrees": { "terms": { "field": "target_node_id", "size": len(node_ids) * 2, } }, }, } response = await self.client.search(index=self._edges_index, body=body) result = {} for bucket in response["aggregations"]["source_degrees"]["buckets"]: if bucket["key"] in node_ids: result[bucket["key"]] = ( result.get(bucket["key"], 0) + bucket["doc_count"] ) for bucket in response["aggregations"]["target_degrees"]["buckets"]: if bucket["key"] in node_ids: result[bucket["key"]] = ( result.get(bucket["key"], 0) + bucket["doc_count"] ) return result except OpenSearchException as e: if _is_missing_index_error(e): self._mark_indices_missing() return {} async def get_nodes_edges_batch( self, node_ids: list[str] ) -> dict[str, list[tuple[str, str]]]: """Batch-fetch edge tuples for multiple nodes.""" result = {nid: [] for nid in node_ids} if not self._indices_ready: return result try: await self._refresh_graph_indices_if_dirty(refresh_edges=True) query = { "bool": { "should": [ {"terms": {"source_node_id": node_ids}}, {"terms": {"target_node_id": node_ids}}, ] } } pit = await self.client.create_pit( index=self._edges_index, params={"keep_alive": "1m"} ) pit_id = pit["pit_id"] try: search_after = None while True: body = { "query": query, "_source": ["source_node_id", "target_node_id"], "size": 10000, "pit": {"id": pit_id, "keep_alive": "1m"}, "sort": _pit_sort_with_composite_key( "source_node_id", "target_node_id" ), } if search_after: body["search_after"] = search_after response = await self.client.search(body=body) hits = response["hits"]["hits"] if not hits: break for hit in hits: src = hit["_source"]["source_node_id"] tgt = hit["_source"]["target_node_id"] if src in result: result[src].append((src, tgt)) if tgt in result: result[tgt].append((src, tgt)) search_after = hits[-1]["sort"] if len(hits) < 10000: break finally: try: await self.client.delete_pit(body={"pit_id": [pit_id]}) except Exception: pass except OpenSearchException as e: if _is_missing_index_error(e): self._mark_indices_missing() pass return result # --- Upsert operations --- async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """Insert or update a node. Adds entity_id for PPL compatibility.""" try: await self._ensure_indices_ready() doc = {k: v for k, v in node_data.items() if k != "_id"} doc["entity_id"] = node_id if node_data.get("source_id", ""): doc["source_ids"] = node_data["source_id"].split(GRAPH_FIELD_SEP) # No per-operation refresh: node reads use ID-based mget/exists # (translog, real-time). Search visibility after index_done_callback(). await self.client.index(index=self._nodes_index, id=node_id, body=doc) self._nodes_dirty = True except OpenSearchException as e: logger.error(f"[{self.workspace}] Error upserting node {node_id}: {e}") async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: """Insert or update an edge with deterministic ID for bidirectional handling.""" try: await self._ensure_indices_ready() # Ensure source node exists (don't overwrite if it already has data) if not await self.has_node(source_node_id): await self.upsert_node(source_node_id, {}) doc = {k: v for k, v in edge_data.items() if k != "_id"} doc["source_node_id"] = source_node_id doc["target_node_id"] = target_node_id if edge_data.get("source_id", ""): doc["source_ids"] = edge_data["source_id"].split(GRAPH_FIELD_SEP) # Use a deterministic ID for the edge so upserts work edge_id = compute_mdhash_id( f"{source_node_id}-{target_node_id}", prefix="edge-" ) # Check if reverse edge exists reverse_id = compute_mdhash_id( f"{target_node_id}-{source_node_id}", prefix="edge-" ) try: if await self.client.exists(index=self._edges_index, id=reverse_id): edge_id = reverse_id except OpenSearchException: pass await self.client.index(index=self._edges_index, id=edge_id, body=doc) self._edges_dirty = True except OpenSearchException as e: logger.error( f"[{self.workspace}] Error upserting edge {source_node_id}->{target_node_id}: {e}" ) async def upsert_nodes_batch(self, nodes: list[tuple[str, dict[str, str]]]) -> None: """Batch insert/update multiple nodes using the OpenSearch bulk API. Args: nodes: List of (node_id, node_data) tuples. """ if not nodes: return try: await self._ensure_indices_ready() actions = [] for node_id, node_data in nodes: doc = {k: v for k, v in node_data.items() if k != "_id"} doc["entity_id"] = node_id if node_data.get("source_id", ""): doc["source_ids"] = node_data["source_id"].split(GRAPH_FIELD_SEP) actions.append( { "_op_type": "index", "_index": self._nodes_index, "_id": node_id, "_source": doc, } ) await helpers.async_bulk(self.client, actions) self._nodes_dirty = True except OpenSearchException as e: logger.error(f"[{self.workspace}] Error during batch node upsert: {e}") async def has_nodes_batch(self, node_ids: list[str]) -> set[str]: """Check existence of multiple nodes using a single mget request. Args: node_ids: List of node IDs to check. Returns: Set of node_ids that exist in the graph. """ if not node_ids: return set() if not self._indices_ready: return set() try: response = await self.client.mget( index=self._nodes_index, body={"ids": node_ids} ) return {doc["_id"] for doc in response.get("docs", []) if doc.get("found")} except OpenSearchException as e: if _is_missing_index_error(e): self._mark_indices_missing() return set() async def upsert_edges_batch( self, edges: list[tuple[str, str, dict[str, str]]] ) -> None: """Batch insert/update multiple edges using the OpenSearch bulk API. Replicates the bidirectional edge-ID logic of upsert_edge(): a canonical forward ID is used unless a reverse-direction document already exists, in which case the reverse ID is used so the update lands on the existing doc. The reverse-ID look-up is done in a single mget call before the bulk write. Args: edges: List of (source_node_id, target_node_id, edge_data) tuples. """ if not edges: return try: await self._ensure_indices_ready() # Ensure all source nodes exist (mirrors upsert_edge behaviour) source_ids = list({src for src, _tgt, _data in edges}) existing_sources = await self.has_nodes_batch(source_ids) missing_sources = [ (nid, {}) for nid in source_ids if nid not in existing_sources ] if missing_sources: await self.upsert_nodes_batch(missing_sources) # Compute forward and reverse edge IDs, then batch-check which # reverse-direction docs already exist (one mget instead of N exists). forward_ids = [ compute_mdhash_id(f"{src}-{tgt}", prefix="edge-") for src, tgt, _ in edges ] reverse_ids = [ compute_mdhash_id(f"{tgt}-{src}", prefix="edge-") for src, tgt, _ in edges ] try: rev_response = await self.client.mget( index=self._edges_index, body={"ids": reverse_ids} ) existing_reverse = { doc["_id"] for doc in rev_response.get("docs", []) if doc.get("found") } except OpenSearchException: existing_reverse = set() actions = [] reserved_edge_ids = set(existing_reverse) for (src, tgt, edge_data), fwd_id, rev_id in zip( edges, forward_ids, reverse_ids ): edge_id = rev_id if rev_id in reserved_edge_ids else fwd_id reserved_edge_ids.add(edge_id) doc = {k: v for k, v in edge_data.items() if k != "_id"} doc["source_node_id"] = src doc["target_node_id"] = tgt if edge_data.get("source_id", ""): doc["source_ids"] = edge_data["source_id"].split(GRAPH_FIELD_SEP) actions.append( { "_op_type": "index", "_index": self._edges_index, "_id": edge_id, "_source": doc, } ) await helpers.async_bulk(self.client, actions) self._edges_dirty = True except OpenSearchException as e: logger.error(f"[{self.workspace}] Error during batch edge upsert: {e}") # --- Delete operations --- async def delete_node(self, node_id: str) -> None: """Delete a node and all its connected edges. Marks node and edge search views dirty so refresh happens lazily on the next search/count-based graph read. Uses conflicts="proceed" to tolerate already-deleted matches. """ try: # Refresh edge search view so delete_by_query sees all un-flushed writes. await self._refresh_graph_indices_if_dirty(refresh_edges=True) # Delete all edges referencing this node body = { "query": { "bool": { "should": [ {"term": {"source_node_id": node_id}}, {"term": {"target_node_id": node_id}}, ] } } } await self.client.delete_by_query( index=self._edges_index, body=body, params={"conflicts": "proceed"}, ) # Delete the node try: await self.client.delete(index=self._nodes_index, id=node_id) except NotFoundError: pass self._nodes_dirty = True self._edges_dirty = True except OpenSearchException as e: logger.error(f"[{self.workspace}] Error deleting node {node_id}: {e}") async def remove_nodes(self, nodes: list[str]) -> None: """Batch-delete multiple nodes and their connected edges. Marks node and edge search views dirty so refresh happens lazily on the next search/count-based graph read. Uses conflicts="proceed" to tolerate already-deleted matches. """ if not nodes: return logger.info(f"[{self.workspace}] Deleting {len(nodes)} nodes") try: # Refresh edge search view so delete_by_query sees all un-flushed writes. await self._refresh_graph_indices_if_dirty(refresh_edges=True) # Delete edges body = { "query": { "bool": { "should": [ {"terms": {"source_node_id": nodes}}, {"terms": {"target_node_id": nodes}}, ] } } } await self.client.delete_by_query( index=self._edges_index, body=body, params={"conflicts": "proceed"}, ) # Delete nodes actions = [ {"_op_type": "delete", "_index": self._nodes_index, "_id": nid} for nid in nodes ] await helpers.async_bulk(self.client, actions, raise_on_error=False) self._nodes_dirty = True self._edges_dirty = True except OpenSearchException as e: logger.error(f"[{self.workspace}] Error removing nodes: {e}") async def remove_edges(self, edges: list[tuple[str, str]]) -> None: """Batch-delete multiple edges by deterministic ID (real-time). Each edge is stored under one of two candidate IDs: forward = compute_mdhash_id("src-tgt", prefix="edge-") reverse = compute_mdhash_id("tgt-src", prefix="edge-") We delete both candidates for every requested edge so the deletion is effective regardless of which direction was stored. Marks edge search views dirty so refresh happens lazily on the next search/count-based graph read. """ if not edges: return logger.info(f"[{self.workspace}] Deleting {len(edges)} edges") try: operations = [] for src, tgt in edges: for edge_id in ( compute_mdhash_id(f"{src}-{tgt}", prefix="edge-"), compute_mdhash_id(f"{tgt}-{src}", prefix="edge-"), ): operations.append( { "delete": { "_index": self._edges_index, "_id": edge_id, } } ) await self.client.bulk(body=operations) self._edges_dirty = True except OpenSearchException as e: logger.error(f"[{self.workspace}] Error removing edges: {e}") # --- Query operations --- async def get_all_labels(self) -> list[str]: """Get all node IDs (entity names) sorted alphabetically.""" if not self._indices_ready: return [] try: await self._refresh_graph_indices_if_dirty(refresh_nodes=True) labels = [] pit = await self.client.create_pit( index=self._nodes_index, params={"keep_alive": "1m"} ) pit_id = pit["pit_id"] try: search_after = None while True: body = { "query": {"match_all": {}}, "_source": False, "size": 10000, "pit": {"id": pit_id, "keep_alive": "1m"}, "sort": _pit_sort_with_field("entity_id"), } if search_after: body["search_after"] = search_after response = await self.client.search(body=body) hits = response["hits"]["hits"] if not hits: break for hit in hits: labels.append(hit["_id"]) search_after = hits[-1]["sort"] if len(hits) < 10000: break finally: try: await self.client.delete_pit(body={"pit_id": [pit_id]}) except Exception: pass labels.sort() return labels except OpenSearchException as e: if _is_missing_index_error(e): self._mark_indices_missing() return [] async def _collect_node_ids( self, limit: int, exclude_ids: set[str] | None = None ) -> list[str]: """Collect up to `limit` node IDs, optionally skipping known IDs.""" if limit <= 0: return [] excluded = exclude_ids or set() if not excluded and limit <= 10000: body = { "query": {"match_all": {}}, "_source": False, "size": limit, } resp = await self.client.search(index=self._nodes_index, body=body) return [hit["_id"] for hit in resp["hits"]["hits"]] node_ids: list[str] = [] pit = await self.client.create_pit( index=self._nodes_index, params={"keep_alive": "1m"} ) pit_id = pit["pit_id"] try: search_after = None while len(node_ids) < limit: body = { "query": {"match_all": {}}, "_source": False, "size": 10000, "pit": {"id": pit_id, "keep_alive": "1m"}, "sort": _pit_sort_with_field("entity_id"), } if search_after: body["search_after"] = search_after resp = await self.client.search(body=body) hits = resp["hits"]["hits"] if not hits: break for hit in hits: node_id = hit["_id"] if node_id in excluded: continue node_ids.append(node_id) if len(node_ids) >= limit: break search_after = hits[-1].get("sort") if len(hits) < 10000: break finally: try: await self.client.delete_pit(body={"pit_id": [pit_id]}) except Exception: pass return node_ids @staticmethod def _edge_rank_key(edge: dict[str, Any]) -> tuple[int, float]: """Rank traversal edges by shallower depth first, then higher weight.""" depth = edge.get("_depth", edge.get("depth", 0)) try: depth_value = int(depth) except (TypeError, ValueError): depth_value = 0 weight = edge.get("weight", 0) try: weight_value = float(weight) except (TypeError, ValueError): weight_value = 0.0 return (depth_value, -weight_value) async def _append_edges_between_nodes( self, node_ids: list[str], result: KnowledgeGraph ) -> None: """Append all edges whose source and target are both in `node_ids`.""" if not node_ids: return edge_query = { "bool": { "must": [ {"terms": {"source_node_id": node_ids}}, {"terms": {"target_node_id": node_ids}}, ] } } seen_edges = set() pit = await self.client.create_pit( index=self._edges_index, params={"keep_alive": "1m"} ) pit_id = pit["pit_id"] try: search_after = None while True: edge_body = { "query": edge_query, "size": 10000, "pit": {"id": pit_id, "keep_alive": "1m"}, "sort": _pit_sort_with_composite_key( "source_node_id", "target_node_id" ), } if search_after: edge_body["search_after"] = search_after edge_resp = await self.client.search(body=edge_body) hits = edge_resp["hits"]["hits"] if not hits: break for hit in hits: e = hit["_source"] eid = f"{e['source_node_id']}-{e['target_node_id']}" if eid not in seen_edges: seen_edges.add(eid) result.edges.append(self._construct_graph_edge(eid, e)) search_after = hits[-1].get("sort") if len(hits) < 10000: break finally: try: await self.client.delete_pit(body={"pit_id": [pit_id]}) except Exception: pass def _construct_graph_node(self, node_id, node_data: dict) -> KnowledgeGraphNode: return KnowledgeGraphNode( id=node_id, labels=[node_id], properties={ k: v for k, v in node_data.items() if k not in ( "_id", "entity_id", "source_ids", "connected_edges", "edge_count", ) }, ) def _construct_graph_edge(self, edge_id: str, edge: dict) -> KnowledgeGraphEdge: return KnowledgeGraphEdge( id=edge_id, type=edge.get("relationship", ""), source=edge["source_node_id"], target=edge["target_node_id"], properties={ k: v for k, v in edge.items() if k not in ( "_id", "source_node_id", "target_node_id", "relationship", "source_ids", ) }, ) async def get_knowledge_graph( self, node_label: str, max_depth: int = 3, max_nodes: int = None, ) -> KnowledgeGraph: """Retrieve a subgraph via PPL graphlookup (if available) or client-side BFS.""" if not self._indices_ready: return KnowledgeGraph() if max_nodes is None: max_nodes = self.global_config.get("max_graph_nodes", 1000) else: max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000)) result = KnowledgeGraph() start = time.perf_counter() try: await self._refresh_graph_indices_if_dirty( refresh_nodes=True, refresh_edges=True ) if node_label == "*": result = await self._get_knowledge_graph_all(max_nodes) elif self._ppl_graphlookup_available: result = await self._bfs_subgraph_ppl(node_label, max_depth, max_nodes) else: result = await self._bfs_subgraph(node_label, max_depth, max_nodes) duration = time.perf_counter() - start logger.info( f"[{self.workspace}] Subgraph query in {duration:.4f}s | " f"Nodes: {len(result.nodes)} | Edges: {len(result.edges)} | Truncated: {result.is_truncated}" ) except OpenSearchException as e: if _is_missing_index_error(e): self._mark_indices_missing() return KnowledgeGraph() logger.error(f"[{self.workspace}] Graph query failed: {e}") return result async def _get_knowledge_graph_all(self, max_nodes: int) -> KnowledgeGraph: """Get all nodes (up to max_nodes, ranked by degree) and their interconnecting edges.""" result = KnowledgeGraph() if not self._indices_ready: return result try: total = (await self.client.count(index=self._nodes_index))["count"] result.is_truncated = total > max_nodes if result.is_truncated: # Get top nodes by degree body = { "size": 0, "aggs": { "src": { "terms": { "field": "source_node_id", "size": max_nodes, } }, "tgt": { "terms": { "field": "target_node_id", "size": max_nodes, } }, }, } resp = await self.client.search(index=self._edges_index, body=body) degree_map = {} for bucket in resp["aggregations"]["src"]["buckets"]: degree_map[bucket["key"]] = ( degree_map.get(bucket["key"], 0) + bucket["doc_count"] ) for bucket in resp["aggregations"]["tgt"]["buckets"]: degree_map[bucket["key"]] = ( degree_map.get(bucket["key"], 0) + bucket["doc_count"] ) top_ids = sorted(degree_map, key=degree_map.get, reverse=True)[ :max_nodes ] if len(top_ids) < max_nodes: top_ids.extend( await self._collect_node_ids( max_nodes - len(top_ids), exclude_ids=set(top_ids) ) ) else: top_ids = await self._collect_node_ids(max_nodes) # Fetch node data if top_ids: node_resp = await self.client.mget( index=self._nodes_index, body={"ids": top_ids} ) found_node_ids = [] for doc in node_resp["docs"]: if doc.get("found"): found_node_ids.append(doc["_id"]) result.nodes.append( self._construct_graph_node(doc["_id"], doc["_source"]) ) await self._append_edges_between_nodes(found_node_ids, result) except OpenSearchException as e: if _is_missing_index_error(e): self._mark_indices_missing() return result logger.error(f"[{self.workspace}] Error in get_knowledge_graph_all: {e}") return result async def _bfs_subgraph_ppl( self, start_label: str, max_depth: int, max_nodes: int ) -> KnowledgeGraph: """Server-side BFS using PPL graphlookup command. Queries the nodes index for the start node, then uses graphLookup to traverse the edges index with bidirectional BFS. Uses `flatten` to unnest results and `depthField` for depth-based sorting. Falls back to client-side BFS on failure. """ result = KnowledgeGraph() # Verify start node exists start_node = await self.get_node(start_label) if not start_node: return result result.nodes.append(self._construct_graph_node(start_label, start_node)) if max_depth == 0: return result # PPL maxDepth=0 means 1 hop (direct match), so max_depth-1 ppl_depth = max(0, max_depth - 1) escaped = self._escape_ppl(start_label) ppl_query = ( f"source = {self._nodes_index}" f" | where entity_id = '{escaped}'" f" | graphLookup {self._edges_index}" f" start=entity_id" f" edge=target_node_id<->source_node_id" f" maxDepth={ppl_depth}" f" depthField=_depth" f" usePIT=true" f" as connected_edges" ) try: resp = await self.client.transport.perform_request( "POST", "/_plugins/_ppl", body={"query": ppl_query}, ) except Exception as e: logger.warning( f"[{self.workspace}] PPL graphlookup failed, falling back to client BFS: {e}" ) return await self._bfs_subgraph(start_label, max_depth, max_nodes) # Parse PPL response — schema-driven to avoid fragile positional access try: datarows = resp.get("datarows", []) schema = [col["name"] for col in resp.get("schema", [])] ce_idx = ( schema.index("connected_edges") if "connected_edges" in schema else -1 ) # Collect all edge rows from connected_edges arrays all_edge_rows = [] for row in datarows: edges_arr = row[ce_idx] if ce_idx >= 0 else [] if isinstance(edges_arr, list): all_edge_rows.extend(edges_arr) if not all_edge_rows: return result if isinstance(all_edge_rows[0], dict): sorted_edge_rows = sorted(all_edge_rows, key=self._edge_rank_key) else: # Positional array — column positions are unknown, fall back to client BFS logger.warning( f"[{self.workspace}] PPL returned positional arrays, falling back to client BFS" ) return await self._bfs_subgraph(start_label, max_depth, max_nodes) except (KeyError, IndexError, TypeError, ValueError) as e: logger.warning( f"[{self.workspace}] Error parsing PPL response, falling back: {e}" ) return await self._bfs_subgraph(start_label, max_depth, max_nodes) ordered_node_ids = [start_label] discovered_nodes = {start_label} for edge_row in sorted_edge_rows: for node_id in ( edge_row.get("source_node_id"), edge_row.get("target_node_id"), ): if not node_id or node_id in discovered_nodes: continue discovered_nodes.add(node_id) if len(ordered_node_ids) < max_nodes: ordered_node_ids.append(node_id) result.is_truncated = len(discovered_nodes) > max_nodes # Batch fetch node data (start node already added) new_node_ids = [nid for nid in ordered_node_ids if nid != start_label] if new_node_ids: node_resp = await self.client.mget( index=self._nodes_index, body={"ids": new_node_ids} ) for doc in node_resp["docs"]: if doc.get("found"): result.nodes.append( self._construct_graph_node(doc["_id"], doc["_source"]) ) await self._append_edges_between_nodes(ordered_node_ids, result) return result @staticmethod def _escape_ppl(value: str) -> str: """Escape a string for safe inclusion in a PPL single-quoted literal.""" return value.replace("\\", "\\\\").replace("'", "\\'") async def _bfs_subgraph( self, start_label: str, max_depth: int, max_nodes: int ) -> KnowledgeGraph: """BFS traversal from a starting node, batching neighbor lookups per level.""" result = KnowledgeGraph() seen_nodes = set() # Verify start node exists start_node = await self.get_node(start_label) if not start_node: return result seen_nodes.add(start_label) result.nodes.append(self._construct_graph_node(start_label, start_node)) current_level = [start_label] for _ in range(max_depth): if not current_level or len(seen_nodes) >= max_nodes: break # Batch fetch all edges for current level body = { "query": { "bool": { "should": [ {"terms": {"source_node_id": current_level}}, {"terms": {"target_node_id": current_level}}, ] } }, "_source": ["source_node_id", "target_node_id"], "size": 10000, } try: resp = await self.client.search(index=self._edges_index, body=body) except OpenSearchException: break next_level = set() for hit in resp["hits"]["hits"]: src = hit["_source"]["source_node_id"] tgt = hit["_source"]["target_node_id"] if src not in seen_nodes: next_level.add(src) if tgt not in seen_nodes: next_level.add(tgt) # Limit to max_nodes new_ids = [] for nid in next_level: if len(seen_nodes) + len(new_ids) >= max_nodes: break new_ids.append(nid) if new_ids: # Batch fetch node data node_resp = await self.client.mget( index=self._nodes_index, body={"ids": new_ids} ) for doc in node_resp["docs"]: if doc.get("found"): seen_nodes.add(doc["_id"]) result.nodes.append( self._construct_graph_node(doc["_id"], doc["_source"]) ) current_level = new_ids # Fetch all edges between seen nodes using PIT scrolling all_ids = list(seen_nodes) if all_ids: try: await self._append_edges_between_nodes(all_ids, result) except OpenSearchException: pass result.is_truncated = len(seen_nodes) >= max_nodes return result async def get_all_nodes(self) -> list[dict]: """Get all nodes with their properties.""" if not self._indices_ready: return [] try: await self._refresh_graph_indices_if_dirty(refresh_nodes=True) nodes = [] pit = await self.client.create_pit( index=self._nodes_index, params={"keep_alive": "1m"} ) pit_id = pit["pit_id"] try: search_after = None while True: body = { "query": {"match_all": {}}, "size": 10000, "pit": {"id": pit_id, "keep_alive": "1m"}, "sort": _pit_sort_with_field("entity_id"), } if search_after: body["search_after"] = search_after response = await self.client.search(body=body) hits = response["hits"]["hits"] if not hits: break for hit in hits: node = hit["_source"] node["id"] = hit["_id"] nodes.append(node) search_after = hits[-1]["sort"] if len(hits) < 10000: break finally: try: await self.client.delete_pit(body={"pit_id": [pit_id]}) except Exception: pass return nodes except OpenSearchException as e: if _is_missing_index_error(e): self._mark_indices_missing() return [] async def get_all_edges(self) -> list[dict]: """Get all edges with source/target fields added.""" if not self._indices_ready: return [] try: await self._refresh_graph_indices_if_dirty(refresh_edges=True) edges = [] pit = await self.client.create_pit( index=self._edges_index, params={"keep_alive": "1m"} ) pit_id = pit["pit_id"] try: search_after = None while True: body = { "query": {"match_all": {}}, "size": 10000, "pit": {"id": pit_id, "keep_alive": "1m"}, "sort": _pit_sort_with_composite_key( "source_node_id", "target_node_id" ), } if search_after: body["search_after"] = search_after response = await self.client.search(body=body) hits = response["hits"]["hits"] if not hits: break for hit in hits: edge = hit["_source"] edge["source"] = edge.get("source_node_id") edge["target"] = edge.get("target_node_id") edges.append(edge) search_after = hits[-1]["sort"] if len(hits) < 10000: break finally: try: await self.client.delete_pit(body={"pit_id": [pit_id]}) except Exception: pass return edges except OpenSearchException as e: if _is_missing_index_error(e): self._mark_indices_missing() return [] async def get_popular_labels(self, limit: int = 300) -> list[str]: """Get node labels ranked by edge degree (most connected first).""" if not self._indices_ready: return [] try: await self._refresh_graph_indices_if_dirty(refresh_edges=True) body = { "size": 0, "aggs": { "src": {"terms": {"field": "source_node_id", "size": limit * 2}}, "tgt": {"terms": {"field": "target_node_id", "size": limit * 2}}, }, } response = await self.client.search(index=self._edges_index, body=body) degree_map = {} for bucket in response["aggregations"]["src"]["buckets"]: degree_map[bucket["key"]] = ( degree_map.get(bucket["key"], 0) + bucket["doc_count"] ) for bucket in response["aggregations"]["tgt"]["buckets"]: degree_map[bucket["key"]] = ( degree_map.get(bucket["key"], 0) + bucket["doc_count"] ) sorted_labels = sorted(degree_map, key=degree_map.get, reverse=True)[:limit] return sorted_labels except OpenSearchException as e: if _is_missing_index_error(e): self._mark_indices_missing() return [] async def search_labels(self, query: str, limit: int = 50) -> list[str]: """Search node labels with wildcard and prefix matching.""" query = query.strip() if not query: return [] if not self._indices_ready: return [] try: await self._refresh_graph_indices_if_dirty(refresh_nodes=True) body = { "query": { "bool": { "should": [ {"term": {"entity_id": {"value": query, "boost": 10}}}, { "prefix": { "entity_id": {"value": query.lower(), "boost": 5} } }, { "wildcard": { "entity_id": { "value": f"*{query.lower()}*", "case_insensitive": True, "boost": 2, } } }, ] } }, "_source": False, "size": limit, } response = await self.client.search(index=self._nodes_index, body=body) return [hit["_id"] for hit in response["hits"]["hits"]] except OpenSearchException as e: if _is_missing_index_error(e): self._mark_indices_missing() return [] async def index_done_callback(self) -> None: """Refresh both node and edge indices.""" if not self._indices_ready: return try: await self._refresh_graph_indices_if_dirty( refresh_nodes=True, refresh_edges=True ) except OpenSearchException as e: if _is_missing_index_error(e): self._mark_indices_missing() return except Exception: pass async def drop(self) -> dict[str, str]: """Delete both node and edge indices.""" errors = [] for idx in (self._nodes_index, self._edges_index): try: await self.client.indices.delete(index=idx) logger.info(f"[{self.workspace}] Dropped graph index: {idx}") except NotFoundError: logger.info( f"[{self.workspace}] Graph index already missing during drop: {idx}" ) except OpenSearchException as e: errors.append(f"{idx}: {e}") logger.error( f"[{self.workspace}] Error dropping graph index {idx}: {e}" ) except Exception as e: errors.append(f"{idx}: {e}") logger.error( f"[{self.workspace}] Unexpected error dropping graph index {idx}: {e}" ) self._mark_indices_missing() if errors: return { "status": "error", "message": "Failed to drop graph indices: " + "; ".join(errors), } try: logger.info(f"[{self.workspace}] Dropped graph indices") return {"status": "success", "message": "Graph indices dropped"} except Exception as e: logger.error(f"[{self.workspace}] Error finalizing graph drop: {e}") return {"status": "error", "message": str(e)} @final @dataclass class OpenSearchVectorDBStorage(BaseVectorStorage): """Vector storage using OpenSearch k-NN plugin with corrected cosine score handling.""" client: AsyncOpenSearch = field(default=None) _index_name: str = field(default="", init=False) _index_ready: bool = field(default=False, init=False) def __init__( self, namespace, global_config, embedding_func, workspace=None, meta_fields=None ): super().__init__( namespace=namespace, workspace=workspace or "", global_config=global_config, embedding_func=embedding_func, meta_fields=meta_fields or set(), ) self.__post_init__() def __post_init__(self): self._validate_embedding_func() self.workspace, self.final_namespace, self._index_name = _build_index_name( self.workspace, self.namespace ) kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = kwargs.get("cosine_better_than_threshold") if cosine_threshold is None: raise ValueError( "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" ) self.cosine_better_than_threshold = cosine_threshold self._max_batch_size = self.global_config["embedding_batch_num"] async def initialize(self): """Initialize client and create k-NN vector index.""" async with get_data_init_lock(): if self.client is None: self.client = await ClientManager.get_client() await self._create_knn_index_if_not_exists() self._index_ready = True logger.debug( f"[{self.workspace}] OpenSearch Vector storage initialized: {self._index_name}" ) async def _ensure_index_ready(self): """Recreate the vector index before the next write if it is missing.""" if self._index_ready: return async with get_data_init_lock(): if self.client is None: self.client = await ClientManager.get_client() if not self._index_ready: await self._create_knn_index_if_not_exists() self._index_ready = True def _mark_index_missing(self): """Mark the vector index as unavailable for subsequent read short-circuiting.""" self._index_ready = False async def _create_knn_index_if_not_exists(self): try: if await self.client.indices.exists(index=self._index_name): # Validate existing index dimension try: mapping = await self.client.indices.get_mapping( index=self._index_name ) existing_dim = ( mapping[self._index_name]["mappings"]["properties"] .get("vector", {}) .get("dimension") ) expected_dim = self.embedding_func.embedding_dim if existing_dim is not None and existing_dim != expected_dim: raise ValueError( f"Vector dimension mismatch! Index '{self._index_name}' has " f"dimension {existing_dim}, but current embedding model expects " f"dimension {expected_dim}. Please drop the existing index or " f"use an embedding model with matching dimensions." ) except (KeyError, TypeError): logger.warning( f"[{self.workspace}] Could not read vector mapping for index " f"'{self._index_name}'; skipping dimension validation" ) return ef_construction = int( _get_opensearch_env("OPENSEARCH_KNN_EF_CONSTRUCTION", "200") ) m = int(_get_opensearch_env("OPENSEARCH_KNN_M", "16")) ef_search = int(_get_opensearch_env("OPENSEARCH_KNN_EF_SEARCH", "100")) body = { "settings": { "index": { "knn": True, "knn.algo_param.ef_search": ef_search, "number_of_shards": _get_index_number_of_shards(), "number_of_replicas": _get_index_number_of_replicas(), } }, "mappings": { "properties": { "vector": { "type": "knn_vector", "dimension": self.embedding_func.embedding_dim, "method": { "name": "hnsw", "space_type": "cosinesimil", "engine": "lucene", "parameters": { "ef_construction": ef_construction, "m": m, }, }, }, "content": {"type": "text"}, "entity_name": {"type": "keyword"}, "src_id": {"type": "keyword"}, "tgt_id": {"type": "keyword"}, "file_path": {"type": "keyword"}, "created_at": {"type": "long"}, }, "dynamic": True, }, } await self.client.indices.create(index=self._index_name, body=body) logger.info( f"[{self.workspace}] Created k-NN index: {self._index_name} " f"(dim={self.embedding_func.embedding_dim})" ) except RequestError as e: if "resource_already_exists_exception" not in str(e): logger.error(f"[{self.workspace}] Error creating k-NN index: {e}") raise except OpenSearchException as e: logger.error(f"[{self.workspace}] Error creating k-NN index: {e}") raise async def finalize(self): """Release the OpenSearch client connection.""" if self.client is not None: await ClientManager.release_client(self.client) self.client = None async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """Generate embeddings and upsert vectors in batches.""" if not data: return await self._ensure_index_ready() logger.debug( f"[{self.workspace}] Upserting {len(data)} vectors to {self.namespace}" ) current_time = int(time.time()) list_data = [] for i, (k, v) in enumerate(data.items(), start=1): list_data.append( { "_id": k, "created_at": current_time, **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, } ) await _cooperative_yield(i) contents = [v["content"] for v in data.values()] batches = [ contents[i : i + self._max_batch_size] for i in range(0, len(contents), self._max_batch_size) ] embeddings_list = await asyncio.gather( *[self.embedding_func(batch, context="document") for batch in batches] ) embeddings = np.concatenate(embeddings_list) assert len(embeddings) == len( list_data ), f"Embedding count mismatch: expected {len(list_data)}, got {len(embeddings)}" for i, doc in enumerate(list_data, start=1): doc["vector"] = embeddings[i - 1].tolist() await _cooperative_yield(i) actions = [] for i, doc in enumerate(list_data, start=1): actions.append( { "_op_type": "index", "_index": self._index_name, "_id": doc["_id"], "_source": {k: v for k, v in doc.items() if k != "_id"}, } ) await _cooperative_yield(i) try: # No per-operation refresh: immediate reads use ID-based mget (translog), # k-NN search visibility is guaranteed after index_done_callback() batch refresh. success, failed = await helpers.async_bulk( self.client, actions, raise_on_error=False ) if failed: logger.warning( f"[{self.workspace}] {len(failed)} vectors failed to upsert" ) except OpenSearchException as e: logger.error(f"[{self.workspace}] Error upserting vectors: {e}") raise async def query( self, query: str, top_k: int, query_embedding: list[float] = None ) -> list[dict[str, Any]]: """k-NN similarity search with cosine score conversion for lucene engine.""" if not self._index_ready: return [] if query_embedding is not None: query_vector = ( query_embedding.tolist() if hasattr(query_embedding, "tolist") else list(query_embedding) ) else: embedding = await self.embedding_func([query], context="query", _priority=5) query_vector = embedding[0].tolist() search_body = { "size": top_k, "query": {"knn": {"vector": {"vector": query_vector, "k": top_k}}}, "_source": {"excludes": ["vector"]}, } try: response = await self.client.search( index=self._index_name, body=search_body ) results = [] for hit in response["hits"]["hits"]: # OpenSearch k-NN with lucene engine and cosinesimil space type # returns scores that can be used directly as similarity measure. score = hit["_score"] if score >= self.cosine_better_than_threshold: doc = hit["_source"] doc["id"] = hit["_id"] doc["distance"] = score results.append(doc) logger.info( f"[{self.workspace}] Vector query on {self._index_name}: " f"top_k={top_k}, threshold={self.cosine_better_than_threshold}, " f"total_hits={len(response['hits']['hits'])}, " f"passed_filter={len(results)}, " f"score_range=[{min((h['_score'] for h in response['hits']['hits']), default=0):.4f}, " f"{max((h['_score'] for h in response['hits']['hits']), default=0):.4f}]" ) return results except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return [] logger.error(f"[{self.workspace}] Error querying vectors: {e}") return [] async def index_done_callback(self) -> None: """Refresh index to make recently indexed vectors searchable.""" if not self._index_ready: return try: await self.client.indices.refresh(index=self._index_name) except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return except Exception: pass async def get_by_id(self, id: str) -> dict[str, Any] | None: """Get a vector document by ID.""" if not self._index_ready: return None try: response = await _mget_optional_doc(self.client, self._index_name, id) if response is None: return None doc = response["_source"] doc["id"] = response["_id"] return doc except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return None logger.error(f"[{self.workspace}] Error getting vector {id}: {e}") return None async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: """Get multiple vector documents by IDs, preserving order.""" if not ids: return [] if not self._index_ready: return [None] * len(ids) try: response = await self.client.mget(index=self._index_name, body={"ids": ids}) doc_map = {} for doc in response["docs"]: if doc.get("found"): data = doc["_source"] data["id"] = doc["_id"] doc_map[doc["_id"]] = data return [doc_map.get(id) for id in ids] except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return [None] * len(ids) logger.error(f"[{self.workspace}] Error getting vectors by ids: {e}") return [None] * len(ids) async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]: """Get only the vector embeddings for given IDs.""" if not ids: return {} if not self._index_ready: return {} try: response = await self.client.mget( index=self._index_name, body={"ids": ids}, _source_includes=["vector"] ) result = {} for doc in response["docs"]: if doc.get("found") and "vector" in doc.get("_source", {}): result[doc["_id"]] = doc["_source"]["vector"] return result except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return {} logger.error(f"[{self.workspace}] Error getting vectors: {e}") return {} async def delete(self, ids: list[str]) -> None: """Delete vectors by their IDs.""" if not ids: return if not self._index_ready: return if isinstance(ids, set): ids = list(ids) try: # No per-operation refresh: search visibility after index_done_callback(). actions = [ {"_op_type": "delete", "_index": self._index_name, "_id": doc_id} for doc_id in ids ] result = await helpers.async_bulk( self.client, actions, raise_on_error=False ) logger.debug( f"[{self.workspace}] Deleted {result[0]} vectors from {self.namespace}" ) except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return logger.error(f"[{self.workspace}] Error deleting vectors: {e}") async def delete_entity(self, entity_name: str) -> None: """Delete an entity vector by computing its hash ID.""" if not self._index_ready: return try: # No per-operation refresh: search visibility after index_done_callback(). entity_id = compute_mdhash_id(entity_name, prefix="ent-") try: await self.client.delete(index=self._index_name, id=entity_id) logger.debug(f"[{self.workspace}] Deleted entity {entity_name}") except NotFoundError as e: if _is_missing_index_error(e): self._mark_index_missing() return logger.debug(f"[{self.workspace}] Entity {entity_name} not found") except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return logger.error(f"[{self.workspace}] Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: """Delete all relation vectors where entity appears as src or tgt.""" if not self._index_ready: return try: # No per-operation refresh: search visibility after index_done_callback(). body = { "query": { "bool": { "should": [ {"term": {"src_id": entity_name}}, {"term": {"tgt_id": entity_name}}, ] } } } # conflicts="proceed" tolerates stale search view after refresh removal. await self.client.delete_by_query( index=self._index_name, body=body, params={"conflicts": "proceed"} ) logger.debug( f"[{self.workspace}] Deleted relations for entity {entity_name}" ) except OpenSearchException as e: if _is_missing_index_error(e): self._mark_index_missing() return logger.error( f"[{self.workspace}] Error deleting relations for {entity_name}: {e}" ) async def drop(self) -> dict[str, str]: """Delete and recreate the vector index.""" try: try: await self.client.indices.delete(index=self._index_name) logger.info( f"[{self.workspace}] Dropped vector index: {self._index_name}" ) except NotFoundError: logger.info( f"[{self.workspace}] Vector index already missing during drop: {self._index_name}" ) # Recreate the index await self._create_knn_index_if_not_exists() self._index_ready = True logger.info( f"[{self.workspace}] Dropped and recreated vector index: {self._index_name}" ) return { "status": "success", "message": f"Vector index {self._index_name} dropped and recreated", } except OpenSearchException as e: self._mark_index_missing() logger.error(f"[{self.workspace}] Error dropping vector index: {e}") return {"status": "error", "message": str(e)} except Exception as e: self._mark_index_missing() logger.error( f"[{self.workspace}] Unexpected error dropping vector index: {e}" ) return {"status": "error", "message": str(e)}