feat: 重构知识库系统,移除Hermes集成,增强RAG和同步功能

主要变更:
- 移除Hermes智能体及相关回调服务
- 新增知识库RAG、同步、调度、规范化和索引任务服务
- 重构orchestrator服务,增强运行时聊天功能
- 更新前端聊天、政策制度、设置等页面样式和逻辑
- 更新expense_claims和document_intelligence服务
- 删除llm_wiki相关服务和测试文件
- 更新docker-compose配置和启动脚本
This commit is contained in:
caoxiaozhu
2026-05-17 08:38:41 +00:00
parent 212c935308
commit 68f663f2f4
308 changed files with 83729 additions and 13588 deletions

View File

@@ -0,0 +1,161 @@
STORAGE_IMPLEMENTATIONS = {
"KV_STORAGE": {
"implementations": [
"JsonKVStorage",
"RedisKVStorage",
"PGKVStorage",
"MongoKVStorage",
"OpenSearchKVStorage",
],
"required_methods": ["get_by_id", "upsert"],
},
"GRAPH_STORAGE": {
"implementations": [
"NetworkXStorage",
"Neo4JStorage",
"PGGraphStorage",
"MongoGraphStorage",
"MemgraphStorage",
"OpenSearchGraphStorage",
],
"required_methods": ["upsert_node", "upsert_edge"],
},
"VECTOR_STORAGE": {
"implementations": [
"NanoVectorDBStorage",
"MilvusVectorDBStorage",
"PGVectorStorage",
"FaissVectorDBStorage",
"QdrantVectorDBStorage",
"MongoVectorDBStorage",
"OpenSearchVectorDBStorage",
# "ChromaVectorDBStorage",
],
"required_methods": ["query", "upsert"],
},
"DOC_STATUS_STORAGE": {
"implementations": [
"JsonDocStatusStorage",
"RedisDocStatusStorage",
"PGDocStatusStorage",
"MongoDocStatusStorage",
"OpenSearchDocStatusStorage",
],
"required_methods": ["get_docs_by_status"],
},
}
# Storage implementation environment variable without default value
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
# KV Storage Implementations
"JsonKVStorage": [],
"MongoKVStorage": [
"MONGO_URI",
"MONGO_DATABASE",
],
"RedisKVStorage": ["REDIS_URI"],
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
# Graph Storage Implementations
"NetworkXStorage": [],
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
"MongoGraphStorage": [
"MONGO_URI",
"MONGO_DATABASE",
],
"MemgraphStorage": ["MEMGRAPH_URI"],
"AGEStorage": [
"AGE_POSTGRES_DB",
"AGE_POSTGRES_USER",
"AGE_POSTGRES_PASSWORD",
],
"PGGraphStorage": [
"POSTGRES_USER",
"POSTGRES_PASSWORD",
"POSTGRES_DATABASE",
],
# Vector Storage Implementations
"NanoVectorDBStorage": [],
"MilvusVectorDBStorage": [
"MILVUS_URI",
"MILVUS_DB_NAME",
],
# "ChromaVectorDBStorage": [],
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"FaissVectorDBStorage": [],
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
"MongoVectorDBStorage": [
"MONGO_URI",
"MONGO_DATABASE",
],
# Document Status Storage Implementations
"JsonDocStatusStorage": [],
"RedisDocStatusStorage": ["REDIS_URI"],
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"MongoDocStatusStorage": [
"MONGO_URI",
"MONGO_DATABASE",
],
# OpenSearch Storage Implementations
"OpenSearchKVStorage": [
"OPENSEARCH_HOSTS",
],
"OpenSearchDocStatusStorage": [
"OPENSEARCH_HOSTS",
],
"OpenSearchGraphStorage": [
"OPENSEARCH_HOSTS",
],
"OpenSearchVectorDBStorage": [
"OPENSEARCH_HOSTS",
],
}
# Storage implementation module mapping
STORAGES = {
"NetworkXStorage": ".kg.networkx_impl",
"JsonKVStorage": ".kg.json_kv_impl",
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
"JsonDocStatusStorage": ".kg.json_doc_status_impl",
"Neo4JStorage": ".kg.neo4j_impl",
"MilvusVectorDBStorage": ".kg.milvus_impl",
"MongoKVStorage": ".kg.mongo_impl",
"MongoDocStatusStorage": ".kg.mongo_impl",
"MongoGraphStorage": ".kg.mongo_impl",
"MongoVectorDBStorage": ".kg.mongo_impl",
"RedisKVStorage": ".kg.redis_impl",
"RedisDocStatusStorage": ".kg.redis_impl",
"ChromaVectorDBStorage": ".kg.chroma_impl",
"PGKVStorage": ".kg.postgres_impl",
"PGVectorStorage": ".kg.postgres_impl",
"AGEStorage": ".kg.age_impl",
"PGGraphStorage": ".kg.postgres_impl",
"PGDocStatusStorage": ".kg.postgres_impl",
"FaissVectorDBStorage": ".kg.faiss_impl",
"QdrantVectorDBStorage": ".kg.qdrant_impl",
"MemgraphStorage": ".kg.memgraph_impl",
"OpenSearchKVStorage": ".kg.opensearch_impl",
"OpenSearchDocStatusStorage": ".kg.opensearch_impl",
"OpenSearchGraphStorage": ".kg.opensearch_impl",
"OpenSearchVectorDBStorage": ".kg.opensearch_impl",
}
def verify_storage_implementation(storage_type: str, storage_name: str) -> None:
"""Verify if storage implementation is compatible with specified storage type
Args:
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
storage_name: Storage implementation name
Raises:
ValueError: If storage implementation is incompatible or missing required methods
"""
if storage_type not in STORAGE_IMPLEMENTATIONS:
raise ValueError(f"Unknown storage type: {storage_type}")
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
if storage_name not in storage_info["implementations"]:
raise ValueError(
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
)

View File

