Files
caoxiaozhu 68f663f2f4 feat: 重构知识库系统,移除Hermes集成,增强RAG和同步功能
主要变更:
- 移除Hermes智能体及相关回调服务
- 新增知识库RAG、同步、调度、规范化和索引任务服务
- 重构orchestrator服务,增强运行时聊天功能
- 更新前端聊天、政策制度、设置等页面样式和逻辑
- 更新expense_claims和document_intelligence服务
- 删除llm_wiki相关服务和测试文件
- 更新docker-compose配置和启动脚本
2026-05-17 08:38:41 +00:00

3046 lines
119 KiB
Python

"""
OpenSearch Storage Implementation for LightRAG
This module provides OpenSearch-based storage backends for LightRAG,
including KV storage, document status storage, graph storage, and vector storage.
Requirements:
- opensearch-py >= 3.0.0
- OpenSearch 3.x or higher with k-NN plugin enabled
"""
import os
import re
import ssl as ssl_module
import time
import asyncio
from dataclasses import dataclass, field
from typing import Any, AsyncIterator, Union, final
import numpy as np
import configparser
from ..base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
DocProcessingStatus,
DocStatus,
DocStatusStorage,
)
from ..utils import logger, compute_mdhash_id, _cooperative_yield
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP
from ..kg.shared_storage import get_data_init_lock
import pipmaster as pm
if not pm.is_installed("opensearch-py"):
pm.install("opensearch-py")
from opensearchpy import AsyncOpenSearch, helpers # type: ignore
from opensearchpy.exceptions import OpenSearchException, NotFoundError, RequestError # type: ignore
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
def _get_opensearch_env(key, fallback):
cfg_key = key.replace("OPENSEARCH_", "").lower()
return os.environ.get(key, config.get("opensearch", cfg_key, fallback=fallback))
def _get_index_number_of_shards() -> int:
return int(_get_opensearch_env("OPENSEARCH_NUMBER_OF_SHARDS", "1"))
def _get_index_number_of_replicas() -> int:
return int(_get_opensearch_env("OPENSEARCH_NUMBER_OF_REPLICAS", "0"))
def _sanitize_index_name(name: str) -> str:
"""Sanitize a string to be a valid OpenSearch index name."""
sanitized = re.sub(r"[^a-z0-9_-]", "_", name.lower())
if sanitized and sanitized[0] in "-_+":
sanitized = "x" + sanitized
return sanitized
# Detected at first connection; True when OpenSearch >= 3.3.0.
_shard_doc_supported: bool | None = None
def _pit_sort_with_field(field: str) -> list[dict]:
"""Return PIT sort clause with a unique field as primary sort.
Used purely as a pagination tiebreaker — order is fixed to asc since the
business sort (when present) is applied separately by the caller.
>= 3.3.0: _shard_doc only (most efficient, already unique within PIT).
< 3.3.0: field + _doc (field is unique, _doc for efficiency).
"""
if _shard_doc_supported:
return [{"_shard_doc": "asc"}]
return [{field: {"order": "asc"}}, {"_doc": "asc"}]
def _pit_sort_with_composite_key(*fields: str) -> list[dict]:
"""Return PIT sort clause with multiple fields forming a composite unique key.
>= 3.3.0: _shard_doc (most efficient, ignores the fields).
< 3.3.0: field1 + field2 + ... + _doc (composite is unique, _doc for efficiency).
"""
if _shard_doc_supported:
return [{"_shard_doc": "asc"}]
return [{f: {"order": "asc"}} for f in fields] + [{"_doc": "asc"}]
async def _detect_shard_doc_support(client: AsyncOpenSearch) -> bool:
"""Check if the cluster supports _shard_doc (OpenSearch >= 3.3.0)."""
try:
info = await client.info()
version_str = info.get("version", {}).get("number", "0.0.0")
# Strip pre-release suffixes (e.g. "3.3.0-SNAPSHOT" → "3", "3", "0")
parts = [p.split("-")[0] for p in version_str.split(".")]
major = int(parts[0]) if parts[0].isdigit() else 0
minor = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 0
supported = (major > 3) or (major == 3 and minor >= 3)
logger.info(
f"OpenSearch version {version_str}: "
f"_shard_doc {'supported' if supported else 'not supported, using field+_doc fallback'}"
)
return supported
except Exception as e:
logger.warning(
f"Failed to detect OpenSearch version, assuming _shard_doc not supported: {e}"
)
return False
class ClientManager:
"""Singleton manager for OpenSearch client connections."""
_instances = {"client": None, "ref_count": 0}
_lock = asyncio.Lock()
@classmethod
async def get_client(cls) -> AsyncOpenSearch:
"""Get or create a shared AsyncOpenSearch client with reference counting."""
global _shard_doc_supported
async with cls._lock:
if cls._instances["client"] is None:
hosts_str = _get_opensearch_env("OPENSEARCH_HOSTS", "localhost:9200")
hosts = [h.strip() for h in hosts_str.split(",") if h.strip()]
username = _get_opensearch_env("OPENSEARCH_USER", "admin")
password = _get_opensearch_env("OPENSEARCH_PASSWORD", "admin")
use_ssl = _get_opensearch_env("OPENSEARCH_USE_SSL", "true").lower() in (
"true",
"1",
"yes",
)
verify_certs = _get_opensearch_env(
"OPENSEARCH_VERIFY_CERTS", "false"
).lower() in ("true", "1", "yes")
timeout = int(_get_opensearch_env("OPENSEARCH_TIMEOUT", "30"))
max_retries = int(_get_opensearch_env("OPENSEARCH_MAX_RETRIES", "3"))
ssl_context = None
if use_ssl and not verify_certs:
ssl_context = ssl_module.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl_module.CERT_NONE
client = AsyncOpenSearch(
hosts=hosts,
http_auth=(username, password) if username else None,
use_ssl=use_ssl,
verify_certs=verify_certs,
ssl_context=ssl_context,
ssl_show_warn=False,
timeout=timeout,
max_retries=max_retries,
retry_on_timeout=True,
)
cls._instances["client"] = client
cls._instances["ref_count"] = 0
_shard_doc_supported = await _detect_shard_doc_support(client)
logger.info(f"OpenSearch client connected to {hosts}")
cls._instances["ref_count"] += 1
return cls._instances["client"]
@classmethod
async def release_client(cls, client: AsyncOpenSearch):
"""Release a client reference. Closes the connection when ref count reaches 0."""
global _shard_doc_supported
async with cls._lock:
if client is not None and client is cls._instances["client"]:
cls._instances["ref_count"] -= 1
if cls._instances["ref_count"] <= 0:
try:
await cls._instances["client"].close()
except Exception:
pass
cls._instances["client"] = None
cls._instances["ref_count"] = 0
_shard_doc_supported = None
logger.info("OpenSearch client connection closed")
def _resolve_workspace(workspace: str, namespace: str):
"""Resolve effective workspace from env or parameter."""
opensearch_workspace = os.environ.get("OPENSEARCH_WORKSPACE")
if opensearch_workspace and opensearch_workspace.strip():
effective = opensearch_workspace.strip()
logger.info(
f"Using OPENSEARCH_WORKSPACE: '{effective}' (overriding '{workspace}/{namespace}')"
)
return effective
return workspace
def _build_index_name(workspace: str, namespace: str) -> tuple[str, str, str]:
"""Build index name and return (effective_workspace, final_namespace, index_name)."""
effective = _resolve_workspace(workspace, namespace)
if effective:
final_ns = f"{effective}_{namespace}"
else:
final_ns = namespace
effective = ""
index_name = _sanitize_index_name(final_ns)
return effective, final_ns, index_name
async def _mget_optional_doc(
client: AsyncOpenSearch, index_name: str, doc_id: str
) -> dict[str, Any] | None:
"""Fetch a single document via mget and return None when it is absent."""
response = await client.mget(index=index_name, body={"ids": [doc_id]})
docs = response.get("docs", [])
if not docs:
return None
doc = docs[0]
if not doc.get("found"):
return None
return doc
def _is_missing_index_error(exc: Exception) -> bool:
"""Return True when an OpenSearch exception means the target index is missing."""
return "index_not_found_exception" in str(exc)
async def _verify_mirrored_id_mapping(client: AsyncOpenSearch, index_name: str) -> None:
"""Fail-fast when an existing index lacks the __mirrored_id keyword mapping.
Only enforced on OpenSearch < 3.3.0, where __mirrored_id serves as the
cross-shard pagination tiebreaker. Indices created by older LightRAG
releases will be missing this mapping; sorting by a missing field on a
multi-shard index can drop or duplicate documents during PIT pagination.
"""
if _shard_doc_supported:
return
try:
mapping = await client.indices.get_mapping(index=index_name)
except OpenSearchException:
return
props = mapping.get(index_name, {}).get("mappings", {}).get("properties", {})
if "__mirrored_id" not in props:
raise RuntimeError(
f"Index '{index_name}' lacks the '__mirrored_id' keyword mapping "
f"required for stable PIT pagination on OpenSearch < 3.3.0. "
f"This index was likely created by an older LightRAG release. "
f"Please reindex the data, or upgrade the cluster to OpenSearch >= 3.3.0."
)
@final
@dataclass
class OpenSearchKVStorage(BaseKVStorage):
"""Key-Value storage using OpenSearch. Uses dynamic mapping to support varied schemas."""
client: AsyncOpenSearch = field(default=None)
_index_name: str = field(default="", init=False)
_index_ready: bool = field(default=False, init=False)
def __init__(self, namespace, global_config, embedding_func, workspace=None):
super().__init__(
namespace=namespace,
workspace=workspace or "",
global_config=global_config,
embedding_func=embedding_func,
)
self.__post_init__()
def __post_init__(self):
self.workspace, self.final_namespace, self._index_name = _build_index_name(
self.workspace, self.namespace
)
async def initialize(self):
"""Initialize client connection and create index if needed."""
async with get_data_init_lock():
if self.client is None:
self.client = await ClientManager.get_client()
await self._create_index_if_not_exists()
self._index_ready = True
logger.debug(
f"[{self.workspace}] OpenSearch KV storage initialized: {self._index_name}"
)
async def _ensure_index_ready(self):
"""Recreate the KV index after drop before the next write."""
if self._index_ready:
return
async with get_data_init_lock():
if self.client is None:
self.client = await ClientManager.get_client()
if not self._index_ready:
await self._create_index_if_not_exists()
self._index_ready = True
def _mark_index_missing(self):
"""Mark the KV index as unavailable for subsequent read short-circuiting."""
self._index_ready = False
async def _create_index_if_not_exists(self):
try:
if not await self.client.indices.exists(index=self._index_name):
# Use dynamic mapping so any namespace schema works
body = {
"mappings": {
"dynamic": True,
"properties": {
"__mirrored_id": {"type": "keyword"},
},
},
"settings": {
"index": {
"number_of_shards": _get_index_number_of_shards(),
"number_of_replicas": _get_index_number_of_replicas(),
},
},
}
await self.client.indices.create(index=self._index_name, body=body)
logger.info(f"[{self.workspace}] Created index: {self._index_name}")
else:
await _verify_mirrored_id_mapping(self.client, self._index_name)
except RequestError as e:
if "resource_already_exists_exception" not in str(e):
raise
except OpenSearchException as e:
logger.error(f"[{self.workspace}] Error creating index: {e}")
raise
async def finalize(self):
"""Release the OpenSearch client connection."""
if self.client is not None:
await ClientManager.release_client(self.client)
self.client = None
async def _iter_raw_docs(
self, batch_size: int = 1000
) -> AsyncIterator[list[dict[str, Any]]]:
"""Yield raw OpenSearch hits using PIT + search_after pagination."""
if not self._index_ready:
return
try:
pit = await self.client.create_pit(
index=self._index_name, params={"keep_alive": "1m"}
)
pit_id = pit["pit_id"]
try:
search_after = None
while True:
body = {
"query": {"match_all": {}},
"size": batch_size,
"pit": {"id": pit_id, "keep_alive": "1m"},
"sort": _pit_sort_with_field("__mirrored_id"),
}
if search_after:
body["search_after"] = search_after
response = await self.client.search(body=body)
hits = response["hits"]["hits"]
if not hits:
break
yield hits
search_after = hits[-1]["sort"]
if len(hits) < batch_size:
break
finally:
try:
await self.client.delete_pit(body={"pit_id": [pit_id]})
except Exception:
pass
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return
logger.error(f"[{self.workspace}] Error scanning documents: {e}")
raise
async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get a document by its ID, or None if not found."""
if not self._index_ready:
return None
try:
response = await _mget_optional_doc(self.client, self._index_name, id)
if response is None:
return None
doc = response["_source"]
doc.pop("__mirrored_id", None)
doc["_id"] = response["_id"]
doc.setdefault("create_time", 0)
doc.setdefault("update_time", 0)
return doc
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return None
logger.error(f"[{self.workspace}] Error getting document {id}: {e}")
return None
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get multiple documents by IDs, preserving input order."""
if not self._index_ready:
return [None] * len(ids)
try:
response = await self.client.mget(index=self._index_name, body={"ids": ids})
doc_map = {}
for doc in response["docs"]:
if doc.get("found"):
data = doc["_source"]
data.pop("__mirrored_id", None)
data["_id"] = doc["_id"]
data.setdefault("create_time", 0)
data.setdefault("update_time", 0)
doc_map[doc["_id"]] = data
return [doc_map.get(id) for id in ids]
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return [None] * len(ids)
logger.error(f"[{self.workspace}] Error getting documents: {e}")
return [None] * len(ids)
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return the subset of keys that do not exist in storage."""
if not self._index_ready:
return keys
try:
response = await self.client.mget(
index=self._index_name, body={"ids": list(keys)}, _source=False
)
existing_ids = {doc["_id"] for doc in response["docs"] if doc.get("found")}
return keys - existing_ids
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return keys
logger.error(f"[{self.workspace}] Error filtering keys: {e}")
return keys
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""Insert or update documents with automatic timestamping."""
if not data:
return
await self._ensure_index_ready()
logger.debug(
f"[{self.workspace}] Upserting {len(data)} documents to {self.namespace}"
)
current_time = int(time.time())
actions = []
for i, (doc_id, doc_data) in enumerate(data.items(), start=1):
doc_data["update_time"] = current_time
doc_data.setdefault("create_time", current_time)
source = {k: v for k, v in doc_data.items() if k != "_id"}
source["__mirrored_id"] = doc_id
actions.append(
{
"_op_type": "index",
"_index": self._index_name,
"_id": doc_id,
"_source": source,
}
)
await _cooperative_yield(i)
try:
# No per-operation refresh: immediate reads use ID-based mget (translog),
# search visibility is guaranteed after index_done_callback() batch refresh.
success, failed = await helpers.async_bulk(
self.client, actions, raise_on_error=False
)
if failed:
logger.warning(
f"[{self.workspace}] {len(failed)} documents failed to upsert"
)
except OpenSearchException as e:
logger.error(f"[{self.workspace}] Error upserting documents: {e}")
raise
async def index_done_callback(self) -> None:
"""Refresh index to make recently indexed documents searchable."""
if not self._index_ready:
return
try:
await self.client.indices.refresh(index=self._index_name)
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return
except Exception:
pass
async def is_empty(self) -> bool:
"""Return True if the index contains no documents."""
if not self._index_ready:
return True
try:
response = await self.client.count(index=self._index_name)
return response["count"] == 0
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return True
async def delete(self, ids: list[str]) -> None:
"""Delete documents by their IDs."""
if not ids:
return
if not self._index_ready:
return
if isinstance(ids, set):
ids = list(ids)
try:
# No per-operation refresh: immediate reads use ID-based mget (translog),
# search visibility is guaranteed after index_done_callback() batch refresh.
actions = [
{"_op_type": "delete", "_index": self._index_name, "_id": doc_id}
for doc_id in ids
]
success, _ = await helpers.async_bulk(
self.client, actions, raise_on_error=False
)
logger.info(
f"[{self.workspace}] Deleted {success} documents from {self.namespace}"
)
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return
logger.error(f"[{self.workspace}] Error deleting documents: {e}")
async def drop(self) -> dict[str, str]:
"""Delete the entire index."""
try:
try:
await self.client.indices.delete(index=self._index_name)
logger.info(f"[{self.workspace}] Dropped index: {self._index_name}")
except NotFoundError:
logger.info(
f"[{self.workspace}] Index already missing during drop: {self._index_name}"
)
self._mark_index_missing()
return {"status": "success", "message": f"Index {self._index_name} dropped"}
except OpenSearchException as e:
self._mark_index_missing()
logger.error(f"[{self.workspace}] Error dropping index: {e}")
return {"status": "error", "message": str(e)}
except Exception as e:
self._mark_index_missing()
logger.error(f"[{self.workspace}] Unexpected error dropping index: {e}")
return {"status": "error", "message": str(e)}
@final
@dataclass
class OpenSearchDocStatusStorage(DocStatusStorage):
"""Document status storage using OpenSearch."""
client: AsyncOpenSearch = field(default=None)
_index_name: str = field(default="", init=False)
_index_ready: bool = field(default=False, init=False)
def __init__(self, namespace, global_config, embedding_func, workspace=None):
super().__init__(
namespace=namespace,
workspace=workspace or "",
global_config=global_config,
embedding_func=embedding_func,
)
self.__post_init__()
def __post_init__(self):
self.workspace, self.final_namespace, self._index_name = _build_index_name(
self.workspace, self.namespace
)
def _prepare_doc_status_data(self, doc: dict[str, Any]) -> dict[str, Any]:
"""Normalize a raw OpenSearch document to DocProcessingStatus-compatible dict."""
data = doc.copy()
data.pop("_id", None)
data.pop("__mirrored_id", None)
if "file_path" not in data:
data["file_path"] = "no-file-path"
data.setdefault("metadata", {})
data.setdefault("error_msg", None)
if "error" in data:
if not data.get("error_msg"):
data["error_msg"] = data.pop("error")
else:
data.pop("error", None)
return data
async def initialize(self):
"""Initialize client connection and create doc status index."""
async with get_data_init_lock():
if self.client is None:
self.client = await ClientManager.get_client()
await self._create_index_if_not_exists()
self._index_ready = True
logger.debug(
f"[{self.workspace}] OpenSearch DocStatus storage initialized: {self._index_name}"
)
async def _ensure_index_ready(self):
"""Recreate the doc status index after drop before the next write."""
if self._index_ready:
return
async with get_data_init_lock():
if self.client is None:
self.client = await ClientManager.get_client()
if not self._index_ready:
await self._create_index_if_not_exists()
self._index_ready = True
def _mark_index_missing(self):
"""Mark the doc status index as unavailable for subsequent read short-circuiting."""
self._index_ready = False
async def _create_index_if_not_exists(self):
try:
if not await self.client.indices.exists(index=self._index_name):
body = {
"mappings": {
"dynamic": True,
"properties": {
"__mirrored_id": {"type": "keyword"},
"status": {"type": "keyword"},
"file_path": {"type": "keyword"},
"track_id": {"type": "keyword"},
"created_at": {"type": "date"},
"updated_at": {"type": "date"},
},
},
"settings": {
"index": {
"number_of_shards": _get_index_number_of_shards(),
"number_of_replicas": _get_index_number_of_replicas(),
},
},
}
await self.client.indices.create(index=self._index_name, body=body)
logger.info(
f"[{self.workspace}] Created doc status index: {self._index_name}"
)
else:
await _verify_mirrored_id_mapping(self.client, self._index_name)
except RequestError as e:
if "resource_already_exists_exception" not in str(e):
raise
except OpenSearchException as e:
logger.error(f"[{self.workspace}] Error creating doc status index: {e}")
raise
async def finalize(self):
"""Release the OpenSearch client connection."""
if self.client is not None:
await ClientManager.release_client(self.client)
self.client = None
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
"""Get a document status record by ID."""
if not self._index_ready:
return None
try:
response = await _mget_optional_doc(self.client, self._index_name, id)
if response is None:
return None
doc = response["_source"]
doc["_id"] = response["_id"]
return doc
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return None
logger.error(f"[{self.workspace}] Error getting doc status {id}: {e}")
return None
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get multiple document status records by IDs."""
if not self._index_ready:
return [None] * len(ids)
try:
response = await self.client.mget(index=self._index_name, body={"ids": ids})
doc_map = {}
for doc in response["docs"]:
if doc.get("found"):
data = doc["_source"]
data["_id"] = doc["_id"]
doc_map[doc["_id"]] = data
return [doc_map.get(id) for id in ids]
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return [None] * len(ids)
logger.error(f"[{self.workspace}] Error getting doc statuses: {e}")
return [None] * len(ids)
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return the subset of keys that do not exist in storage."""
if not self._index_ready:
return keys
try:
response = await self.client.mget(
index=self._index_name, body={"ids": list(keys)}, _source=False
)
existing_ids = {doc["_id"] for doc in response["docs"] if doc.get("found")}
return keys - existing_ids
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return keys
logger.error(f"[{self.workspace}] Error filtering keys: {e}")
return keys
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""Insert or update document status records."""
if not data:
return
await self._ensure_index_ready()
logger.debug(f"[{self.workspace}] Upserting {len(data)} doc statuses")
actions = []
for i, (k, v) in enumerate(data.items(), start=1):
v.setdefault("chunks_list", [])
source = {fk: fv for fk, fv in v.items() if fk != "_id"}
source["__mirrored_id"] = k
actions.append(
{
"_op_type": "index",
"_index": self._index_name,
"_id": k,
"_source": source,
}
)
await _cooperative_yield(i)
try:
# DocStatus needs refresh="wait_for" because get_docs_by_status
# (search-based) is called immediately after enqueue upserts.
await helpers.async_bulk(
self.client, actions, raise_on_error=False, refresh="wait_for"
)
except OpenSearchException as e:
logger.error(f"[{self.workspace}] Error upserting doc statuses: {e}")
async def get_status_counts(self) -> dict[str, int]:
"""Get document counts grouped by status."""
if not self._index_ready:
return {}
try:
body = {
"size": 0,
"aggs": {"status_counts": {"terms": {"field": "status", "size": 100}}},
}
response = await self.client.search(index=self._index_name, body=body)
return {
bucket["key"]: bucket["doc_count"]
for bucket in response["aggregations"]["status_counts"]["buckets"]
}
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return {}
logger.error(f"[{self.workspace}] Error getting status counts: {e}")
return {}
async def _search_all_docs(self, query: dict) -> dict[str, DocProcessingStatus]:
"""Fetch all documents matching a query using PIT + search_after."""
if not self._index_ready:
return {}
result = {}
batch_size = 10000
try:
pit = await self.client.create_pit(
index=self._index_name, params={"keep_alive": "1m"}
)
pit_id = pit["pit_id"]
try:
search_after = None
while True:
body = {
"query": query,
"size": batch_size,
"pit": {"id": pit_id, "keep_alive": "1m"},
"sort": _pit_sort_with_field("__mirrored_id"),
}
if search_after:
body["search_after"] = search_after
response = await self.client.search(body=body)
hits = response["hits"]["hits"]
if not hits:
break
for hit in hits:
try:
data = self._prepare_doc_status_data(hit["_source"])
result[hit["_id"]] = DocProcessingStatus(**data)
except (KeyError, TypeError) as e:
logger.error(
f"[{self.workspace}] Error parsing doc {hit['_id']}: {e}"
)
search_after = hits[-1]["sort"]
if len(hits) < batch_size:
break
finally:
try:
await self.client.delete_pit(body={"pit_id": [pit_id]})
except Exception:
pass
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return {}
logger.error(f"[{self.workspace}] Error fetching docs: {e}")
return result
async def get_docs_by_status(
self, status: DocStatus
) -> dict[str, DocProcessingStatus]:
"""Get all documents matching a specific processing status."""
return await self.get_docs_by_statuses([status])
async def get_docs_by_statuses(
self, statuses: list[DocStatus]
) -> dict[str, DocProcessingStatus]:
"""Get all documents matching any of the given statuses in a single query.
Uses OpenSearch's terms query (multi-value equivalent of term) to fetch
all matching statuses in one PIT + search_after pass instead of one
full scan per status.
"""
if not statuses:
return {}
status_values = [s.value for s in statuses]
return await self._search_all_docs({"terms": {"status": status_values}})
async def get_docs_by_track_id(
self, track_id: str
) -> dict[str, DocProcessingStatus]:
"""Get all documents matching a specific track ID."""
return await self._search_all_docs({"term": {"track_id": track_id}})
async def get_docs_paginated(
self,
status_filter: DocStatus | None = None,
page: int = 1,
page_size: int = 50,
sort_field: str = "updated_at",
sort_direction: str = "desc",
) -> tuple[list[tuple[str, DocProcessingStatus]], int]:
"""Get documents with pagination using PIT + search_after."""
if not self._index_ready:
return [], 0
page = max(1, page)
page_size = max(10, min(200, page_size))
if sort_field == "id":
sort_field = "_id"
if sort_field not in ("created_at", "updated_at", "_id", "file_path"):
sort_field = "updated_at"
sort_order = "asc" if sort_direction.lower() == "asc" else "desc"
query = {"match_all": {}}
if status_filter is not None:
query = {"term": {"status": status_filter.value}}
skip_count = (page - 1) * page_size
try:
count_resp = await self.client.count(
index=self._index_name, body={"query": query}
)
total_count = count_resp.get("count", 0)
if total_count == 0 or skip_count >= total_count:
return [], total_count
sort_clause = [{sort_field: {"order": sort_order}}] + _pit_sort_with_field(
"__mirrored_id"
)
pit = await self.client.create_pit(
index=self._index_name, params={"keep_alive": "1m"}
)
pit_id = pit["pit_id"]
try:
search_after = None
skipped = 0
while skipped < skip_count:
batch = min(page_size, skip_count - skipped)
body = {
"query": query,
"sort": sort_clause,
"size": batch,
"pit": {"id": pit_id, "keep_alive": "1m"},
}
if search_after:
body["search_after"] = search_after
resp = await self.client.search(body=body)
hits = resp["hits"]["hits"]
if not hits:
return [], total_count
search_after = hits[-1]["sort"]
skipped += len(hits)
body = {
"query": query,
"sort": sort_clause,
"size": page_size,
"pit": {"id": pit_id, "keep_alive": "1m"},
}
if search_after:
body["search_after"] = search_after
response = await self.client.search(body=body)
finally:
try:
await self.client.delete_pit(body={"pit_id": [pit_id]})
except Exception:
pass
documents = []
for hit in response["hits"]["hits"]:
try:
data = self._prepare_doc_status_data(hit["_source"])
documents.append((hit["_id"], DocProcessingStatus(**data)))
except (KeyError, TypeError) as e:
logger.error(
f"[{self.workspace}] Error parsing doc {hit['_id']}: {e}"
)
return documents, total_count
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return [], 0
logger.error(f"[{self.workspace}] Error in paginated query: {e}")
return [], 0
async def get_all_status_counts(self) -> dict[str, int]:
"""Get document counts for all statuses including an 'all' total."""
if not self._index_ready:
return {}
try:
body = {
"size": 0,
"aggs": {"status_counts": {"terms": {"field": "status", "size": 100}}},
}
response = await self.client.search(index=self._index_name, body=body)
counts = {}
total = 0
for bucket in response["aggregations"]["status_counts"]["buckets"]:
counts[bucket["key"]] = bucket["doc_count"]
total += bucket["doc_count"]
counts["all"] = total
return counts
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return {}
logger.error(f"[{self.workspace}] Error getting all status counts: {e}")
return {}
async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]:
"""Find a document status record by its file_path field."""
if not self._index_ready:
return None
try:
body = {"query": {"term": {"file_path": file_path}}, "size": 1}
response = await self.client.search(index=self._index_name, body=body)
hits = response["hits"]["hits"]
if hits:
doc = hits[0]["_source"]
doc["_id"] = hits[0]["_id"]
return doc
return None
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return None
logger.error(f"[{self.workspace}] Error getting doc by file_path: {e}")
return None
async def index_done_callback(self) -> None:
"""Refresh index to make recently indexed documents searchable."""
if not self._index_ready:
return
try:
await self.client.indices.refresh(index=self._index_name)
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return
except Exception:
pass
async def is_empty(self) -> bool:
"""Return True if the index contains no documents."""
if not self._index_ready:
return True
try:
response = await self.client.count(index=self._index_name)
return response["count"] == 0
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return True
async def delete(self, ids: list[str]) -> None:
"""Delete document status records by IDs."""
if not ids:
return
if not self._index_ready:
return
if isinstance(ids, set):
ids = list(ids)
try:
# DocStatus needs refresh="wait_for" because downstream readers
# (get_docs_by_status, get_docs_paginated, etc.) are search-based
# and callers like _validate_and_fix_document_consistency() may
# query immediately after deletion without index_done_callback().
actions = [
{"_op_type": "delete", "_index": self._index_name, "_id": doc_id}
for doc_id in ids
]
await helpers.async_bulk(
self.client, actions, raise_on_error=False, refresh="wait_for"
)
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return
logger.error(f"[{self.workspace}] Error deleting doc statuses: {e}")
async def drop(self) -> dict[str, str]:
"""Delete the entire doc status index."""
try:
try:
await self.client.indices.delete(index=self._index_name)
logger.info(
f"[{self.workspace}] Dropped doc status index: {self._index_name}"
)
except NotFoundError:
logger.info(
f"[{self.workspace}] Doc status index already missing during drop: {self._index_name}"
)
self._mark_index_missing()
return {"status": "success", "message": f"Index {self._index_name} dropped"}
except OpenSearchException as e:
self._mark_index_missing()
logger.error(f"[{self.workspace}] Error dropping doc status index: {e}")
return {"status": "error", "message": str(e)}
except Exception as e:
self._mark_index_missing()
logger.error(
f"[{self.workspace}] Unexpected error dropping doc status index: {e}"
)
return {"status": "error", "message": str(e)}
@final
@dataclass
class OpenSearchGraphStorage(BaseGraphStorage):
"""Graph storage using OpenSearch with separate nodes and edges indices.
Supports two BFS traversal strategies:
- PPL graphlookup (server-side BFS, requires OpenSearch SQL plugin with Calcite engine)
- Application-level batched BFS (fallback, works on any OpenSearch 3.x+)
The strategy is auto-detected during initialize() and can be overridden via
the OPENSEARCH_USE_PPL_GRAPHLOOKUP environment variable (true/false).
"""
client: AsyncOpenSearch = field(default=None)
_nodes_index: str = field(default="", init=False)
_edges_index: str = field(default="", init=False)
_indices_ready: bool = field(default=False, init=False)
_nodes_dirty: bool = field(default=False, init=False)
_edges_dirty: bool = field(default=False, init=False)
_ppl_graphlookup_available: bool = field(default=False, init=False)
def __init__(self, namespace, global_config, embedding_func, workspace=None):
super().__init__(
namespace=namespace,
workspace=workspace or "",
global_config=global_config,
embedding_func=embedding_func,
)
self.__post_init__()
def __post_init__(self):
self.workspace, self.final_namespace, base_name = _build_index_name(
self.workspace, self.namespace
)
self._nodes_index = f"{base_name}-nodes"
self._edges_index = f"{base_name}-edges"
async def initialize(self):
"""Initialize client, create indices, and detect PPL graphlookup support."""
async with get_data_init_lock():
if self.client is None:
self.client = await ClientManager.get_client()
await self._create_indices_if_not_exist()
self._indices_ready = True
self._nodes_dirty = False
self._edges_dirty = False
await self._detect_ppl_graphlookup()
logger.debug(
f"[{self.workspace}] OpenSearch Graph storage initialized: "
f"{self._nodes_index}, {self._edges_index} "
f"(PPL graphlookup: {self._ppl_graphlookup_available})"
)
async def _ensure_indices_ready(self):
"""Recreate graph indices after drop before the next write."""
if self._indices_ready:
return
async with get_data_init_lock():
if self.client is None:
self.client = await ClientManager.get_client()
if not self._indices_ready:
await self._create_indices_if_not_exist()
self._indices_ready = True
def _mark_indices_missing(self):
"""Mark graph indices as unavailable for subsequent read short-circuiting."""
self._indices_ready = False
self._nodes_dirty = False
self._edges_dirty = False
async def _refresh_graph_indices_if_dirty(
self, *, refresh_nodes: bool = False, refresh_edges: bool = False
) -> None:
"""Refresh graph indices only when prior writes made search views stale."""
if not self._indices_ready:
return
if not (
(refresh_nodes and self._nodes_dirty)
or (refresh_edges and self._edges_dirty)
):
return
try:
async with get_data_init_lock():
if refresh_nodes and self._nodes_dirty:
await self.client.indices.refresh(index=self._nodes_index)
self._nodes_dirty = False
if refresh_edges and self._edges_dirty:
await self.client.indices.refresh(index=self._edges_index)
self._edges_dirty = False
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_indices_missing()
return
raise
async def _detect_ppl_graphlookup(self):
"""Detect whether PPL graphlookup command is available on this cluster."""
env_override = os.environ.get("OPENSEARCH_USE_PPL_GRAPHLOOKUP", "").lower()
if env_override == "true":
self._ppl_graphlookup_available = True
return
if env_override == "false":
self._ppl_graphlookup_available = False
return
# Auto-detect by sending a minimal PPL query
try:
await self.client.transport.perform_request(
"POST",
"/_plugins/_ppl",
body={"query": f"source = {self._edges_index} | head 0"},
)
# PPL endpoint works; now test graphlookup syntax with a no-op query
await self.client.transport.perform_request(
"POST",
"/_plugins/_ppl",
body={
"query": (
f"source = {self._edges_index} | head 1 "
f"| graphLookup {self._edges_index} "
f"start=source_node_id edge=target_node_id-->source_node_id "
f"maxDepth=0 as _gl_probe"
)
},
)
self._ppl_graphlookup_available = True
logger.info(
f"[{self.workspace}] PPL graphlookup is available, using server-side BFS"
)
except Exception:
self._ppl_graphlookup_available = False
logger.info(
f"[{self.workspace}] PPL graphlookup not available, using client-side BFS"
)
async def _create_indices_if_not_exist(self):
try:
if not await self.client.indices.exists(index=self._nodes_index):
body = {
"mappings": {
"dynamic": True,
"properties": {
"entity_id": {"type": "keyword"},
"entity_type": {"type": "keyword"},
"description": {"type": "text"},
"source_id": {"type": "text"},
"source_ids": {"type": "keyword"},
"file_path": {"type": "keyword"},
"created_at": {"type": "long"},
},
},
"settings": {
"index": {
"number_of_shards": _get_index_number_of_shards(),
"number_of_replicas": _get_index_number_of_replicas(),
}
},
}
await self.client.indices.create(index=self._nodes_index, body=body)
logger.info(
f"[{self.workspace}] Created nodes index: {self._nodes_index}"
)
except RequestError as e:
if "resource_already_exists_exception" not in str(e):
raise
try:
if not await self.client.indices.exists(index=self._edges_index):
body = {
"mappings": {
"dynamic": True,
"properties": {
"source_node_id": {"type": "keyword"},
"target_node_id": {"type": "keyword"},
"relationship": {"type": "keyword"},
"description": {"type": "text"},
"weight": {"type": "float"},
"keywords": {"type": "text"},
"source_id": {"type": "text"},
"source_ids": {"type": "keyword"},
"file_path": {"type": "keyword"},
"created_at": {"type": "long"},
},
},
"settings": {
"index": {
"number_of_shards": _get_index_number_of_shards(),
"number_of_replicas": _get_index_number_of_replicas(),
}
},
}
await self.client.indices.create(index=self._edges_index, body=body)
logger.info(
f"[{self.workspace}] Created edges index: {self._edges_index}"
)
except RequestError as e:
if "resource_already_exists_exception" not in str(e):
raise
async def finalize(self):
"""Release the OpenSearch client connection."""
if self.client is not None:
await ClientManager.release_client(self.client)
self.client = None
# --- Basic queries ---
async def has_node(self, node_id: str) -> bool:
"""Check whether a node exists in the graph."""
if not self._indices_ready:
return False
try:
return await self.client.exists(index=self._nodes_index, id=node_id)
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_indices_missing()
return False
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""Check whether an edge exists between two nodes (bidirectional).
Uses mget with the two candidate edge IDs so the check is real-time
(translog-backed), consistent with has_node() and independent of the
index refresh cycle.
"""
if not self._indices_ready:
return False
try:
forward_id = compute_mdhash_id(
f"{source_node_id}-{target_node_id}", prefix="edge-"
)
reverse_id = compute_mdhash_id(
f"{target_node_id}-{source_node_id}", prefix="edge-"
)
response = await self.client.mget(
index=self._edges_index, body={"ids": [forward_id, reverse_id]}
)
return any(doc.get("found") for doc in response.get("docs", []))
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_indices_missing()
return False
async def node_degree(self, node_id: str) -> int:
"""Count the number of edges connected to a node."""
if not self._indices_ready:
return 0
try:
await self._refresh_graph_indices_if_dirty(refresh_edges=True)
response = await self.client.count(
index=self._edges_index,
body={
"query": {
"bool": {
"should": [
{"term": {"source_node_id": node_id}},
{"term": {"target_node_id": node_id}},
]
}
}
},
)
return response.get("count", 0)
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_indices_missing()
return 0
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""Sum of degrees of both endpoint nodes."""
src_degree = await self.node_degree(src_id)
tgt_degree = await self.node_degree(tgt_id)
return src_degree + tgt_degree
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get a node document by ID, or None if not found."""
if not self._indices_ready:
return None
try:
response = await _mget_optional_doc(self.client, self._nodes_index, node_id)
if response is None:
return None
doc = response["_source"]
doc["_id"] = response["_id"]
return doc
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_indices_missing()
return None
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
"""Get an edge between two nodes (bidirectional), or None.
Uses mget with the two candidate edge IDs so the read is real-time
(translog-backed), consistent with get_node() and independent of the
index refresh cycle.
"""
if not self._indices_ready:
return None
try:
forward_id = compute_mdhash_id(
f"{source_node_id}-{target_node_id}", prefix="edge-"
)
reverse_id = compute_mdhash_id(
f"{target_node_id}-{source_node_id}", prefix="edge-"
)
response = await self.client.mget(
index=self._edges_index, body={"ids": [forward_id, reverse_id]}
)
for doc in response.get("docs", []):
if doc.get("found"):
result = doc["_source"]
result["_id"] = doc["_id"]
return result
return None
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_indices_missing()
return None
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""Get all (source, target) edge tuples connected to a node."""
if not self._indices_ready:
return None
try:
await self._refresh_graph_indices_if_dirty(refresh_edges=True)
query = {
"bool": {
"should": [
{"term": {"source_node_id": source_node_id}},
{"term": {"target_node_id": source_node_id}},
]
}
}
edges = []
pit = await self.client.create_pit(
index=self._edges_index, params={"keep_alive": "1m"}
)
pit_id = pit["pit_id"]
try:
search_after = None
while True:
body = {
"query": query,
"_source": ["source_node_id", "target_node_id"],
"size": 10000,
"pit": {"id": pit_id, "keep_alive": "1m"},
"sort": _pit_sort_with_composite_key(
"source_node_id", "target_node_id"
),
}
if search_after:
body["search_after"] = search_after
response = await self.client.search(body=body)
hits = response["hits"]["hits"]
if not hits:
break
for hit in hits:
edges.append(
(
hit["_source"]["source_node_id"],
hit["_source"]["target_node_id"],
)
)
search_after = hits[-1]["sort"]
if len(hits) < 10000:
break
finally:
try:
await self.client.delete_pit(body={"pit_id": [pit_id]})
except Exception:
pass
return edges
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_indices_missing()
return None
# --- Batch operations ---
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""Batch-fetch multiple nodes by ID."""
if not self._indices_ready:
return {}
try:
response = await self.client.mget(
index=self._nodes_index, body={"ids": node_ids}
)
result = {}
for doc in response["docs"]:
if doc.get("found"):
data = doc["_source"]
data["_id"] = doc["_id"]
result[doc["_id"]] = data
return result
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_indices_missing()
return {}
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
"""Batch-fetch edge counts for multiple nodes using aggregations."""
if not node_ids:
return {}
if not self._indices_ready:
return {}
try:
await self._refresh_graph_indices_if_dirty(refresh_edges=True)
# Use a single query with aggregations for both source and target
body = {
"size": 0,
"query": {
"bool": {
"should": [
{"terms": {"source_node_id": node_ids}},
{"terms": {"target_node_id": node_ids}},
]
}
},
"aggs": {
"source_degrees": {
"terms": {
"field": "source_node_id",
"size": len(node_ids) * 2,
}
},
"target_degrees": {
"terms": {
"field": "target_node_id",
"size": len(node_ids) * 2,
}
},
},
}
response = await self.client.search(index=self._edges_index, body=body)
result = {}
for bucket in response["aggregations"]["source_degrees"]["buckets"]:
if bucket["key"] in node_ids:
result[bucket["key"]] = (
result.get(bucket["key"], 0) + bucket["doc_count"]
)
for bucket in response["aggregations"]["target_degrees"]["buckets"]:
if bucket["key"] in node_ids:
result[bucket["key"]] = (
result.get(bucket["key"], 0) + bucket["doc_count"]
)
return result
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_indices_missing()
return {}
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]:
"""Batch-fetch edge tuples for multiple nodes."""
result = {nid: [] for nid in node_ids}
if not self._indices_ready:
return result
try:
await self._refresh_graph_indices_if_dirty(refresh_edges=True)
query = {
"bool": {
"should": [
{"terms": {"source_node_id": node_ids}},
{"terms": {"target_node_id": node_ids}},
]
}
}
pit = await self.client.create_pit(
index=self._edges_index, params={"keep_alive": "1m"}
)
pit_id = pit["pit_id"]
try:
search_after = None
while True:
body = {
"query": query,
"_source": ["source_node_id", "target_node_id"],
"size": 10000,
"pit": {"id": pit_id, "keep_alive": "1m"},
"sort": _pit_sort_with_composite_key(
"source_node_id", "target_node_id"
),
}
if search_after:
body["search_after"] = search_after
response = await self.client.search(body=body)
hits = response["hits"]["hits"]
if not hits:
break
for hit in hits:
src = hit["_source"]["source_node_id"]
tgt = hit["_source"]["target_node_id"]
if src in result:
result[src].append((src, tgt))
if tgt in result:
result[tgt].append((src, tgt))
search_after = hits[-1]["sort"]
if len(hits) < 10000:
break
finally:
try:
await self.client.delete_pit(body={"pit_id": [pit_id]})
except Exception:
pass
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_indices_missing()
pass
return result
# --- Upsert operations ---
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""Insert or update a node. Adds entity_id for PPL compatibility."""
try:
await self._ensure_indices_ready()
doc = {k: v for k, v in node_data.items() if k != "_id"}
doc["entity_id"] = node_id
if node_data.get("source_id", ""):
doc["source_ids"] = node_data["source_id"].split(GRAPH_FIELD_SEP)
# No per-operation refresh: node reads use ID-based mget/exists
# (translog, real-time). Search visibility after index_done_callback().
await self.client.index(index=self._nodes_index, id=node_id, body=doc)
self._nodes_dirty = True
except OpenSearchException as e:
logger.error(f"[{self.workspace}] Error upserting node {node_id}: {e}")
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
"""Insert or update an edge with deterministic ID for bidirectional handling."""
try:
await self._ensure_indices_ready()
# Ensure source node exists (don't overwrite if it already has data)
if not await self.has_node(source_node_id):
await self.upsert_node(source_node_id, {})
doc = {k: v for k, v in edge_data.items() if k != "_id"}
doc["source_node_id"] = source_node_id
doc["target_node_id"] = target_node_id
if edge_data.get("source_id", ""):
doc["source_ids"] = edge_data["source_id"].split(GRAPH_FIELD_SEP)
# Use a deterministic ID for the edge so upserts work
edge_id = compute_mdhash_id(
f"{source_node_id}-{target_node_id}", prefix="edge-"
)
# Check if reverse edge exists
reverse_id = compute_mdhash_id(
f"{target_node_id}-{source_node_id}", prefix="edge-"
)
try:
if await self.client.exists(index=self._edges_index, id=reverse_id):
edge_id = reverse_id
except OpenSearchException:
pass
await self.client.index(index=self._edges_index, id=edge_id, body=doc)
self._edges_dirty = True
except OpenSearchException as e:
logger.error(
f"[{self.workspace}] Error upserting edge {source_node_id}->{target_node_id}: {e}"
)
async def upsert_nodes_batch(self, nodes: list[tuple[str, dict[str, str]]]) -> None:
"""Batch insert/update multiple nodes using the OpenSearch bulk API.
Args:
nodes: List of (node_id, node_data) tuples.
"""
if not nodes:
return
try:
await self._ensure_indices_ready()
actions = []
for node_id, node_data in nodes:
doc = {k: v for k, v in node_data.items() if k != "_id"}
doc["entity_id"] = node_id
if node_data.get("source_id", ""):
doc["source_ids"] = node_data["source_id"].split(GRAPH_FIELD_SEP)
actions.append(
{
"_op_type": "index",
"_index": self._nodes_index,
"_id": node_id,
"_source": doc,
}
)
await helpers.async_bulk(self.client, actions)
self._nodes_dirty = True
except OpenSearchException as e:
logger.error(f"[{self.workspace}] Error during batch node upsert: {e}")
async def has_nodes_batch(self, node_ids: list[str]) -> set[str]:
"""Check existence of multiple nodes using a single mget request.
Args:
node_ids: List of node IDs to check.
Returns:
Set of node_ids that exist in the graph.
"""
if not node_ids:
return set()
if not self._indices_ready:
return set()
try:
response = await self.client.mget(
index=self._nodes_index, body={"ids": node_ids}
)
return {doc["_id"] for doc in response.get("docs", []) if doc.get("found")}
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_indices_missing()
return set()
async def upsert_edges_batch(
self, edges: list[tuple[str, str, dict[str, str]]]
) -> None:
"""Batch insert/update multiple edges using the OpenSearch bulk API.
Replicates the bidirectional edge-ID logic of upsert_edge(): a canonical
forward ID is used unless a reverse-direction document already exists, in
which case the reverse ID is used so the update lands on the existing doc.
The reverse-ID look-up is done in a single mget call before the bulk write.
Args:
edges: List of (source_node_id, target_node_id, edge_data) tuples.
"""
if not edges:
return
try:
await self._ensure_indices_ready()
# Ensure all source nodes exist (mirrors upsert_edge behaviour)
source_ids = list({src for src, _tgt, _data in edges})
existing_sources = await self.has_nodes_batch(source_ids)
missing_sources = [
(nid, {}) for nid in source_ids if nid not in existing_sources
]
if missing_sources:
await self.upsert_nodes_batch(missing_sources)
# Compute forward and reverse edge IDs, then batch-check which
# reverse-direction docs already exist (one mget instead of N exists).
forward_ids = [
compute_mdhash_id(f"{src}-{tgt}", prefix="edge-")
for src, tgt, _ in edges
]
reverse_ids = [
compute_mdhash_id(f"{tgt}-{src}", prefix="edge-")
for src, tgt, _ in edges
]
try:
rev_response = await self.client.mget(
index=self._edges_index, body={"ids": reverse_ids}
)
existing_reverse = {
doc["_id"]
for doc in rev_response.get("docs", [])
if doc.get("found")
}
except OpenSearchException:
existing_reverse = set()
actions = []
reserved_edge_ids = set(existing_reverse)
for (src, tgt, edge_data), fwd_id, rev_id in zip(
edges, forward_ids, reverse_ids
):
edge_id = rev_id if rev_id in reserved_edge_ids else fwd_id
reserved_edge_ids.add(edge_id)
doc = {k: v for k, v in edge_data.items() if k != "_id"}
doc["source_node_id"] = src
doc["target_node_id"] = tgt
if edge_data.get("source_id", ""):
doc["source_ids"] = edge_data["source_id"].split(GRAPH_FIELD_SEP)
actions.append(
{
"_op_type": "index",
"_index": self._edges_index,
"_id": edge_id,
"_source": doc,
}
)
await helpers.async_bulk(self.client, actions)
self._edges_dirty = True
except OpenSearchException as e:
logger.error(f"[{self.workspace}] Error during batch edge upsert: {e}")
# --- Delete operations ---
async def delete_node(self, node_id: str) -> None:
"""Delete a node and all its connected edges.
Marks node and edge search views dirty so refresh happens lazily on the
next search/count-based graph read. Uses conflicts="proceed" to
tolerate already-deleted matches.
"""
try:
# Refresh edge search view so delete_by_query sees all un-flushed writes.
await self._refresh_graph_indices_if_dirty(refresh_edges=True)
# Delete all edges referencing this node
body = {
"query": {
"bool": {
"should": [
{"term": {"source_node_id": node_id}},
{"term": {"target_node_id": node_id}},
]
}
}
}
await self.client.delete_by_query(
index=self._edges_index,
body=body,
params={"conflicts": "proceed"},
)
# Delete the node
try:
await self.client.delete(index=self._nodes_index, id=node_id)
except NotFoundError:
pass
self._nodes_dirty = True
self._edges_dirty = True
except OpenSearchException as e:
logger.error(f"[{self.workspace}] Error deleting node {node_id}: {e}")
async def remove_nodes(self, nodes: list[str]) -> None:
"""Batch-delete multiple nodes and their connected edges.
Marks node and edge search views dirty so refresh happens lazily on the
next search/count-based graph read. Uses conflicts="proceed" to
tolerate already-deleted matches.
"""
if not nodes:
return
logger.info(f"[{self.workspace}] Deleting {len(nodes)} nodes")
try:
# Refresh edge search view so delete_by_query sees all un-flushed writes.
await self._refresh_graph_indices_if_dirty(refresh_edges=True)
# Delete edges
body = {
"query": {
"bool": {
"should": [
{"terms": {"source_node_id": nodes}},
{"terms": {"target_node_id": nodes}},
]
}
}
}
await self.client.delete_by_query(
index=self._edges_index,
body=body,
params={"conflicts": "proceed"},
)
# Delete nodes
actions = [
{"_op_type": "delete", "_index": self._nodes_index, "_id": nid}
for nid in nodes
]
await helpers.async_bulk(self.client, actions, raise_on_error=False)
self._nodes_dirty = True
self._edges_dirty = True
except OpenSearchException as e:
logger.error(f"[{self.workspace}] Error removing nodes: {e}")
async def remove_edges(self, edges: list[tuple[str, str]]) -> None:
"""Batch-delete multiple edges by deterministic ID (real-time).
Each edge is stored under one of two candidate IDs:
forward = compute_mdhash_id("src-tgt", prefix="edge-")
reverse = compute_mdhash_id("tgt-src", prefix="edge-")
We delete both candidates for every requested edge so the deletion
is effective regardless of which direction was stored.
Marks edge search views dirty so refresh happens lazily on the next
search/count-based graph read.
"""
if not edges:
return
logger.info(f"[{self.workspace}] Deleting {len(edges)} edges")
try:
operations = []
for src, tgt in edges:
for edge_id in (
compute_mdhash_id(f"{src}-{tgt}", prefix="edge-"),
compute_mdhash_id(f"{tgt}-{src}", prefix="edge-"),
):
operations.append(
{
"delete": {
"_index": self._edges_index,
"_id": edge_id,
}
}
)
await self.client.bulk(body=operations)
self._edges_dirty = True
except OpenSearchException as e:
logger.error(f"[{self.workspace}] Error removing edges: {e}")
# --- Query operations ---
async def get_all_labels(self) -> list[str]:
"""Get all node IDs (entity names) sorted alphabetically."""
if not self._indices_ready:
return []
try:
await self._refresh_graph_indices_if_dirty(refresh_nodes=True)
labels = []
pit = await self.client.create_pit(
index=self._nodes_index, params={"keep_alive": "1m"}
)
pit_id = pit["pit_id"]
try:
search_after = None
while True:
body = {
"query": {"match_all": {}},
"_source": False,
"size": 10000,
"pit": {"id": pit_id, "keep_alive": "1m"},
"sort": _pit_sort_with_field("entity_id"),
}
if search_after:
body["search_after"] = search_after
response = await self.client.search(body=body)
hits = response["hits"]["hits"]
if not hits:
break
for hit in hits:
labels.append(hit["_id"])
search_after = hits[-1]["sort"]
if len(hits) < 10000:
break
finally:
try:
await self.client.delete_pit(body={"pit_id": [pit_id]})
except Exception:
pass
labels.sort()
return labels
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_indices_missing()
return []
async def _collect_node_ids(
self, limit: int, exclude_ids: set[str] | None = None
) -> list[str]:
"""Collect up to `limit` node IDs, optionally skipping known IDs."""
if limit <= 0:
return []
excluded = exclude_ids or set()
if not excluded and limit <= 10000:
body = {
"query": {"match_all": {}},
"_source": False,
"size": limit,
}
resp = await self.client.search(index=self._nodes_index, body=body)
return [hit["_id"] for hit in resp["hits"]["hits"]]
node_ids: list[str] = []
pit = await self.client.create_pit(
index=self._nodes_index, params={"keep_alive": "1m"}
)
pit_id = pit["pit_id"]
try:
search_after = None
while len(node_ids) < limit:
body = {
"query": {"match_all": {}},
"_source": False,
"size": 10000,
"pit": {"id": pit_id, "keep_alive": "1m"},
"sort": _pit_sort_with_field("entity_id"),
}
if search_after:
body["search_after"] = search_after
resp = await self.client.search(body=body)
hits = resp["hits"]["hits"]
if not hits:
break
for hit in hits:
node_id = hit["_id"]
if node_id in excluded:
continue
node_ids.append(node_id)
if len(node_ids) >= limit:
break
search_after = hits[-1].get("sort")
if len(hits) < 10000:
break
finally:
try:
await self.client.delete_pit(body={"pit_id": [pit_id]})
except Exception:
pass
return node_ids
@staticmethod
def _edge_rank_key(edge: dict[str, Any]) -> tuple[int, float]:
"""Rank traversal edges by shallower depth first, then higher weight."""
depth = edge.get("_depth", edge.get("depth", 0))
try:
depth_value = int(depth)
except (TypeError, ValueError):
depth_value = 0
weight = edge.get("weight", 0)
try:
weight_value = float(weight)
except (TypeError, ValueError):
weight_value = 0.0
return (depth_value, -weight_value)
async def _append_edges_between_nodes(
self, node_ids: list[str], result: KnowledgeGraph
) -> None:
"""Append all edges whose source and target are both in `node_ids`."""
if not node_ids:
return
edge_query = {
"bool": {
"must": [
{"terms": {"source_node_id": node_ids}},
{"terms": {"target_node_id": node_ids}},
]
}
}
seen_edges = set()
pit = await self.client.create_pit(
index=self._edges_index, params={"keep_alive": "1m"}
)
pit_id = pit["pit_id"]
try:
search_after = None
while True:
edge_body = {
"query": edge_query,
"size": 10000,
"pit": {"id": pit_id, "keep_alive": "1m"},
"sort": _pit_sort_with_composite_key(
"source_node_id", "target_node_id"
),
}
if search_after:
edge_body["search_after"] = search_after
edge_resp = await self.client.search(body=edge_body)
hits = edge_resp["hits"]["hits"]
if not hits:
break
for hit in hits:
e = hit["_source"]
eid = f"{e['source_node_id']}-{e['target_node_id']}"
if eid not in seen_edges:
seen_edges.add(eid)
result.edges.append(self._construct_graph_edge(eid, e))
search_after = hits[-1].get("sort")
if len(hits) < 10000:
break
finally:
try:
await self.client.delete_pit(body={"pit_id": [pit_id]})
except Exception:
pass
def _construct_graph_node(self, node_id, node_data: dict) -> KnowledgeGraphNode:
return KnowledgeGraphNode(
id=node_id,
labels=[node_id],
properties={
k: v
for k, v in node_data.items()
if k
not in (
"_id",
"entity_id",
"source_ids",
"connected_edges",
"edge_count",
)
},
)
def _construct_graph_edge(self, edge_id: str, edge: dict) -> KnowledgeGraphEdge:
return KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relationship", ""),
source=edge["source_node_id"],
target=edge["target_node_id"],
properties={
k: v
for k, v in edge.items()
if k
not in (
"_id",
"source_node_id",
"target_node_id",
"relationship",
"source_ids",
)
},
)
async def get_knowledge_graph(
self,
node_label: str,
max_depth: int = 3,
max_nodes: int = None,
) -> KnowledgeGraph:
"""Retrieve a subgraph via PPL graphlookup (if available) or client-side BFS."""
if not self._indices_ready:
return KnowledgeGraph()
if max_nodes is None:
max_nodes = self.global_config.get("max_graph_nodes", 1000)
else:
max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000))
result = KnowledgeGraph()
start = time.perf_counter()
try:
await self._refresh_graph_indices_if_dirty(
refresh_nodes=True, refresh_edges=True
)
if node_label == "*":
result = await self._get_knowledge_graph_all(max_nodes)
elif self._ppl_graphlookup_available:
result = await self._bfs_subgraph_ppl(node_label, max_depth, max_nodes)
else:
result = await self._bfs_subgraph(node_label, max_depth, max_nodes)
duration = time.perf_counter() - start
logger.info(
f"[{self.workspace}] Subgraph query in {duration:.4f}s | "
f"Nodes: {len(result.nodes)} | Edges: {len(result.edges)} | Truncated: {result.is_truncated}"
)
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_indices_missing()
return KnowledgeGraph()
logger.error(f"[{self.workspace}] Graph query failed: {e}")
return result
async def _get_knowledge_graph_all(self, max_nodes: int) -> KnowledgeGraph:
"""Get all nodes (up to max_nodes, ranked by degree) and their interconnecting edges."""
result = KnowledgeGraph()
if not self._indices_ready:
return result
try:
total = (await self.client.count(index=self._nodes_index))["count"]
result.is_truncated = total > max_nodes
if result.is_truncated:
# Get top nodes by degree
body = {
"size": 0,
"aggs": {
"src": {
"terms": {
"field": "source_node_id",
"size": max_nodes,
}
},
"tgt": {
"terms": {
"field": "target_node_id",
"size": max_nodes,
}
},
},
}
resp = await self.client.search(index=self._edges_index, body=body)
degree_map = {}
for bucket in resp["aggregations"]["src"]["buckets"]:
degree_map[bucket["key"]] = (
degree_map.get(bucket["key"], 0) + bucket["doc_count"]
)
for bucket in resp["aggregations"]["tgt"]["buckets"]:
degree_map[bucket["key"]] = (
degree_map.get(bucket["key"], 0) + bucket["doc_count"]
)
top_ids = sorted(degree_map, key=degree_map.get, reverse=True)[
:max_nodes
]
if len(top_ids) < max_nodes:
top_ids.extend(
await self._collect_node_ids(
max_nodes - len(top_ids), exclude_ids=set(top_ids)
)
)
else:
top_ids = await self._collect_node_ids(max_nodes)
# Fetch node data
if top_ids:
node_resp = await self.client.mget(
index=self._nodes_index, body={"ids": top_ids}
)
found_node_ids = []
for doc in node_resp["docs"]:
if doc.get("found"):
found_node_ids.append(doc["_id"])
result.nodes.append(
self._construct_graph_node(doc["_id"], doc["_source"])
)
await self._append_edges_between_nodes(found_node_ids, result)
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_indices_missing()
return result
logger.error(f"[{self.workspace}] Error in get_knowledge_graph_all: {e}")
return result
async def _bfs_subgraph_ppl(
self, start_label: str, max_depth: int, max_nodes: int
) -> KnowledgeGraph:
"""Server-side BFS using PPL graphlookup command.
Queries the nodes index for the start node, then uses graphLookup to traverse
the edges index with bidirectional BFS. Uses `flatten` to unnest results and
`depthField` for depth-based sorting. Falls back to client-side BFS on failure.
"""
result = KnowledgeGraph()
# Verify start node exists
start_node = await self.get_node(start_label)
if not start_node:
return result
result.nodes.append(self._construct_graph_node(start_label, start_node))
if max_depth == 0:
return result
# PPL maxDepth=0 means 1 hop (direct match), so max_depth-1
ppl_depth = max(0, max_depth - 1)
escaped = self._escape_ppl(start_label)
ppl_query = (
f"source = {self._nodes_index}"
f" | where entity_id = '{escaped}'"
f" | graphLookup {self._edges_index}"
f" start=entity_id"
f" edge=target_node_id<->source_node_id"
f" maxDepth={ppl_depth}"
f" depthField=_depth"
f" usePIT=true"
f" as connected_edges"
)
try:
resp = await self.client.transport.perform_request(
"POST",
"/_plugins/_ppl",
body={"query": ppl_query},
)
except Exception as e:
logger.warning(
f"[{self.workspace}] PPL graphlookup failed, falling back to client BFS: {e}"
)
return await self._bfs_subgraph(start_label, max_depth, max_nodes)
# Parse PPL response — schema-driven to avoid fragile positional access
try:
datarows = resp.get("datarows", [])
schema = [col["name"] for col in resp.get("schema", [])]
ce_idx = (
schema.index("connected_edges") if "connected_edges" in schema else -1
)
# Collect all edge rows from connected_edges arrays
all_edge_rows = []
for row in datarows:
edges_arr = row[ce_idx] if ce_idx >= 0 else []
if isinstance(edges_arr, list):
all_edge_rows.extend(edges_arr)
if not all_edge_rows:
return result
if isinstance(all_edge_rows[0], dict):
sorted_edge_rows = sorted(all_edge_rows, key=self._edge_rank_key)
else:
# Positional array — column positions are unknown, fall back to client BFS
logger.warning(
f"[{self.workspace}] PPL returned positional arrays, falling back to client BFS"
)
return await self._bfs_subgraph(start_label, max_depth, max_nodes)
except (KeyError, IndexError, TypeError, ValueError) as e:
logger.warning(
f"[{self.workspace}] Error parsing PPL response, falling back: {e}"
)
return await self._bfs_subgraph(start_label, max_depth, max_nodes)
ordered_node_ids = [start_label]
discovered_nodes = {start_label}
for edge_row in sorted_edge_rows:
for node_id in (
edge_row.get("source_node_id"),
edge_row.get("target_node_id"),
):
if not node_id or node_id in discovered_nodes:
continue
discovered_nodes.add(node_id)
if len(ordered_node_ids) < max_nodes:
ordered_node_ids.append(node_id)
result.is_truncated = len(discovered_nodes) > max_nodes
# Batch fetch node data (start node already added)
new_node_ids = [nid for nid in ordered_node_ids if nid != start_label]
if new_node_ids:
node_resp = await self.client.mget(
index=self._nodes_index, body={"ids": new_node_ids}
)
for doc in node_resp["docs"]:
if doc.get("found"):
result.nodes.append(
self._construct_graph_node(doc["_id"], doc["_source"])
)
await self._append_edges_between_nodes(ordered_node_ids, result)
return result
@staticmethod
def _escape_ppl(value: str) -> str:
"""Escape a string for safe inclusion in a PPL single-quoted literal."""
return value.replace("\\", "\\\\").replace("'", "\\'")
async def _bfs_subgraph(
self, start_label: str, max_depth: int, max_nodes: int
) -> KnowledgeGraph:
"""BFS traversal from a starting node, batching neighbor lookups per level."""
result = KnowledgeGraph()
seen_nodes = set()
# Verify start node exists
start_node = await self.get_node(start_label)
if not start_node:
return result
seen_nodes.add(start_label)
result.nodes.append(self._construct_graph_node(start_label, start_node))
current_level = [start_label]
for _ in range(max_depth):
if not current_level or len(seen_nodes) >= max_nodes:
break
# Batch fetch all edges for current level
body = {
"query": {
"bool": {
"should": [
{"terms": {"source_node_id": current_level}},
{"terms": {"target_node_id": current_level}},
]
}
},
"_source": ["source_node_id", "target_node_id"],
"size": 10000,
}
try:
resp = await self.client.search(index=self._edges_index, body=body)
except OpenSearchException:
break
next_level = set()
for hit in resp["hits"]["hits"]:
src = hit["_source"]["source_node_id"]
tgt = hit["_source"]["target_node_id"]
if src not in seen_nodes:
next_level.add(src)
if tgt not in seen_nodes:
next_level.add(tgt)
# Limit to max_nodes
new_ids = []
for nid in next_level:
if len(seen_nodes) + len(new_ids) >= max_nodes:
break
new_ids.append(nid)
if new_ids:
# Batch fetch node data
node_resp = await self.client.mget(
index=self._nodes_index, body={"ids": new_ids}
)
for doc in node_resp["docs"]:
if doc.get("found"):
seen_nodes.add(doc["_id"])
result.nodes.append(
self._construct_graph_node(doc["_id"], doc["_source"])
)
current_level = new_ids
# Fetch all edges between seen nodes using PIT scrolling
all_ids = list(seen_nodes)
if all_ids:
try:
await self._append_edges_between_nodes(all_ids, result)
except OpenSearchException:
pass
result.is_truncated = len(seen_nodes) >= max_nodes
return result
async def get_all_nodes(self) -> list[dict]:
"""Get all nodes with their properties."""
if not self._indices_ready:
return []
try:
await self._refresh_graph_indices_if_dirty(refresh_nodes=True)
nodes = []
pit = await self.client.create_pit(
index=self._nodes_index, params={"keep_alive": "1m"}
)
pit_id = pit["pit_id"]
try:
search_after = None
while True:
body = {
"query": {"match_all": {}},
"size": 10000,
"pit": {"id": pit_id, "keep_alive": "1m"},
"sort": _pit_sort_with_field("entity_id"),
}
if search_after:
body["search_after"] = search_after
response = await self.client.search(body=body)
hits = response["hits"]["hits"]
if not hits:
break
for hit in hits:
node = hit["_source"]
node["id"] = hit["_id"]
nodes.append(node)
search_after = hits[-1]["sort"]
if len(hits) < 10000:
break
finally:
try:
await self.client.delete_pit(body={"pit_id": [pit_id]})
except Exception:
pass
return nodes
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_indices_missing()
return []
async def get_all_edges(self) -> list[dict]:
"""Get all edges with source/target fields added."""
if not self._indices_ready:
return []
try:
await self._refresh_graph_indices_if_dirty(refresh_edges=True)
edges = []
pit = await self.client.create_pit(
index=self._edges_index, params={"keep_alive": "1m"}
)
pit_id = pit["pit_id"]
try:
search_after = None
while True:
body = {
"query": {"match_all": {}},
"size": 10000,
"pit": {"id": pit_id, "keep_alive": "1m"},
"sort": _pit_sort_with_composite_key(
"source_node_id", "target_node_id"
),
}
if search_after:
body["search_after"] = search_after
response = await self.client.search(body=body)
hits = response["hits"]["hits"]
if not hits:
break
for hit in hits:
edge = hit["_source"]
edge["source"] = edge.get("source_node_id")
edge["target"] = edge.get("target_node_id")
edges.append(edge)
search_after = hits[-1]["sort"]
if len(hits) < 10000:
break
finally:
try:
await self.client.delete_pit(body={"pit_id": [pit_id]})
except Exception:
pass
return edges
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_indices_missing()
return []
async def get_popular_labels(self, limit: int = 300) -> list[str]:
"""Get node labels ranked by edge degree (most connected first)."""
if not self._indices_ready:
return []
try:
await self._refresh_graph_indices_if_dirty(refresh_edges=True)
body = {
"size": 0,
"aggs": {
"src": {"terms": {"field": "source_node_id", "size": limit * 2}},
"tgt": {"terms": {"field": "target_node_id", "size": limit * 2}},
},
}
response = await self.client.search(index=self._edges_index, body=body)
degree_map = {}
for bucket in response["aggregations"]["src"]["buckets"]:
degree_map[bucket["key"]] = (
degree_map.get(bucket["key"], 0) + bucket["doc_count"]
)
for bucket in response["aggregations"]["tgt"]["buckets"]:
degree_map[bucket["key"]] = (
degree_map.get(bucket["key"], 0) + bucket["doc_count"]
)
sorted_labels = sorted(degree_map, key=degree_map.get, reverse=True)[:limit]
return sorted_labels
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_indices_missing()
return []
async def search_labels(self, query: str, limit: int = 50) -> list[str]:
"""Search node labels with wildcard and prefix matching."""
query = query.strip()
if not query:
return []
if not self._indices_ready:
return []
try:
await self._refresh_graph_indices_if_dirty(refresh_nodes=True)
body = {
"query": {
"bool": {
"should": [
{"term": {"entity_id": {"value": query, "boost": 10}}},
{
"prefix": {
"entity_id": {"value": query.lower(), "boost": 5}
}
},
{
"wildcard": {
"entity_id": {
"value": f"*{query.lower()}*",
"case_insensitive": True,
"boost": 2,
}
}
},
]
}
},
"_source": False,
"size": limit,
}
response = await self.client.search(index=self._nodes_index, body=body)
return [hit["_id"] for hit in response["hits"]["hits"]]
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_indices_missing()
return []
async def index_done_callback(self) -> None:
"""Refresh both node and edge indices."""
if not self._indices_ready:
return
try:
await self._refresh_graph_indices_if_dirty(
refresh_nodes=True, refresh_edges=True
)
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_indices_missing()
return
except Exception:
pass
async def drop(self) -> dict[str, str]:
"""Delete both node and edge indices."""
errors = []
for idx in (self._nodes_index, self._edges_index):
try:
await self.client.indices.delete(index=idx)
logger.info(f"[{self.workspace}] Dropped graph index: {idx}")
except NotFoundError:
logger.info(
f"[{self.workspace}] Graph index already missing during drop: {idx}"
)
except OpenSearchException as e:
errors.append(f"{idx}: {e}")
logger.error(
f"[{self.workspace}] Error dropping graph index {idx}: {e}"
)
except Exception as e:
errors.append(f"{idx}: {e}")
logger.error(
f"[{self.workspace}] Unexpected error dropping graph index {idx}: {e}"
)
self._mark_indices_missing()
if errors:
return {
"status": "error",
"message": "Failed to drop graph indices: " + "; ".join(errors),
}
try:
logger.info(f"[{self.workspace}] Dropped graph indices")
return {"status": "success", "message": "Graph indices dropped"}
except Exception as e:
logger.error(f"[{self.workspace}] Error finalizing graph drop: {e}")
return {"status": "error", "message": str(e)}
@final
@dataclass
class OpenSearchVectorDBStorage(BaseVectorStorage):
"""Vector storage using OpenSearch k-NN plugin with corrected cosine score handling."""
client: AsyncOpenSearch = field(default=None)
_index_name: str = field(default="", init=False)
_index_ready: bool = field(default=False, init=False)
def __init__(
self, namespace, global_config, embedding_func, workspace=None, meta_fields=None
):
super().__init__(
namespace=namespace,
workspace=workspace or "",
global_config=global_config,
embedding_func=embedding_func,
meta_fields=meta_fields or set(),
)
self.__post_init__()
def __post_init__(self):
self._validate_embedding_func()
self.workspace, self.final_namespace, self._index_name = _build_index_name(
self.workspace, self.namespace
)
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")
if cosine_threshold is None:
raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self):
"""Initialize client and create k-NN vector index."""
async with get_data_init_lock():
if self.client is None:
self.client = await ClientManager.get_client()
await self._create_knn_index_if_not_exists()
self._index_ready = True
logger.debug(
f"[{self.workspace}] OpenSearch Vector storage initialized: {self._index_name}"
)
async def _ensure_index_ready(self):
"""Recreate the vector index before the next write if it is missing."""
if self._index_ready:
return
async with get_data_init_lock():
if self.client is None:
self.client = await ClientManager.get_client()
if not self._index_ready:
await self._create_knn_index_if_not_exists()
self._index_ready = True
def _mark_index_missing(self):
"""Mark the vector index as unavailable for subsequent read short-circuiting."""
self._index_ready = False
async def _create_knn_index_if_not_exists(self):
try:
if await self.client.indices.exists(index=self._index_name):
# Validate existing index dimension
try:
mapping = await self.client.indices.get_mapping(
index=self._index_name
)
existing_dim = (
mapping[self._index_name]["mappings"]["properties"]
.get("vector", {})
.get("dimension")
)
expected_dim = self.embedding_func.embedding_dim
if existing_dim is not None and existing_dim != expected_dim:
raise ValueError(
f"Vector dimension mismatch! Index '{self._index_name}' has "
f"dimension {existing_dim}, but current embedding model expects "
f"dimension {expected_dim}. Please drop the existing index or "
f"use an embedding model with matching dimensions."
)
except (KeyError, TypeError):
logger.warning(
f"[{self.workspace}] Could not read vector mapping for index "
f"'{self._index_name}'; skipping dimension validation"
)
return
ef_construction = int(
_get_opensearch_env("OPENSEARCH_KNN_EF_CONSTRUCTION", "200")
)
m = int(_get_opensearch_env("OPENSEARCH_KNN_M", "16"))
ef_search = int(_get_opensearch_env("OPENSEARCH_KNN_EF_SEARCH", "100"))
body = {
"settings": {
"index": {
"knn": True,
"knn.algo_param.ef_search": ef_search,
"number_of_shards": _get_index_number_of_shards(),
"number_of_replicas": _get_index_number_of_replicas(),
}
},
"mappings": {
"properties": {
"vector": {
"type": "knn_vector",
"dimension": self.embedding_func.embedding_dim,
"method": {
"name": "hnsw",
"space_type": "cosinesimil",
"engine": "lucene",
"parameters": {
"ef_construction": ef_construction,
"m": m,
},
},
},
"content": {"type": "text"},
"entity_name": {"type": "keyword"},
"src_id": {"type": "keyword"},
"tgt_id": {"type": "keyword"},
"file_path": {"type": "keyword"},
"created_at": {"type": "long"},
},
"dynamic": True,
},
}
await self.client.indices.create(index=self._index_name, body=body)
logger.info(
f"[{self.workspace}] Created k-NN index: {self._index_name} "
f"(dim={self.embedding_func.embedding_dim})"
)
except RequestError as e:
if "resource_already_exists_exception" not in str(e):
logger.error(f"[{self.workspace}] Error creating k-NN index: {e}")
raise
except OpenSearchException as e:
logger.error(f"[{self.workspace}] Error creating k-NN index: {e}")
raise
async def finalize(self):
"""Release the OpenSearch client connection."""
if self.client is not None:
await ClientManager.release_client(self.client)
self.client = None
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""Generate embeddings and upsert vectors in batches."""
if not data:
return
await self._ensure_index_ready()
logger.debug(
f"[{self.workspace}] Upserting {len(data)} vectors to {self.namespace}"
)
current_time = int(time.time())
list_data = []
for i, (k, v) in enumerate(data.items(), start=1):
list_data.append(
{
"_id": k,
"created_at": current_time,
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
)
await _cooperative_yield(i)
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embeddings_list = await asyncio.gather(
*[self.embedding_func(batch, context="document") for batch in batches]
)
embeddings = np.concatenate(embeddings_list)
assert len(embeddings) == len(
list_data
), f"Embedding count mismatch: expected {len(list_data)}, got {len(embeddings)}"
for i, doc in enumerate(list_data, start=1):
doc["vector"] = embeddings[i - 1].tolist()
await _cooperative_yield(i)
actions = []
for i, doc in enumerate(list_data, start=1):
actions.append(
{
"_op_type": "index",
"_index": self._index_name,
"_id": doc["_id"],
"_source": {k: v for k, v in doc.items() if k != "_id"},
}
)
await _cooperative_yield(i)
try:
# No per-operation refresh: immediate reads use ID-based mget (translog),
# k-NN search visibility is guaranteed after index_done_callback() batch refresh.
success, failed = await helpers.async_bulk(
self.client, actions, raise_on_error=False
)
if failed:
logger.warning(
f"[{self.workspace}] {len(failed)} vectors failed to upsert"
)
except OpenSearchException as e:
logger.error(f"[{self.workspace}] Error upserting vectors: {e}")
raise
async def query(
self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]:
"""k-NN similarity search with cosine score conversion for lucene engine."""
if not self._index_ready:
return []
if query_embedding is not None:
query_vector = (
query_embedding.tolist()
if hasattr(query_embedding, "tolist")
else list(query_embedding)
)
else:
embedding = await self.embedding_func([query], context="query", _priority=5)
query_vector = embedding[0].tolist()
search_body = {
"size": top_k,
"query": {"knn": {"vector": {"vector": query_vector, "k": top_k}}},
"_source": {"excludes": ["vector"]},
}
try:
response = await self.client.search(
index=self._index_name, body=search_body
)
results = []
for hit in response["hits"]["hits"]:
# OpenSearch k-NN with lucene engine and cosinesimil space type
# returns scores that can be used directly as similarity measure.
score = hit["_score"]
if score >= self.cosine_better_than_threshold:
doc = hit["_source"]
doc["id"] = hit["_id"]
doc["distance"] = score
results.append(doc)
logger.info(
f"[{self.workspace}] Vector query on {self._index_name}: "
f"top_k={top_k}, threshold={self.cosine_better_than_threshold}, "
f"total_hits={len(response['hits']['hits'])}, "
f"passed_filter={len(results)}, "
f"score_range=[{min((h['_score'] for h in response['hits']['hits']), default=0):.4f}, "
f"{max((h['_score'] for h in response['hits']['hits']), default=0):.4f}]"
)
return results
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return []
logger.error(f"[{self.workspace}] Error querying vectors: {e}")
return []
async def index_done_callback(self) -> None:
"""Refresh index to make recently indexed vectors searchable."""
if not self._index_ready:
return
try:
await self.client.indices.refresh(index=self._index_name)
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return
except Exception:
pass
async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get a vector document by ID."""
if not self._index_ready:
return None
try:
response = await _mget_optional_doc(self.client, self._index_name, id)
if response is None:
return None
doc = response["_source"]
doc["id"] = response["_id"]
return doc
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return None
logger.error(f"[{self.workspace}] Error getting vector {id}: {e}")
return None
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get multiple vector documents by IDs, preserving order."""
if not ids:
return []
if not self._index_ready:
return [None] * len(ids)
try:
response = await self.client.mget(index=self._index_name, body={"ids": ids})
doc_map = {}
for doc in response["docs"]:
if doc.get("found"):
data = doc["_source"]
data["id"] = doc["_id"]
doc_map[doc["_id"]] = data
return [doc_map.get(id) for id in ids]
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return [None] * len(ids)
logger.error(f"[{self.workspace}] Error getting vectors by ids: {e}")
return [None] * len(ids)
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
"""Get only the vector embeddings for given IDs."""
if not ids:
return {}
if not self._index_ready:
return {}
try:
response = await self.client.mget(
index=self._index_name, body={"ids": ids}, _source_includes=["vector"]
)
result = {}
for doc in response["docs"]:
if doc.get("found") and "vector" in doc.get("_source", {}):
result[doc["_id"]] = doc["_source"]["vector"]
return result
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return {}
logger.error(f"[{self.workspace}] Error getting vectors: {e}")
return {}
async def delete(self, ids: list[str]) -> None:
"""Delete vectors by their IDs."""
if not ids:
return
if not self._index_ready:
return
if isinstance(ids, set):
ids = list(ids)
try:
# No per-operation refresh: search visibility after index_done_callback().
actions = [
{"_op_type": "delete", "_index": self._index_name, "_id": doc_id}
for doc_id in ids
]
result = await helpers.async_bulk(
self.client, actions, raise_on_error=False
)
logger.debug(
f"[{self.workspace}] Deleted {result[0]} vectors from {self.namespace}"
)
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return
logger.error(f"[{self.workspace}] Error deleting vectors: {e}")
async def delete_entity(self, entity_name: str) -> None:
"""Delete an entity vector by computing its hash ID."""
if not self._index_ready:
return
try:
# No per-operation refresh: search visibility after index_done_callback().
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
try:
await self.client.delete(index=self._index_name, id=entity_id)
logger.debug(f"[{self.workspace}] Deleted entity {entity_name}")
except NotFoundError as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return
logger.debug(f"[{self.workspace}] Entity {entity_name} not found")
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return
logger.error(f"[{self.workspace}] Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete all relation vectors where entity appears as src or tgt."""
if not self._index_ready:
return
try:
# No per-operation refresh: search visibility after index_done_callback().
body = {
"query": {
"bool": {
"should": [
{"term": {"src_id": entity_name}},
{"term": {"tgt_id": entity_name}},
]
}
}
}
# conflicts="proceed" tolerates stale search view after refresh removal.
await self.client.delete_by_query(
index=self._index_name, body=body, params={"conflicts": "proceed"}
)
logger.debug(
f"[{self.workspace}] Deleted relations for entity {entity_name}"
)
except OpenSearchException as e:
if _is_missing_index_error(e):
self._mark_index_missing()
return
logger.error(
f"[{self.workspace}] Error deleting relations for {entity_name}: {e}"
)
async def drop(self) -> dict[str, str]:
"""Delete and recreate the vector index."""
try:
try:
await self.client.indices.delete(index=self._index_name)
logger.info(
f"[{self.workspace}] Dropped vector index: {self._index_name}"
)
except NotFoundError:
logger.info(
f"[{self.workspace}] Vector index already missing during drop: {self._index_name}"
)
# Recreate the index
await self._create_knn_index_if_not_exists()
self._index_ready = True
logger.info(
f"[{self.workspace}] Dropped and recreated vector index: {self._index_name}"
)
return {
"status": "success",
"message": f"Vector index {self._index_name} dropped and recreated",
}
except OpenSearchException as e:
self._mark_index_missing()
logger.error(f"[{self.workspace}] Error dropping vector index: {e}")
return {"status": "error", "message": str(e)}
except Exception as e:
self._mark_index_missing()
logger.error(
f"[{self.workspace}] Unexpected error dropping vector index: {e}"
)
return {"status": "error", "message": str(e)}