@@ -0,0 +1,343 @@
import asyncio
import os
from dataclasses import dataclass
from typing import Any, final
import numpy as np
from lightrag.base import BaseVectorStorage
from lightrag.utils import logger
import pipmaster as pm
if not pm.is_installed("chromadb"):
pm.install("chromadb")
from chromadb import HttpClient, PersistentClient # type: ignore
from chromadb.config import Settings # type: ignore
@final
@dataclass
class ChromaVectorDBStorage(BaseVectorStorage):
"""ChromaDB vector storage implementation."""
def __post_init__(self):
self._validate_embedding_func()
try:
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold")
if cosine_threshold is None:
raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
user_collection_settings = config.get("collection_settings", {})
# Default HNSW index settings for ChromaDB
default_collection_settings = {
# Distance metric used for similarity search (cosine similarity)
"hnsw:space": "cosine",
# Number of nearest neighbors to explore during index construction
# Higher values = better recall but slower indexing
"hnsw:construction_ef": 128,
# Number of nearest neighbors to explore during search
# Higher values = better recall but slower search
"hnsw:search_ef": 128,
# Number of connections per node in the HNSW graph
# Higher values = better recall but more memory usage
"hnsw:M": 16,
# Number of vectors to process in one batch during indexing
"hnsw:batch_size": 100,
# Number of updates before forcing index synchronization
# Lower values = more frequent syncs but slower indexing
"hnsw:sync_threshold": 1000,
}
collection_settings = {
**default_collection_settings,
**user_collection_settings,
}
local_path = config.get("local_path", None)
if local_path:
self._client = PersistentClient(
path=local_path,
settings=Settings(
allow_reset=True,
anonymized_telemetry=False,
),
)
else:
auth_provider = config.get(
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
)
auth_credentials = config.get("auth_token", "secret-token")
headers = {}
if "token_authn" in auth_provider:
headers = {
config.get(
"auth_header_name", "X-Chroma-Token"
): auth_credentials
}
elif "basic_authn" in auth_provider:
auth_credentials = config.get("auth_credentials", "admin:admin")
self._client = HttpClient(
host=config.get("host", "localhost"),
port=config.get("port", 8000),
headers=headers,
settings=Settings(
chroma_api_impl="rest",
chroma_client_auth_provider=auth_provider,
chroma_client_auth_credentials=auth_credentials,
allow_reset=True,
anonymized_telemetry=False,
),
)
self._collection = self._client.get_or_create_collection(
name=self.namespace,
metadata={
**collection_settings,
"dimension": self.embedding_func.embedding_dim,
},
)
# Use batch size from collection settings if specified
self._max_batch_size = self.global_config.get(
"embedding_batch_num", collection_settings.get("hnsw:batch_size", 32)
)
except Exception as e:
logger.error(f"ChromaDB initialization failed: {str(e)}")
raise
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
try:
import time
current_time = int(time.time())
ids = list(data.keys())
documents = [v["content"] for v in data.values()]
metadatas = [
{
**{k: v for k, v in item.items() if k in self.meta_fields},
"created_at": current_time,
}
or {"_default": "true", "created_at": current_time}
for item in data.values()
]
# Process in batches
batches = [
documents[i : i + self._max_batch_size]
for i in range(0, len(documents), self._max_batch_size)
]
embedding_tasks = [self.embedding_func(batch) for batch in batches]
embeddings_list = []
# Pre-allocate embeddings_list with known size
embeddings_list = [None] * len(embedding_tasks)
# Use asyncio.gather instead of as_completed if order doesn't matter
embeddings_results = await asyncio.gather(*embedding_tasks)
embeddings_list = list(embeddings_results)
embeddings = np.concatenate(embeddings_list)
# Upsert in batches
for i in range(0, len(ids), self._max_batch_size):
batch_slice = slice(i, i + self._max_batch_size)
self._collection.upsert(
ids=ids[batch_slice],
embeddings=embeddings[batch_slice].tolist(),
documents=documents[batch_slice],
metadatas=metadatas[batch_slice],
)
return ids
except Exception as e:
logger.error(f"Error during ChromaDB upsert: {str(e)}")
raise
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
try:
embedding = await self.embedding_func(
[query], _priority=5
) # higher priority for query
results = self._collection.query(
query_embeddings=embedding.tolist()
if not isinstance(embedding, list)
else embedding,
n_results=top_k * 2, # Request more results to allow for filtering
include=["metadatas", "distances", "documents"],
)
# Filter results by cosine similarity threshold and take top k
# We request 2x results initially to have enough after filtering
# ChromaDB returns cosine similarity (1 = identical, 0 = orthogonal)
# We convert to distance (0 = identical, 1 = orthogonal) via (1 - similarity)
# Only keep results with distance below threshold, then take top k
return [
{
"id": results["ids"][0][i],
"distance": 1 - results["distances"][0][i],
"content": results["documents"][0][i],
"created_at": results["metadatas"][0][i].get("created_at"),
**results["metadatas"][0][i],
}
for i in range(len(results["ids"][0]))
if (1 - results["distances"][0][i]) >= self.cosine_better_than_threshold
][:top_k]
except Exception as e:
logger.error(f"Error during ChromaDB query: {str(e)}")
raise
async def index_done_callback(self) -> None:
# ChromaDB handles persistence automatically
pass
async def delete_entity(self, entity_name: str) -> None:
"""Delete an entity by its ID.
Args:
entity_name: The ID of the entity to delete
"""
try:
logger.info(f"Deleting entity with ID {entity_name} from {self.namespace}")
self._collection.delete(ids=[entity_name])
except Exception as e:
logger.error(f"Error during entity deletion: {str(e)}")
raise
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete an entity and its relations by ID.
In vector DB context, this is equivalent to delete_entity.
Args:
entity_name: The ID of the entity to delete
"""
await self.delete_entity(entity_name)
async def delete(self, ids: list[str]) -> None:
"""Delete vectors with specified IDs
Args:
ids: List of vector IDs to be deleted
"""
try:
self._collection.delete(ids=ids)
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
raise
except Exception as e:
logger.error(f"Error during prefix search in ChromaDB: {str(e)}")
raise
async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get vector data by its ID
Args:
id: The unique identifier of the vector
Returns:
The vector data if found, or None if not found
"""
try:
# Query the collection for a single vector by ID
result = self._collection.get(
ids=[id], include=["metadatas", "embeddings", "documents"]
)
if not result or not result["ids"] or len(result["ids"]) == 0:
return None
# Format the result to match the expected structure
return {
"id": result["ids"][0],
"vector": result["embeddings"][0],
"content": result["documents"][0],
"created_at": result["metadatas"][0].get("created_at"),
**result["metadatas"][0],
}
except Exception as e:
logger.error(f"Error retrieving vector data for ID {id}: {e}")
return None
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get multiple vector data by their IDs
Args:
ids: List of unique identifiers
Returns:
List of vector data objects that were found
"""
if not ids:
return []
try:
# Query the collection for multiple vectors by IDs
result = self._collection.get(
ids=ids, include=["metadatas", "embeddings", "documents"]
)
if not result or not result["ids"] or len(result["ids"]) == 0:
return []
# Format the results to match the expected structure and preserve ordering
formatted_map: dict[str, dict[str, Any]] = {}
for i, result_id in enumerate(result["ids"]):
record = {
"id": result_id,
"vector": result["embeddings"][i],
"content": result["documents"][i],
"created_at": result["metadatas"][i].get("created_at"),
**result["metadatas"][i],
}
formatted_map[str(result_id)] = record
ordered_results: list[dict[str, Any] | None] = []
for requested_id in ids:
ordered_results.append(formatted_map.get(str(requested_id)))
return ordered_results
except Exception as e:
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
return []
async def drop(self) -> dict[str, str]:
"""Drop all vector data from storage and clean up resources
This method will delete all documents from the ChromaDB collection.
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
# Get all IDs in the collection
result = self._collection.get(include=[])
if result and result["ids"] and len(result["ids"]) > 0:
# Delete all documents
self._collection.delete(ids=result["ids"])
logger.info(
f"Process {os.getpid()} drop ChromaDB collection {self.namespace}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping ChromaDB collection {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

View File

@@ -0,0 +1,585 @@
import os
import time
import asyncio
from typing import Any, final
import json
import numpy as np
from dataclasses import dataclass
from lightrag.utils import logger, compute_mdhash_id
from lightrag.base import BaseVectorStorage
from .shared_storage import (
get_namespace_lock,
get_update_flag,
set_all_update_flags,
)
# You must manually install faiss-cpu or faiss-gpu before using FAISS vector db
import faiss # type: ignore
@final
@dataclass
class FaissVectorDBStorage(BaseVectorStorage):
"""
A Faiss-based Vector DB Storage for LightRAG.
Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search.
"""
def __post_init__(self):
self._validate_embedding_func()
# Grab config values if available
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")
if cosine_threshold is None:
raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
# Where to save index file if you want persistent storage
working_dir = self.global_config["working_dir"]
if self.workspace:
# Include workspace in the file path for data isolation
workspace_dir = os.path.join(working_dir, self.workspace)
else:
# Default behavior when workspace is empty
workspace_dir = working_dir
self.workspace = ""
os.makedirs(workspace_dir, exist_ok=True)
self._faiss_index_file = os.path.join(
workspace_dir, f"faiss_index_{self.namespace}.index"
)
self._meta_file = self._faiss_index_file + ".meta.json"
self._max_batch_size = self.global_config["embedding_batch_num"]
# Embedding dimension (e.g. 768) must match your embedding function
self._dim = self.embedding_func.embedding_dim
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
# If you have a large number of vectors, you might want IVF or other indexes.
# For demonstration, we use a simple IndexFlatIP.
self._index = faiss.IndexFlatIP(self._dim)
# Keep a local store for metadata, IDs, etc.
# Maps <int faiss_id> → metadata (including your original ID).
self._id_to_meta = {}
self._load_faiss_index()
async def initialize(self):
"""Initialize storage data"""
# Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag(
self.namespace, workspace=self.workspace
)
# Get the storage lock for use in other methods
self._storage_lock = get_namespace_lock(
self.namespace, workspace=self.workspace
)
async def _get_index(self):
"""Check if the shtorage should be reloaded"""
# Acquire lock to prevent concurrent read and write
async with self._storage_lock:
# Check if storage was updated by another process
if self.storage_updated.value:
logger.info(
f"[{self.workspace}] Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process"
)
# Reload data
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {}
self._load_faiss_index()
self.storage_updated.value = False
return self._index
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""
Insert or update vectors in the Faiss index.
data: {
"custom_id_1": {
"content": <text>,
...metadata...
},
"custom_id_2": {
"content": <text>,
...metadata...
},
...
}
"""
logger.debug(
f"[{self.workspace}] FAISS: Inserting {len(data)} to {self.namespace}"
)
if not data:
return
current_time = int(time.time())
# Prepare data for embedding
list_data = []
contents = []
for k, v in data.items():
# Store only known meta fields if needed
meta = {mf: v[mf] for mf in self.meta_fields if mf in v}
meta["__id__"] = k
meta["__created_at__"] = current_time
list_data.append(meta)
contents.append(v["content"])
# Split into batches for embedding if needed
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embedding_tasks = [
self.embedding_func(batch, context="document") for batch in batches
]
embeddings_list = await asyncio.gather(*embedding_tasks)
# Flatten the list of arrays
embeddings = np.concatenate(embeddings_list, axis=0)
if len(embeddings) != len(list_data):
logger.error(
f"[{self.workspace}] Embedding size mismatch. Embeddings: {len(embeddings)}, Data: {len(list_data)}"
)
return []
# Convert to float32 and normalize embeddings for cosine similarity (in-place)
embeddings = embeddings.astype(np.float32)
faiss.normalize_L2(embeddings)
# Upsert logic:
# 1. Identify which vectors to remove if they exist
# 2. Remove them
# 3. Add the new vectors
existing_ids_to_remove = []
for meta, emb in zip(list_data, embeddings):
faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"])
if faiss_internal_id is not None:
existing_ids_to_remove.append(faiss_internal_id)
if existing_ids_to_remove:
await self._remove_faiss_ids(existing_ids_to_remove)
# Step 2: Add new vectors
index = await self._get_index()
start_idx = index.ntotal
index.add(embeddings)
# Step 3: Store metadata + vector for each new ID
for i, meta in enumerate(list_data):
fid = start_idx + i
# Store the raw vector so we can rebuild if something is removed
meta["__vector__"] = embeddings[i].tolist()
self._id_to_meta.update({fid: meta})
logger.debug(
f"[{self.workspace}] Upserted {len(list_data)} vectors into Faiss index."
)
return [m["__id__"] for m in list_data]
async def query(
self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]:
"""
Search by a textual query; returns top_k results with their metadata + similarity distance.
"""
if query_embedding is not None:
embedding = np.array([query_embedding], dtype=np.float32)
else:
embedding = await self.embedding_func(
[query], context="query", _priority=5
) # higher priority for query
# embedding is shape (1, dim)
embedding = np.array(embedding, dtype=np.float32)
faiss.normalize_L2(embedding) # we do in-place normalization
# Perform the similarity search
index = await self._get_index()
distances, indices = index.search(embedding, top_k)
distances = distances[0]
indices = indices[0]
results = []
for dist, idx in zip(distances, indices):
if idx == -1:
# Faiss returns -1 if no neighbor
continue
# Cosine similarity threshold
if dist < self.cosine_better_than_threshold:
continue
meta = self._id_to_meta.get(idx, {})
# Filter out __vector__ from query results to avoid returning large vector data
filtered_meta = {k: v for k, v in meta.items() if k != "__vector__"}
results.append(
{
**filtered_meta,
"id": meta.get("__id__"),
"distance": float(dist),
"created_at": meta.get("__created_at__"),
}
)
return results
@property
def client_storage(self):
# Return whatever structure LightRAG might need for debugging
return {"data": list(self._id_to_meta.values())}
async def delete(self, ids: list[str]):
"""
Delete vectors for the provided custom IDs.
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
logger.debug(
f"[{self.workspace}] Deleting {len(ids)} vectors from {self.namespace}"
)
to_remove = []
for cid in ids:
fid = self._find_faiss_id_by_custom_id(cid)
if fid is not None:
to_remove.append(fid)
if to_remove:
await self._remove_faiss_ids(to_remove)
logger.debug(
f"[{self.workspace}] Successfully deleted {len(to_remove)} vectors from {self.namespace}"
)
async def delete_entity(self, entity_name: str) -> None:
"""
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(
f"[{self.workspace}] Attempting to delete entity {entity_name} with ID {entity_id}"
)
await self.delete([entity_id])
async def delete_entity_relation(self, entity_name: str) -> None:
"""
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
logger.debug(f"[{self.workspace}] Searching relations for entity {entity_name}")
relations = []
for fid, meta in self._id_to_meta.items():
if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name:
relations.append(fid)
logger.debug(
f"[{self.workspace}] Found {len(relations)} relations for {entity_name}"
)
if relations:
await self._remove_faiss_ids(relations)
logger.debug(
f"[{self.workspace}] Deleted {len(relations)} relations for {entity_name}"
)
# --------------------------------------------------------------------------------
# Internal helper methods
# --------------------------------------------------------------------------------
def _find_faiss_id_by_custom_id(self, custom_id: str):
"""
Return the Faiss internal ID for a given custom ID, or None if not found.
"""
for fid, meta in self._id_to_meta.items():
if meta.get("__id__") == custom_id:
return fid
return None
async def _remove_faiss_ids(self, fid_list):
"""
Remove a list of internal Faiss IDs from the index.
Because IndexFlatIP doesn't support 'removals',
we rebuild the index excluding those vectors.
"""
keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list]
# Rebuild the index
vectors_to_keep = []
new_id_to_meta = {}
for old_fid in keep_fids:
vec_meta = self._id_to_meta[old_fid]
if "__vector__" in vec_meta:
vec = vec_meta["__vector__"]
elif old_fid < self._index.ntotal:
vec = self._index.reconstruct(old_fid).tolist()
vec_meta["__vector__"] = vec
else:
logger.warning(
f"[{self.workspace}] Skipping fid={old_fid} during rebuild: "
f"no vector and fid exceeds index size ({self._index.ntotal})"
)
continue
new_fid = len(vectors_to_keep)
vectors_to_keep.append(vec)
new_id_to_meta[new_fid] = vec_meta
async with self._storage_lock:
# Re-init index
self._index = faiss.IndexFlatIP(self._dim)
if vectors_to_keep:
arr = np.array(vectors_to_keep, dtype=np.float32)
self._index.add(arr)
self._id_to_meta = new_id_to_meta
def _save_faiss_index(self):
"""
Save the current Faiss index + metadata to disk so it can persist across runs.
"""
faiss.write_index(self._index, self._faiss_index_file)
# Save metadata dict to JSON, excluding __vector__ since vectors are
# already stored in the Faiss index file and can be reconstructed on load.
serializable_dict = {}
for fid, meta in self._id_to_meta.items():
filtered_meta = {k: v for k, v in meta.items() if k != "__vector__"}
serializable_dict[str(fid)] = filtered_meta
# Atomic write: write to temp file first, then rename to reduce
# mismatch risk between index and meta files on crash.
tmp_meta_file = self._meta_file + ".tmp"
with open(tmp_meta_file, "w", encoding="utf-8") as f:
json.dump(serializable_dict, f)
os.replace(tmp_meta_file, self._meta_file)
def _load_faiss_index(self):
"""
Load the Faiss index + metadata from disk if it exists,
and rebuild in-memory structures so we can query.
"""
if not os.path.exists(self._faiss_index_file):
logger.warning(
f"[{self.workspace}] No existing Faiss index file found for {self.namespace}"
)
return
dim_mismatch = False
try:
# Load the Faiss index
self._index = faiss.read_index(self._faiss_index_file)
# Verify dimension consistency between loaded index and embedding function
if self._index.d != self._dim:
error_msg = (
f"Dimension mismatch: loaded Faiss index has dimension {self._index.d}, "
f"but embedding function expects dimension {self._dim}. "
f"Please ensure the embedding model matches the stored index or rebuild the index."
)
logger.error(error_msg)
dim_mismatch = True
raise ValueError(error_msg)
# Load metadata
with open(self._meta_file, "r", encoding="utf-8") as f:
stored_dict = json.load(f)
# Convert string keys back to int and reconstruct vectors from index
self._id_to_meta = {}
for fid_str, meta in stored_dict.items():
fid = int(fid_str)
if fid >= self._index.ntotal:
logger.warning(
f"[{self.workspace}] Skipping metadata row fid={fid}: "
f"exceeds index size ({self._index.ntotal})"
)
continue
if "__vector__" not in meta:
meta["__vector__"] = self._index.reconstruct(fid).tolist()
self._id_to_meta[fid] = meta
logger.info(
f"[{self.workspace}] Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}"
)
except Exception as e:
if dim_mismatch:
raise
logger.error(
f"[{self.workspace}] Failed to load Faiss index or metadata: {e}"
)
logger.warning(f"[{self.workspace}] Starting with an empty Faiss index.")
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {}
async def index_done_callback(self) -> None:
async with self._storage_lock:
# Check if storage was updated by another process
if self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.warning(
f"[{self.workspace}] Storage for FAISS {self.namespace} was updated by another process, reloading..."
)
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {}
self._load_faiss_index()
self.storage_updated.value = False
return False # Return error
# Acquire lock and perform persistence
async with self._storage_lock:
try:
# Save data to disk
self._save_faiss_index()
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace, workspace=self.workspace)
# Reset own update flag to avoid self-reloading
self.storage_updated.value = False
except Exception as e:
logger.error(
f"[{self.workspace}] Error saving FAISS index for {self.namespace}: {e}"
)
return False # Return error
return True # Return success
async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get vector data by its ID
Args:
id: The unique identifier of the vector
Returns:
The vector data if found, or None if not found
"""
# Find the Faiss internal ID for the custom ID
fid = self._find_faiss_id_by_custom_id(id)
if fid is None:
return None
# Get the metadata for the found ID
metadata = self._id_to_meta.get(fid, {})
if not metadata:
return None
# Filter out __vector__ from metadata to avoid returning large vector data
filtered_metadata = {k: v for k, v in metadata.items() if k != "__vector__"}
return {
**filtered_metadata,
"id": metadata.get("__id__"),
"created_at": metadata.get("__created_at__"),
}
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get multiple vector data by their IDs
Args:
ids: List of unique identifiers
Returns:
List of vector data objects that were found
"""
if not ids:
return []
results: list[dict[str, Any] | None] = []
for id in ids:
record = None
fid = self._find_faiss_id_by_custom_id(id)
if fid is not None:
metadata = self._id_to_meta.get(fid)
if metadata:
# Filter out __vector__ from metadata to avoid returning large vector data
filtered_metadata = {
k: v for k, v in metadata.items() if k != "__vector__"
}
record = {
**filtered_metadata,
"id": metadata.get("__id__"),
"created_at": metadata.get("__created_at__"),
}
results.append(record)
return results
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
"""Get vectors by their IDs, returning only ID and vector data for efficiency
Args:
ids: List of unique identifiers
Returns:
Dictionary mapping IDs to their vector embeddings
Format: {id: [vector_values], ...}
"""
if not ids:
return {}
vectors_dict = {}
for id in ids:
# Find the Faiss internal ID for the custom ID
fid = self._find_faiss_id_by_custom_id(id)
if fid is not None and fid in self._id_to_meta:
metadata = self._id_to_meta[fid]
# Get the stored vector from metadata
if "__vector__" in metadata:
vectors_dict[id] = metadata["__vector__"]
return vectors_dict
async def drop(self) -> dict[str, str]:
"""Drop all vector data from storage and clean up resources
This method will:
1. Remove the vector database storage file if it exists
2. Reinitialize the vector database client
3. Update flags to notify other processes
4. Changes is persisted to disk immediately
This method will remove all vectors from the Faiss index and delete the storage files.
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
async with self._storage_lock:
# Reset the index
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {}
# Remove storage files if they exist
if os.path.exists(self._faiss_index_file):
os.remove(self._faiss_index_file)
if os.path.exists(self._meta_file):
os.remove(self._meta_file)
self._id_to_meta = {}
self._load_faiss_index()
# Notify other processes
await set_all_update_flags(self.namespace, workspace=self.workspace)
self.storage_updated.value = False
logger.info(
f"[{self.workspace}] Process {os.getpid()} drop FAISS index {self.namespace}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(
f"[{self.workspace}] Error dropping FAISS index {self.namespace}: {e}"
)
return {"status": "error", "message": str(e)}

View File

@@ -0,0 +1,422 @@
from dataclasses import dataclass
import os
from typing import Any, Union, final
from lightrag.base import (
DocProcessingStatus,
DocStatus,
DocStatusStorage,
)
from lightrag.utils import (
_cooperative_yield,
load_json,
logger,
write_json,
get_pinyin_sort_key,
)
from lightrag.exceptions import StorageNotInitializedError
from .shared_storage import (
get_namespace_data,
get_namespace_lock,
get_data_init_lock,
get_update_flag,
set_all_update_flags,
clear_all_update_flags,
try_initialize_namespace,
)
@final
@dataclass
class JsonDocStatusStorage(DocStatusStorage):
"""JSON implementation of document status storage"""
def __post_init__(self):
working_dir = self.global_config["working_dir"]
if self.workspace:
# Include workspace in the file path for data isolation
workspace_dir = os.path.join(working_dir, self.workspace)
else:
# Default behavior when workspace is empty
workspace_dir = working_dir
self.workspace = ""
os.makedirs(workspace_dir, exist_ok=True)
self._file_name = os.path.join(workspace_dir, f"kv_store_{self.namespace}.json")
self._data = None
self._storage_lock = None
self.storage_updated = None
async def initialize(self):
"""Initialize storage data"""
self._storage_lock = get_namespace_lock(
self.namespace, workspace=self.workspace
)
self.storage_updated = await get_update_flag(
self.namespace, workspace=self.workspace
)
async with get_data_init_lock():
# check need_init must before get_namespace_data
need_init = await try_initialize_namespace(
self.namespace, workspace=self.workspace
)
self._data = await get_namespace_data(
self.namespace, workspace=self.workspace
)
if need_init:
loaded_data = load_json(self._file_name) or {}
async with self._storage_lock:
self._data.update(loaded_data)
logger.info(
f"[{self.workspace}] Process {os.getpid()} doc status load {self.namespace} with {len(loaded_data)} records"
)
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return keys that should be processed (not in storage or not successfully processed)"""
if self._storage_lock is None:
raise StorageNotInitializedError("JsonDocStatusStorage")
async with self._storage_lock:
return set(keys) - set(self._data.keys())
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
ordered_results: list[dict[str, Any] | None] = []
if self._storage_lock is None:
raise StorageNotInitializedError("JsonDocStatusStorage")
async with self._storage_lock:
for id in ids:
data = self._data.get(id, None)
if data:
ordered_results.append(data.copy())
else:
ordered_results.append(None)
return ordered_results
async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status"""
counts = {status.value: 0 for status in DocStatus}
if self._storage_lock is None:
raise StorageNotInitializedError("JsonDocStatusStorage")
async with self._storage_lock:
for doc in self._data.values():
counts[doc["status"]] += 1
return counts
async def get_docs_by_status(
self, status: DocStatus
) -> dict[str, DocProcessingStatus]:
"""Get all documents with a specific status"""
return await self.get_docs_by_statuses([status])
async def get_docs_by_statuses(
self, statuses: list[DocStatus]
) -> dict[str, DocProcessingStatus]:
"""Get all documents matching any of the given statuses in a single pass.
Acquires the storage lock once and scans the in-memory dict once,
filtering against a set of status values. More efficient than N separate
get_docs_by_status() calls, which would acquire the lock N times and scan
the data N times.
"""
if not statuses:
return {}
status_values = {s.value for s in statuses}
result = {}
async with self._storage_lock:
for k, v in self._data.items():
if v["status"] not in status_values:
continue
try:
data = v.copy()
data.pop("content", None)
if not data.get("file_path"):
data["file_path"] = "no-file-path"
if "metadata" not in data:
data["metadata"] = {}
if "error_msg" not in data:
data["error_msg"] = None
result[k] = DocProcessingStatus(**data)
except (KeyError, TypeError) as e:
logger.error(
f"[{self.workspace}] Missing required field for document {k}: {e}"
)
continue
return result
async def get_docs_by_track_id(
self, track_id: str
) -> dict[str, DocProcessingStatus]:
"""Get all documents with a specific track_id"""
result = {}
async with self._storage_lock:
for k, v in self._data.items():
if v.get("track_id") == track_id:
try:
# Make a copy of the data to avoid modifying the original
data = v.copy()
# Remove deprecated content field if it exists
data.pop("content", None)
# Normalize missing or null file_path
if not data.get("file_path"):
data["file_path"] = "no-file-path"
# Ensure new fields exist with default values
if "metadata" not in data:
data["metadata"] = {}
if "error_msg" not in data:
data["error_msg"] = None
result[k] = DocProcessingStatus(**data)
except KeyError as e:
logger.error(
f"[{self.workspace}] Missing required field for document {k}: {e}"
)
continue
return result
async def index_done_callback(self) -> None:
async with self._storage_lock:
if self.storage_updated.value:
data_dict = (
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
)
logger.debug(
f"[{self.workspace}] Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}"
)
# Write JSON and check if sanitization was applied
needs_reload = write_json(data_dict, self._file_name)
# If data was sanitized, reload cleaned data to update shared memory
if needs_reload:
logger.info(
f"[{self.workspace}] Reloading sanitized data into shared memory for {self.namespace}"
)
cleaned_data = load_json(self._file_name)
if cleaned_data is not None:
self._data.clear()
self._data.update(cleaned_data)
await clear_all_update_flags(self.namespace, workspace=self.workspace)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed
"""
if not data:
return
logger.debug(
f"[{self.workspace}] Inserting {len(data)} records to {self.namespace}"
)
if self._storage_lock is None:
raise StorageNotInitializedError("JsonDocStatusStorage")
# Prepare data outside the lock: this only mutates the caller-supplied
# dict values, not shared storage state, so no lock needed here.
for i, (doc_id, doc_data) in enumerate(data.items(), start=1):
if "chunks_list" not in doc_data:
doc_data["chunks_list"] = []
await _cooperative_yield(i)
async with self._storage_lock:
self._data.update(data)
await set_all_update_flags(self.namespace, workspace=self.workspace)
await self.index_done_callback()
async def is_empty(self) -> bool:
"""Check if the storage is empty
Returns:
bool: True if storage is empty, False otherwise
Raises:
StorageNotInitializedError: If storage is not initialized
"""
if self._storage_lock is None:
raise StorageNotInitializedError("JsonDocStatusStorage")
async with self._storage_lock:
return len(self._data) == 0
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
async with self._storage_lock:
return self._data.get(id)
async def get_docs_paginated(
self,
status_filter: DocStatus | None = None,
page: int = 1,
page_size: int = 50,
sort_field: str = "updated_at",
sort_direction: str = "desc",
) -> tuple[list[tuple[str, DocProcessingStatus]], int]:
"""Get documents with pagination support
Args:
status_filter: Filter by document status, None for all statuses
page: Page number (1-based)
page_size: Number of documents per page (10-200)
sort_field: Field to sort by ('created_at', 'updated_at', 'id')
sort_direction: Sort direction ('asc' or 'desc')
Returns:
Tuple of (list of (doc_id, DocProcessingStatus) tuples, total_count)
"""
# Validate parameters
if page < 1:
page = 1
if page_size < 10:
page_size = 10
elif page_size > 200:
page_size = 200
if sort_field not in ["created_at", "updated_at", "id", "file_path"]:
sort_field = "updated_at"
if sort_direction.lower() not in ["asc", "desc"]:
sort_direction = "desc"
# For JSON storage, we load all data and sort/filter in memory
all_docs = []
async with self._storage_lock:
for doc_id, doc_data in self._data.items():
# Apply status filter
if (
status_filter is not None
and doc_data.get("status") != status_filter.value
):
continue
try:
# Prepare document data
data = doc_data.copy()
data.pop("content", None)
if not data.get("file_path"):
data["file_path"] = "no-file-path"
if "metadata" not in data:
data["metadata"] = {}
if "error_msg" not in data:
data["error_msg"] = None
doc_status = DocProcessingStatus(**data)
# Add sort key for sorting
if sort_field == "id":
doc_status._sort_key = doc_id
elif sort_field == "file_path":
# Use pinyin sorting for file_path field to support Chinese characters
file_path_value = getattr(doc_status, sort_field, "")
doc_status._sort_key = get_pinyin_sort_key(file_path_value)
else:
doc_status._sort_key = getattr(doc_status, sort_field, "")
all_docs.append((doc_id, doc_status))
except KeyError as e:
logger.error(
f"[{self.workspace}] Error processing document {doc_id}: {e}"
)
continue
# Sort documents
reverse_sort = sort_direction.lower() == "desc"
all_docs.sort(
key=lambda x: getattr(x[1], "_sort_key", ""), reverse=reverse_sort
)
# Remove sort key from documents
for doc_id, doc in all_docs:
if hasattr(doc, "_sort_key"):
delattr(doc, "_sort_key")
total_count = len(all_docs)
# Apply pagination
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
paginated_docs = all_docs[start_idx:end_idx]
return paginated_docs, total_count
async def get_all_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status for all documents
Returns:
Dictionary mapping status names to counts, including 'all' field
"""
counts = await self.get_status_counts()
# Add 'all' field with total count
total_count = sum(counts.values())
counts["all"] = total_count
return counts
async def delete(self, doc_ids: list[str]) -> None:
"""Delete specific records from storage by their IDs
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed
Args:
ids (list[str]): List of document IDs to be deleted from storage
Returns:
None
"""
async with self._storage_lock:
any_deleted = False
for doc_id in doc_ids:
result = self._data.pop(doc_id, None)
if result is not None:
any_deleted = True
if any_deleted:
await set_all_update_flags(self.namespace, workspace=self.workspace)
async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]:
"""Get document by file path
Args:
file_path: The file path to search for
Returns:
Union[dict[str, Any], None]: Document data if found, None otherwise
Returns the same format as get_by_ids method
"""
if self._storage_lock is None:
raise StorageNotInitializedError("JsonDocStatusStorage")
async with self._storage_lock:
for doc_id, doc_data in self._data.items():
if doc_data.get("file_path") == file_path:
# Return complete document data, consistent with get_by_ids method
return doc_data
return None
async def drop(self) -> dict[str, str]:
"""Drop all document status data from storage and clean up resources
This method will:
1. Clear all document status data from memory
2. Update flags to notify other processes
3. Trigger index_done_callback to save the empty state
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
async with self._storage_lock:
self._data.clear()
await set_all_update_flags(self.namespace, workspace=self.workspace)
await self.index_done_callback()
logger.info(
f"[{self.workspace}] Process {os.getpid()} drop {self.namespace}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"[{self.workspace}] Error dropping {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

View File

@@ -0,0 +1,307 @@
import os
from dataclasses import dataclass
from typing import Any, final
from lightrag.base import (
BaseKVStorage,
)
from lightrag.utils import (
_cooperative_yield,
load_json,
logger,
write_json,
)
from lightrag.exceptions import StorageNotInitializedError
from .shared_storage import (
get_namespace_data,
get_namespace_lock,
get_data_init_lock,
get_update_flag,
set_all_update_flags,
clear_all_update_flags,
try_initialize_namespace,
)
@final
@dataclass
class JsonKVStorage(BaseKVStorage):
def __post_init__(self):
working_dir = self.global_config["working_dir"]
if self.workspace:
# Include workspace in the file path for data isolation
workspace_dir = os.path.join(working_dir, self.workspace)
else:
# Default behavior when workspace is empty
workspace_dir = working_dir
self.workspace = ""
os.makedirs(workspace_dir, exist_ok=True)
self._file_name = os.path.join(workspace_dir, f"kv_store_{self.namespace}.json")
self._data = None
self._storage_lock = None
self.storage_updated = None
async def initialize(self):
"""Initialize storage data"""
self._storage_lock = get_namespace_lock(
self.namespace, workspace=self.workspace
)
self.storage_updated = await get_update_flag(
self.namespace, workspace=self.workspace
)
async with get_data_init_lock():
# check need_init must before get_namespace_data
need_init = await try_initialize_namespace(
self.namespace, workspace=self.workspace
)
self._data = await get_namespace_data(
self.namespace, workspace=self.workspace
)
if need_init:
loaded_data = load_json(self._file_name) or {}
async with self._storage_lock:
# Migrate legacy cache structure if needed
if self.namespace.endswith("_cache"):
loaded_data = await self._migrate_legacy_cache_structure(
loaded_data
)
self._data.update(loaded_data)
data_count = len(loaded_data)
logger.info(
f"[{self.workspace}] Process {os.getpid()} KV load {self.namespace} with {data_count} records"
)
async def index_done_callback(self) -> None:
async with self._storage_lock:
if self.storage_updated.value:
data_dict = (
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
)
# Calculate data count - all data is now flattened
data_count = len(data_dict)
logger.debug(
f"[{self.workspace}] Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
)
# Write JSON and check if sanitization was applied
needs_reload = write_json(data_dict, self._file_name)
# If data was sanitized, reload cleaned data to update shared memory
if needs_reload:
logger.info(
f"[{self.workspace}] Reloading sanitized data into shared memory for {self.namespace}"
)
cleaned_data = load_json(self._file_name)
if cleaned_data is not None:
self._data.clear()
self._data.update(cleaned_data)
await clear_all_update_flags(self.namespace, workspace=self.workspace)
async def get_by_id(self, id: str) -> dict[str, Any] | None:
async with self._storage_lock:
result = self._data.get(id)
if result:
# Create a copy to avoid modifying the original data
result = dict(result)
# Ensure time fields are present, provide default values for old data
result.setdefault("create_time", 0)
result.setdefault("update_time", 0)
# Ensure _id field contains the clean ID
result["_id"] = id
return result
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
async with self._storage_lock:
results = []
for id in ids:
data = self._data.get(id, None)
if data:
# Create a copy to avoid modifying the original data
result = {k: v for k, v in data.items()}
# Ensure time fields are present, provide default values for old data
result.setdefault("create_time", 0)
result.setdefault("update_time", 0)
# Ensure _id field contains the clean ID
result["_id"] = id
results.append(result)
else:
results.append(None)
return results
async def filter_keys(self, keys: set[str]) -> set[str]:
async with self._storage_lock:
return set(keys) - set(self._data.keys())
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed
"""
if not data:
return
import time
current_time = int(time.time()) # Get current Unix timestamp
logger.debug(
f"[{self.workspace}] Inserting {len(data)} records to {self.namespace}"
)
if self._storage_lock is None:
raise StorageNotInitializedError("JsonKVStorage")
async with self._storage_lock:
# Add timestamps to data based on whether key exists.
# The loop reads self._data (k in self._data) so it must stay inside
# the lock. _cooperative_yield is safe here: NamespaceLock is
# non-reentrant, so other coroutines waiting on this lock will block
# until we release it; the yield only benefits unrelated coroutines.
for i, (k, v) in enumerate(data.items(), start=1):
# For text_chunks namespace, ensure llm_cache_list field exists
if self.namespace.endswith("text_chunks"):
if "llm_cache_list" not in v:
v["llm_cache_list"] = []
# Add timestamps based on whether key exists
if k in self._data: # Key exists, only update update_time
v["update_time"] = current_time
else: # New key, set both create_time and update_time
v["create_time"] = current_time
v["update_time"] = current_time
v["_id"] = k
await _cooperative_yield(i)
self._data.update(data)
await set_all_update_flags(self.namespace, workspace=self.workspace)
async def delete(self, ids: list[str]) -> None:
"""Delete specific records from storage by their IDs
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed
Args:
ids (list[str]): List of document IDs to be deleted from storage
Returns:
None
"""
async with self._storage_lock:
any_deleted = False
for doc_id in ids:
result = self._data.pop(doc_id, None)
if result is not None:
any_deleted = True
if any_deleted:
await set_all_update_flags(self.namespace, workspace=self.workspace)
async def is_empty(self) -> bool:
"""Check if the storage is empty
Returns:
bool: True if storage contains no data, False otherwise
"""
async with self._storage_lock:
return len(self._data) == 0
async def drop(self) -> dict[str, str]:
"""Drop all data from storage and clean up resources
This action will persistent the data to disk immediately.
This method will:
1. Clear all data from memory
2. Update flags to notify other processes
3. Trigger index_done_callback to save the empty state
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
async with self._storage_lock:
self._data.clear()
await set_all_update_flags(self.namespace, workspace=self.workspace)
await self.index_done_callback()
logger.info(
f"[{self.workspace}] Process {os.getpid()} drop {self.namespace}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"[{self.workspace}] Error dropping {self.namespace}: {e}")
return {"status": "error", "message": str(e)}
async def _migrate_legacy_cache_structure(self, data: dict) -> dict:
"""Migrate legacy nested cache structure to flattened structure
Args:
data: Original data dictionary that may contain legacy structure
Returns:
Migrated data dictionary with flattened cache keys (sanitized if needed)
"""
from lightrag.utils import generate_cache_key
# Early return if data is empty
if not data:
return data
# Check first entry to see if it's already in new format
first_key = next(iter(data.keys()))
if ":" in first_key and len(first_key.split(":")) == 3:
# Already in flattened format, return as-is
return data
migrated_data = {}
migration_count = 0
for key, value in data.items():
# Check if this is a legacy nested cache structure
if isinstance(value, dict) and all(
isinstance(v, dict) and "return" in v for v in value.values()
):
# This looks like a legacy cache mode with nested structure
mode = key
for cache_hash, cache_entry in value.items():
cache_type = cache_entry.get("cache_type", "extract")
flattened_key = generate_cache_key(mode, cache_type, cache_hash)
migrated_data[flattened_key] = cache_entry
migration_count += 1
else:
# Keep non-cache data or already flattened cache data as-is
migrated_data[key] = value
if migration_count > 0:
logger.info(
f"[{self.workspace}] Migrated {migration_count} legacy cache entries to flattened structure"
)
# Persist migrated data immediately and check if sanitization was applied
needs_reload = write_json(migrated_data, self._file_name)
# If data was sanitized during write, reload cleaned data
if needs_reload:
logger.info(
f"[{self.workspace}] Reloading sanitized migration data for {self.namespace}"
)
cleaned_data = load_json(self._file_name)
if cleaned_data is not None:
return cleaned_data # Return cleaned data to update shared memory
return migrated_data
async def finalize(self):
"""Finalize storage resources
Persistence cache data to disk before exiting
"""
if self.namespace.endswith("_cache"):
await self.index_done_callback()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,430 @@
import asyncio
import base64
import os
import zlib
from typing import Any, final
from dataclasses import dataclass
import numpy as np
import time
from lightrag.utils import (
logger,
compute_mdhash_id,
)
from lightrag.base import BaseVectorStorage
from nano_vectordb import NanoVectorDB
from .shared_storage import (
get_namespace_lock,
get_update_flag,
set_all_update_flags,
)
@final
@dataclass
class NanoVectorDBStorage(BaseVectorStorage):
def __post_init__(self):
self._validate_embedding_func()
# Initialize basic attributes
self._client = None
self._storage_lock = None
self.storage_updated = None
# Use global config value if specified, otherwise use default
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")
if cosine_threshold is None:
raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
working_dir = self.global_config["working_dir"]
if self.workspace:
# Include workspace in the file path for data isolation
workspace_dir = os.path.join(working_dir, self.workspace)
self.final_namespace = f"{self.workspace}_{self.namespace}"
else:
# Default behavior when workspace is empty
self.final_namespace = self.namespace
self.workspace = ""
workspace_dir = working_dir
os.makedirs(workspace_dir, exist_ok=True)
self._client_file_name = os.path.join(
workspace_dir, f"vdb_{self.namespace}.json"
)
self._max_batch_size = self.global_config["embedding_batch_num"]
self._client = NanoVectorDB(
self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
async def initialize(self):
"""Initialize storage data"""
# Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag(
self.namespace, workspace=self.workspace
)
# Get the storage lock for use in other methods
self._storage_lock = get_namespace_lock(
self.namespace, workspace=self.workspace
)
async def _get_client(self):
"""Check if the storage should be reloaded"""
# Acquire lock to prevent concurrent read and write
async with self._storage_lock:
# Check if data needs to be reloaded
if self.storage_updated.value:
logger.info(
f"[{self.workspace}] Process {os.getpid()} reloading {self.namespace} due to update by another process"
)
# Reload data
self._client = NanoVectorDB(
self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
# Reset update flag
self.storage_updated.value = False
return self._client
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
# logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
if not data:
return
current_time = int(time.time())
list_data = [
{
"__id__": k,
"__created_at__": current_time,
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
# Execute embedding outside of lock to avoid long lock times
embedding_tasks = [
self.embedding_func(batch, context="document") for batch in batches
]
embeddings_list = await asyncio.gather(*embedding_tasks)
embeddings = np.concatenate(embeddings_list)
if len(embeddings) == len(list_data):
for i, d in enumerate(list_data):
# Compress vector using Float16 + zlib + Base64 for storage optimization
vector_f16 = embeddings[i].astype(np.float16)
compressed_vector = zlib.compress(vector_f16.tobytes())
encoded_vector = base64.b64encode(compressed_vector).decode("utf-8")
d["vector"] = encoded_vector
d["__vector__"] = embeddings[i]
client = await self._get_client()
results = client.upsert(datas=list_data)
return results
else:
# sometimes the embedding is not returned correctly. just log it.
logger.error(
f"[{self.workspace}] embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
)
async def query(
self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]:
# Use provided embedding or compute it
if query_embedding is not None:
embedding = query_embedding
else:
# Execute embedding outside of lock to avoid improve cocurrent
embedding = await self.embedding_func(
[query], context="query", _priority=5
) # higher priority for query
embedding = embedding[0]
client = await self._get_client()
results = client.query(
query=embedding,
top_k=top_k,
better_than_threshold=self.cosine_better_than_threshold,
)
results = [
{
**{k: v for k, v in dp.items() if k != "vector"},
"id": dp["__id__"],
"distance": dp["__metrics__"],
"created_at": dp.get("__created_at__"),
}
for dp in results
]
return results
@property
async def client_storage(self):
client = await self._get_client()
return getattr(client, "_NanoVectorDB__storage")
async def delete(self, ids: list[str]):
"""Delete vectors with specified IDs
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
Args:
ids: List of vector IDs to be deleted
"""
try:
client = await self._get_client()
# Record count before deletion
before_count = len(client)
client.delete(ids)
# Calculate actual deleted count
after_count = len(client)
deleted_count = before_count - after_count
logger.debug(
f"[{self.workspace}] Successfully deleted {deleted_count} vectors from {self.namespace}"
)
except Exception as e:
logger.error(
f"[{self.workspace}] Error while deleting vectors from {self.namespace}: {e}"
)
async def delete_entity(self, entity_name: str) -> None:
"""
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
try:
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(
f"[{self.workspace}] Attempting to delete entity {entity_name} with ID {entity_id}"
)
# Check if the entity exists
client = await self._get_client()
if client.get([entity_id]):
client.delete([entity_id])
logger.debug(
f"[{self.workspace}] Successfully deleted entity {entity_name}"
)
else:
logger.debug(
f"[{self.workspace}] Entity {entity_name} not found in storage"
)
except Exception as e:
logger.error(f"[{self.workspace}] Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None:
"""
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
try:
client = await self._get_client()
storage = getattr(client, "_NanoVectorDB__storage")
relations = [
dp
for dp in storage["data"]
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
]
logger.debug(
f"[{self.workspace}] Found {len(relations)} relations for entity {entity_name}"
)
ids_to_delete = [relation["__id__"] for relation in relations]
if ids_to_delete:
client = await self._get_client()
client.delete(ids_to_delete)
logger.debug(
f"[{self.workspace}] Deleted {len(ids_to_delete)} relations for {entity_name}"
)
else:
logger.debug(
f"[{self.workspace}] No relations found for entity {entity_name}"
)
except Exception as e:
logger.error(
f"[{self.workspace}] Error deleting relations for {entity_name}: {e}"
)
async def index_done_callback(self) -> bool:
"""Save data to disk"""
async with self._storage_lock:
# Check if storage was updated by another process
if self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.warning(
f"[{self.workspace}] Storage for {self.namespace} was updated by another process, reloading..."
)
self._client = NanoVectorDB(
self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
# Reset update flag
self.storage_updated.value = False
return False # Return error
# Acquire lock and perform persistence
async with self._storage_lock:
try:
# Save data to disk
self._client.save()
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace, workspace=self.workspace)
# Reset own update flag to avoid self-reloading
self.storage_updated.value = False
return True # Return success
except Exception as e:
logger.error(
f"[{self.workspace}] Error saving data for {self.namespace}: {e}"
)
return False # Return error
return True # Return success
async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get vector data by its ID
Args:
id: The unique identifier of the vector
Returns:
The vector data if found, or None if not found
"""
client = await self._get_client()
result = client.get([id])
if result:
dp = result[0]
return {
**{k: v for k, v in dp.items() if k != "vector"},
"id": dp.get("__id__"),
"created_at": dp.get("__created_at__"),
}
return None
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get multiple vector data by their IDs
Args:
ids: List of unique identifiers
Returns:
List of vector data objects that were found
"""
if not ids:
return []
client = await self._get_client()
results = client.get(ids)
result_map: dict[str, dict[str, Any]] = {}
for dp in results:
if not dp:
continue
record = {
**{k: v for k, v in dp.items() if k != "vector"},
"id": dp.get("__id__"),
"created_at": dp.get("__created_at__"),
}
key = record.get("id")
if key is not None:
result_map[str(key)] = record
ordered_results: list[dict[str, Any] | None] = []
for requested_id in ids:
ordered_results.append(result_map.get(str(requested_id)))
return ordered_results
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
"""Get vectors by their IDs, returning only ID and vector data for efficiency
Args:
ids: List of unique identifiers
Returns:
Dictionary mapping IDs to their vector embeddings
Format: {id: [vector_values], ...}
"""
if not ids:
return {}
client = await self._get_client()
results = client.get(ids)
vectors_dict = {}
for result in results:
if result and "vector" in result and "__id__" in result:
# Decompress vector data (Base64 + zlib + Float16 compressed)
decoded = base64.b64decode(result["vector"])
decompressed = zlib.decompress(decoded)
vector_f16 = np.frombuffer(decompressed, dtype=np.float16)
vector_f32 = vector_f16.astype(np.float32).tolist()
vectors_dict[result["__id__"]] = vector_f32
return vectors_dict
async def drop(self) -> dict[str, str]:
"""Drop all vector data from storage and clean up resources
This method will:
1. Remove the vector database storage file if it exists
2. Reinitialize the vector database client
3. Update flags to notify other processes
4. Changes is persisted to disk immediately
This method is intended for use in scenarios where all data needs to be removed,
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
async with self._storage_lock:
# delete _client_file_name
if os.path.exists(self._client_file_name):
os.remove(self._client_file_name)
self._client = NanoVectorDB(
self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace, workspace=self.workspace)
# Reset own update flag to avoid self-reloading
self.storage_updated.value = False
logger.info(
f"[{self.workspace}] Process {os.getpid()} drop {self.namespace}(file:{self._client_file_name})"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"[{self.workspace}] Error dropping {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,606 @@
import os
from collections import deque
from dataclasses import dataclass
from typing import final
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from lightrag.utils import logger
from lightrag.base import BaseGraphStorage
import networkx as nx
from .shared_storage import (
get_namespace_lock,
get_update_flag,
set_all_update_flags,
)
from dotenv import load_dotenv
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
@final
@dataclass
class NetworkXStorage(BaseGraphStorage):
@staticmethod
def load_nx_graph(file_name) -> nx.Graph:
if os.path.exists(file_name):
return nx.read_graphml(file_name)
return None
@staticmethod
def write_nx_graph(graph: nx.Graph, file_name, workspace="_"):
logger.info(
f"[{workspace}] Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
)
nx.write_graphml(graph, file_name)
def __post_init__(self):
working_dir = self.global_config["working_dir"]
if self.workspace:
# Include workspace in the file path for data isolation
workspace_dir = os.path.join(working_dir, self.workspace)
else:
# Default behavior when workspace is empty
workspace_dir = working_dir
self.workspace = ""
os.makedirs(workspace_dir, exist_ok=True)
self._graphml_xml_file = os.path.join(
workspace_dir, f"graph_{self.namespace}.graphml"
)
self._storage_lock = None
self.storage_updated = None
self._graph = None
# Load initial graph
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
if preloaded_graph is not None:
logger.info(
f"[{self.workspace}] Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
else:
logger.info(
f"[{self.workspace}] Created new empty graph file: {self._graphml_xml_file}"
)
self._graph = preloaded_graph or nx.Graph()
async def initialize(self):
"""Initialize storage data"""
# Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag(
self.namespace, workspace=self.workspace
)
# Get the storage lock for use in other methods
self._storage_lock = get_namespace_lock(
self.namespace, workspace=self.workspace
)
async def _get_graph(self):
"""Check if the storage should be reloaded"""
# Acquire lock to prevent concurrent read and write
async with self._storage_lock:
# Check if data needs to be reloaded
if self.storage_updated.value:
logger.info(
f"[{self.workspace}] Process {os.getpid()} reloading graph {self._graphml_xml_file} due to modifications by another process"
)
# Reload data
self._graph = (
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
)
# Reset update flag
self.storage_updated.value = False
return self._graph
async def has_node(self, node_id: str) -> bool:
graph = await self._get_graph()
return graph.has_node(node_id)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
graph = await self._get_graph()
return graph.has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> dict[str, str] | None:
graph = await self._get_graph()
return graph.nodes.get(node_id)
async def node_degree(self, node_id: str) -> int:
graph = await self._get_graph()
if graph.has_node(node_id):
return graph.degree(node_id)
return 0
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
graph = await self._get_graph()
src_degree = graph.degree(src_id) if graph.has_node(src_id) else 0
tgt_degree = graph.degree(tgt_id) if graph.has_node(tgt_id) else 0
return src_degree + tgt_degree
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
graph = await self._get_graph()
return graph.edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
graph = await self._get_graph()
if graph.has_node(source_node_id):
return list(graph.edges(source_node_id))
return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
graph = await self._get_graph()
graph.add_node(node_id, **node_data)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
"""
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
graph = await self._get_graph()
graph.add_edge(source_node_id, target_node_id, **edge_data)
async def upsert_nodes_batch(self, nodes: list[tuple[str, dict[str, str]]]) -> None:
"""Batch insert/update multiple nodes in a single call.
Much faster than calling upsert_node() in a loop for large imports
because it avoids per-call async event loop overhead.
Args:
nodes: List of (node_id, node_data) tuples.
"""
graph = await self._get_graph()
for node_id, node_data in nodes:
graph.add_node(node_id, **node_data)
async def has_nodes_batch(self, node_ids: list[str]) -> set[str]:
"""Check existence of multiple nodes in a single call.
Returns:
Set of node_ids that exist in the graph.
"""
graph = await self._get_graph()
return {nid for nid in node_ids if graph.has_node(nid)}
async def upsert_edges_batch(
self, edges: list[tuple[str, str, dict[str, str]]]
) -> None:
"""Batch insert/update multiple edges in a single call.
Args:
edges: List of (source_id, target_id, edge_data) tuples.
"""
graph = await self._get_graph()
for src, tgt, edge_data in edges:
graph.add_edge(src, tgt, **edge_data)
async def delete_node(self, node_id: str) -> None:
"""
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
graph = await self._get_graph()
if graph.has_node(node_id):
graph.remove_node(node_id)
logger.debug(f"[{self.workspace}] Node {node_id} deleted from the graph")
else:
logger.warning(
f"[{self.workspace}] Node {node_id} not found in the graph for deletion"
)
async def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
Args:
nodes: List of node IDs to be deleted
"""
graph = await self._get_graph()
for node in nodes:
if graph.has_node(node):
graph.remove_node(node)
async def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple
"""
graph = await self._get_graph()
for source, target in edges:
if graph.has_edge(source, target):
graph.remove_edge(source, target)
async def get_all_labels(self) -> list[str]:
"""
Get all node labels(entity names) in the graph
Returns:
[label1, label2, ...] # Alphabetically sorted label list
"""
graph = await self._get_graph()
labels = set()
for node in graph.nodes():
labels.add(str(node)) # Add node id as a label
# Return sorted list
return sorted(list(labels))
async def get_popular_labels(self, limit: int = 300) -> list[str]:
"""
Get popular labels(entity names) by node degree (most connected entities)
Args:
limit: Maximum number of labels to return
Returns:
List of labels sorted by degree (highest first)
"""
graph = await self._get_graph()
# Get degrees of all nodes and sort by degree descending
degrees = dict(graph.degree())
sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
# Return top labels limited by the specified limit
popular_labels = [str(node) for node, _ in sorted_nodes[:limit]]
logger.debug(
f"[{self.workspace}] Retrieved {len(popular_labels)} popular labels (limit: {limit})"
)
return popular_labels
async def search_labels(self, query: str, limit: int = 50) -> list[str]:
"""
Search labels(entity names) with fuzzy matching
Args:
query: Search query string
limit: Maximum number of results to return
Returns:
List of matching labels sorted by relevance
"""
graph = await self._get_graph()
query_lower = query.lower().strip()
if not query_lower:
return []
# Collect matching nodes with relevance scores
matches = []
for node in graph.nodes():
node_str = str(node)
node_lower = node_str.lower()
# Skip if no match
if query_lower not in node_lower:
continue
# Calculate relevance score
# Exact match gets highest score
if node_lower == query_lower:
score = 1000
# Prefix match gets high score
elif node_lower.startswith(query_lower):
score = 500
# Contains match gets base score, with bonus for shorter strings
else:
# Shorter strings with matches are more relevant
score = 100 - len(node_str)
# Bonus for word boundary matches
if f" {query_lower}" in node_lower or f"_{query_lower}" in node_lower:
score += 50
matches.append((node_str, score))
# Sort by relevance score (desc) then alphabetically
matches.sort(key=lambda x: (-x[1], x[0]))
# Return top matches limited by the specified limit
search_results = [match[0] for match in matches[:limit]]
logger.debug(
f"[{self.workspace}] Search query '{query}' returned {len(search_results)} results (limit: {limit})"
)
return search_results
async def get_knowledge_graph(
self,
node_label: str,
max_depth: int = 3,
max_nodes: int = None,
) -> KnowledgeGraph:
"""
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Args:
node_label: Label of the starting node* means all nodes
max_depth: Maximum depth of the subgraph, Defaults to 3
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
Returns:
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
"""
# Get max_nodes from global_config if not provided
if max_nodes is None:
max_nodes = self.global_config.get("max_graph_nodes", 1000)
else:
# Limit max_nodes to not exceed global_config max_graph_nodes
max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000))
graph = await self._get_graph()
result = KnowledgeGraph()
# Handle special case for "*" label
if node_label == "*":
# Get degrees of all nodes
degrees = dict(graph.degree())
# Sort nodes by degree in descending order and take top max_nodes
sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
# Check if graph is truncated
if len(sorted_nodes) > max_nodes:
result.is_truncated = True
logger.info(
f"[{self.workspace}] Graph truncated: {len(sorted_nodes)} nodes found, limited to {max_nodes}"
)
limited_nodes = [node for node, _ in sorted_nodes[:max_nodes]]
# Create subgraph with the highest degree nodes
subgraph = graph.subgraph(limited_nodes)
else:
# Check if node exists
if node_label not in graph:
logger.warning(
f"[{self.workspace}] Node {node_label} not found in the graph"
)
return KnowledgeGraph() # Return empty graph
# Use modified BFS to get nodes, prioritizing high-degree nodes at the same depth
bfs_nodes = []
visited = set()
# Store (node, depth, degree) in the queue
queue = deque([(node_label, 0, graph.degree(node_label))])
# Flag to track if there are unexplored neighbors due to depth limit
has_unexplored_neighbors = False
# Modified breadth-first search with degree-based prioritization
while queue and len(bfs_nodes) < max_nodes:
# Get the current depth from the first node in queue
current_depth = queue[0][1]
# Collect all nodes at the current depth
current_level_nodes = []
while queue and queue[0][1] == current_depth:
current_level_nodes.append(queue.popleft())
# Sort nodes at current depth by degree (highest first)
current_level_nodes.sort(key=lambda x: x[2], reverse=True)
# Process all nodes at current depth in order of degree
for current_node, depth, degree in current_level_nodes:
if current_node not in visited:
visited.add(current_node)
bfs_nodes.append(current_node)
# Only explore neighbors if we haven't reached max_depth
if depth < max_depth:
# Add neighbor nodes to queue with incremented depth
neighbors = list(graph.neighbors(current_node))
# Filter out already visited neighbors
unvisited_neighbors = [
n for n in neighbors if n not in visited
]
# Add neighbors to the queue with their degrees
for neighbor in unvisited_neighbors:
neighbor_degree = graph.degree(neighbor)
queue.append((neighbor, depth + 1, neighbor_degree))
else:
# Check if there are unexplored neighbors (skipped due to depth limit)
neighbors = list(graph.neighbors(current_node))
unvisited_neighbors = [
n for n in neighbors if n not in visited
]
if unvisited_neighbors:
has_unexplored_neighbors = True
# Check if we've reached max_nodes
if len(bfs_nodes) >= max_nodes:
break
# Check if graph is truncated - either due to max_nodes limit or depth limit
if (queue and len(bfs_nodes) >= max_nodes) or has_unexplored_neighbors:
if len(bfs_nodes) >= max_nodes:
result.is_truncated = True
logger.info(
f"[{self.workspace}] Graph truncated: max_nodes limit {max_nodes} reached"
)
else:
logger.info(
f"[{self.workspace}] Graph truncated: found {len(bfs_nodes)} nodes within max_depth {max_depth}"
)
# Create subgraph with BFS discovered nodes
subgraph = graph.subgraph(bfs_nodes)
# Add nodes to result
seen_nodes = set()
seen_edges = set()
for node in subgraph.nodes():
if str(node) in seen_nodes:
continue
node_data = dict(subgraph.nodes[node])
# Get entity_type as labels
labels = []
if "entity_type" in node_data:
if isinstance(node_data["entity_type"], list):
labels.extend(node_data["entity_type"])
else:
labels.append(node_data["entity_type"])
# Create node with properties
node_properties = {k: v for k, v in node_data.items()}
result.nodes.append(
KnowledgeGraphNode(
id=str(node), labels=[str(node)], properties=node_properties
)
)
seen_nodes.add(str(node))
# Add edges to result
for edge in subgraph.edges():
source, target = edge
# Esure unique edge_id for undirect graph
if str(source) > str(target):
source, target = target, source
edge_id = f"{source}-{target}"
if edge_id in seen_edges:
continue
edge_data = dict(subgraph.edges[edge])
# Create edge with complete information
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type="DIRECTED",
source=str(source),
target=str(target),
properties=edge_data,
)
)
seen_edges.add(edge_id)
logger.info(
f"[{self.workspace}] Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
return result
async def get_all_nodes(self) -> list[dict]:
"""Get all nodes in the graph.
Returns:
A list of all nodes, where each node is a dictionary of its properties
"""
graph = await self._get_graph()
all_nodes = []
for node_id, node_data in graph.nodes(data=True):
node_data_with_id = node_data.copy()
node_data_with_id["id"] = node_id
all_nodes.append(node_data_with_id)
return all_nodes
async def get_all_edges(self) -> list[dict]:
"""Get all edges in the graph.
Returns:
A list of all edges, where each edge is a dictionary of its properties
"""
graph = await self._get_graph()
all_edges = []
for u, v, edge_data in graph.edges(data=True):
edge_data_with_nodes = edge_data.copy()
edge_data_with_nodes["source"] = u
edge_data_with_nodes["target"] = v
all_edges.append(edge_data_with_nodes)
return all_edges
async def index_done_callback(self) -> bool:
"""Save data to disk"""
async with self._storage_lock:
# Check if storage was updated by another process
if self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.info(
f"[{self.workspace}] Graph was updated by another process, reloading..."
)
self._graph = (
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
)
# Reset update flag
self.storage_updated.value = False
return False # Return error
# Acquire lock and perform persistence
async with self._storage_lock:
try:
# Save data to disk
NetworkXStorage.write_nx_graph(
self._graph, self._graphml_xml_file, self.workspace
)
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace, workspace=self.workspace)
# Reset own update flag to avoid self-reloading
self.storage_updated.value = False
return True # Return success
except Exception as e:
logger.error(f"[{self.workspace}] Error saving graph: {e}")
return False # Return error
return True
async def drop(self) -> dict[str, str]:
"""Drop all graph data from storage and clean up resources
This method will:
1. Remove the graph storage file if it exists
2. Reset the graph to an empty state
3. Update flags to notify other processes
4. Changes is persisted to disk immediately
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
async with self._storage_lock:
# delete _client_file_name
if os.path.exists(self._graphml_xml_file):
os.remove(self._graphml_xml_file)
self._graph = nx.Graph()
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace, workspace=self.workspace)
# Reset own update flag to avoid self-reloading
self.storage_updated.value = False
logger.info(
f"[{self.workspace}] Process {os.getpid()} drop graph file:{self._graphml_xml_file}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(
f"[{self.workspace}] Error dropping graph file:{self._graphml_xml_file}: {e}"
)
return {"status": "error", "message": str(e)}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff