feat: 重构知识库系统,移除Hermes集成,增强RAG和同步功能
主要变更: - 移除Hermes智能体及相关回调服务 - 新增知识库RAG、同步、调度、规范化和索引任务服务 - 重构orchestrator服务,增强运行时聊天功能 - 更新前端聊天、政策制度、设置等页面样式和逻辑 - 更新expense_claims和document_intelligence服务 - 删除llm_wiki相关服务和测试文件 - 更新docker-compose配置和启动脚本
This commit is contained in:
161
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/__init__.py
Normal file
161
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/__init__.py
Normal file
@@ -0,0 +1,161 @@
|
||||
STORAGE_IMPLEMENTATIONS = {
|
||||
"KV_STORAGE": {
|
||||
"implementations": [
|
||||
"JsonKVStorage",
|
||||
"RedisKVStorage",
|
||||
"PGKVStorage",
|
||||
"MongoKVStorage",
|
||||
"OpenSearchKVStorage",
|
||||
],
|
||||
"required_methods": ["get_by_id", "upsert"],
|
||||
},
|
||||
"GRAPH_STORAGE": {
|
||||
"implementations": [
|
||||
"NetworkXStorage",
|
||||
"Neo4JStorage",
|
||||
"PGGraphStorage",
|
||||
"MongoGraphStorage",
|
||||
"MemgraphStorage",
|
||||
"OpenSearchGraphStorage",
|
||||
],
|
||||
"required_methods": ["upsert_node", "upsert_edge"],
|
||||
},
|
||||
"VECTOR_STORAGE": {
|
||||
"implementations": [
|
||||
"NanoVectorDBStorage",
|
||||
"MilvusVectorDBStorage",
|
||||
"PGVectorStorage",
|
||||
"FaissVectorDBStorage",
|
||||
"QdrantVectorDBStorage",
|
||||
"MongoVectorDBStorage",
|
||||
"OpenSearchVectorDBStorage",
|
||||
# "ChromaVectorDBStorage",
|
||||
],
|
||||
"required_methods": ["query", "upsert"],
|
||||
},
|
||||
"DOC_STATUS_STORAGE": {
|
||||
"implementations": [
|
||||
"JsonDocStatusStorage",
|
||||
"RedisDocStatusStorage",
|
||||
"PGDocStatusStorage",
|
||||
"MongoDocStatusStorage",
|
||||
"OpenSearchDocStatusStorage",
|
||||
],
|
||||
"required_methods": ["get_docs_by_status"],
|
||||
},
|
||||
}
|
||||
|
||||
# Storage implementation environment variable without default value
|
||||
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
||||
# KV Storage Implementations
|
||||
"JsonKVStorage": [],
|
||||
"MongoKVStorage": [
|
||||
"MONGO_URI",
|
||||
"MONGO_DATABASE",
|
||||
],
|
||||
"RedisKVStorage": ["REDIS_URI"],
|
||||
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
# Graph Storage Implementations
|
||||
"NetworkXStorage": [],
|
||||
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
|
||||
"MongoGraphStorage": [
|
||||
"MONGO_URI",
|
||||
"MONGO_DATABASE",
|
||||
],
|
||||
"MemgraphStorage": ["MEMGRAPH_URI"],
|
||||
"AGEStorage": [
|
||||
"AGE_POSTGRES_DB",
|
||||
"AGE_POSTGRES_USER",
|
||||
"AGE_POSTGRES_PASSWORD",
|
||||
],
|
||||
"PGGraphStorage": [
|
||||
"POSTGRES_USER",
|
||||
"POSTGRES_PASSWORD",
|
||||
"POSTGRES_DATABASE",
|
||||
],
|
||||
# Vector Storage Implementations
|
||||
"NanoVectorDBStorage": [],
|
||||
"MilvusVectorDBStorage": [
|
||||
"MILVUS_URI",
|
||||
"MILVUS_DB_NAME",
|
||||
],
|
||||
# "ChromaVectorDBStorage": [],
|
||||
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
"FaissVectorDBStorage": [],
|
||||
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
|
||||
"MongoVectorDBStorage": [
|
||||
"MONGO_URI",
|
||||
"MONGO_DATABASE",
|
||||
],
|
||||
# Document Status Storage Implementations
|
||||
"JsonDocStatusStorage": [],
|
||||
"RedisDocStatusStorage": ["REDIS_URI"],
|
||||
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
"MongoDocStatusStorage": [
|
||||
"MONGO_URI",
|
||||
"MONGO_DATABASE",
|
||||
],
|
||||
# OpenSearch Storage Implementations
|
||||
"OpenSearchKVStorage": [
|
||||
"OPENSEARCH_HOSTS",
|
||||
],
|
||||
"OpenSearchDocStatusStorage": [
|
||||
"OPENSEARCH_HOSTS",
|
||||
],
|
||||
"OpenSearchGraphStorage": [
|
||||
"OPENSEARCH_HOSTS",
|
||||
],
|
||||
"OpenSearchVectorDBStorage": [
|
||||
"OPENSEARCH_HOSTS",
|
||||
],
|
||||
}
|
||||
|
||||
# Storage implementation module mapping
|
||||
STORAGES = {
|
||||
"NetworkXStorage": ".kg.networkx_impl",
|
||||
"JsonKVStorage": ".kg.json_kv_impl",
|
||||
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
|
||||
"JsonDocStatusStorage": ".kg.json_doc_status_impl",
|
||||
"Neo4JStorage": ".kg.neo4j_impl",
|
||||
"MilvusVectorDBStorage": ".kg.milvus_impl",
|
||||
"MongoKVStorage": ".kg.mongo_impl",
|
||||
"MongoDocStatusStorage": ".kg.mongo_impl",
|
||||
"MongoGraphStorage": ".kg.mongo_impl",
|
||||
"MongoVectorDBStorage": ".kg.mongo_impl",
|
||||
"RedisKVStorage": ".kg.redis_impl",
|
||||
"RedisDocStatusStorage": ".kg.redis_impl",
|
||||
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
||||
"PGKVStorage": ".kg.postgres_impl",
|
||||
"PGVectorStorage": ".kg.postgres_impl",
|
||||
"AGEStorage": ".kg.age_impl",
|
||||
"PGGraphStorage": ".kg.postgres_impl",
|
||||
"PGDocStatusStorage": ".kg.postgres_impl",
|
||||
"FaissVectorDBStorage": ".kg.faiss_impl",
|
||||
"QdrantVectorDBStorage": ".kg.qdrant_impl",
|
||||
"MemgraphStorage": ".kg.memgraph_impl",
|
||||
"OpenSearchKVStorage": ".kg.opensearch_impl",
|
||||
"OpenSearchDocStatusStorage": ".kg.opensearch_impl",
|
||||
"OpenSearchGraphStorage": ".kg.opensearch_impl",
|
||||
"OpenSearchVectorDBStorage": ".kg.opensearch_impl",
|
||||
}
|
||||
|
||||
|
||||
def verify_storage_implementation(storage_type: str, storage_name: str) -> None:
|
||||
"""Verify if storage implementation is compatible with specified storage type
|
||||
|
||||
Args:
|
||||
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
|
||||
storage_name: Storage implementation name
|
||||
|
||||
Raises:
|
||||
ValueError: If storage implementation is incompatible or missing required methods
|
||||
"""
|
||||
if storage_type not in STORAGE_IMPLEMENTATIONS:
|
||||
raise ValueError(f"Unknown storage type: {storage_type}")
|
||||
|
||||
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
|
||||
if storage_name not in storage_info["implementations"]:
|
||||
raise ValueError(
|
||||
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
|
||||
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
|
||||
)
|
||||
@@ -0,0 +1,343 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, final
|
||||
import numpy as np
|
||||
|
||||
from lightrag.base import BaseVectorStorage
|
||||
from lightrag.utils import logger
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("chromadb"):
|
||||
pm.install("chromadb")
|
||||
|
||||
from chromadb import HttpClient, PersistentClient # type: ignore
|
||||
from chromadb.config import Settings # type: ignore
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
"""ChromaDB vector storage implementation."""
|
||||
|
||||
def __post_init__(self):
|
||||
self._validate_embedding_func()
|
||||
try:
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||
if cosine_threshold is None:
|
||||
raise ValueError(
|
||||
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
user_collection_settings = config.get("collection_settings", {})
|
||||
# Default HNSW index settings for ChromaDB
|
||||
default_collection_settings = {
|
||||
# Distance metric used for similarity search (cosine similarity)
|
||||
"hnsw:space": "cosine",
|
||||
# Number of nearest neighbors to explore during index construction
|
||||
# Higher values = better recall but slower indexing
|
||||
"hnsw:construction_ef": 128,
|
||||
# Number of nearest neighbors to explore during search
|
||||
# Higher values = better recall but slower search
|
||||
"hnsw:search_ef": 128,
|
||||
# Number of connections per node in the HNSW graph
|
||||
# Higher values = better recall but more memory usage
|
||||
"hnsw:M": 16,
|
||||
# Number of vectors to process in one batch during indexing
|
||||
"hnsw:batch_size": 100,
|
||||
# Number of updates before forcing index synchronization
|
||||
# Lower values = more frequent syncs but slower indexing
|
||||
"hnsw:sync_threshold": 1000,
|
||||
}
|
||||
collection_settings = {
|
||||
**default_collection_settings,
|
||||
**user_collection_settings,
|
||||
}
|
||||
|
||||
local_path = config.get("local_path", None)
|
||||
if local_path:
|
||||
self._client = PersistentClient(
|
||||
path=local_path,
|
||||
settings=Settings(
|
||||
allow_reset=True,
|
||||
anonymized_telemetry=False,
|
||||
),
|
||||
)
|
||||
else:
|
||||
auth_provider = config.get(
|
||||
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
|
||||
)
|
||||
auth_credentials = config.get("auth_token", "secret-token")
|
||||
headers = {}
|
||||
|
||||
if "token_authn" in auth_provider:
|
||||
headers = {
|
||||
config.get(
|
||||
"auth_header_name", "X-Chroma-Token"
|
||||
): auth_credentials
|
||||
}
|
||||
elif "basic_authn" in auth_provider:
|
||||
auth_credentials = config.get("auth_credentials", "admin:admin")
|
||||
|
||||
self._client = HttpClient(
|
||||
host=config.get("host", "localhost"),
|
||||
port=config.get("port", 8000),
|
||||
headers=headers,
|
||||
settings=Settings(
|
||||
chroma_api_impl="rest",
|
||||
chroma_client_auth_provider=auth_provider,
|
||||
chroma_client_auth_credentials=auth_credentials,
|
||||
allow_reset=True,
|
||||
anonymized_telemetry=False,
|
||||
),
|
||||
)
|
||||
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self.namespace,
|
||||
metadata={
|
||||
**collection_settings,
|
||||
"dimension": self.embedding_func.embedding_dim,
|
||||
},
|
||||
)
|
||||
# Use batch size from collection settings if specified
|
||||
self._max_batch_size = self.global_config.get(
|
||||
"embedding_batch_num", collection_settings.get("hnsw:batch_size", 32)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"ChromaDB initialization failed: {str(e)}")
|
||||
raise
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
|
||||
try:
|
||||
import time
|
||||
|
||||
current_time = int(time.time())
|
||||
|
||||
ids = list(data.keys())
|
||||
documents = [v["content"] for v in data.values()]
|
||||
metadatas = [
|
||||
{
|
||||
**{k: v for k, v in item.items() if k in self.meta_fields},
|
||||
"created_at": current_time,
|
||||
}
|
||||
or {"_default": "true", "created_at": current_time}
|
||||
for item in data.values()
|
||||
]
|
||||
|
||||
# Process in batches
|
||||
batches = [
|
||||
documents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(documents), self._max_batch_size)
|
||||
]
|
||||
|
||||
embedding_tasks = [self.embedding_func(batch) for batch in batches]
|
||||
embeddings_list = []
|
||||
|
||||
# Pre-allocate embeddings_list with known size
|
||||
embeddings_list = [None] * len(embedding_tasks)
|
||||
|
||||
# Use asyncio.gather instead of as_completed if order doesn't matter
|
||||
embeddings_results = await asyncio.gather(*embedding_tasks)
|
||||
embeddings_list = list(embeddings_results)
|
||||
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
|
||||
# Upsert in batches
|
||||
for i in range(0, len(ids), self._max_batch_size):
|
||||
batch_slice = slice(i, i + self._max_batch_size)
|
||||
|
||||
self._collection.upsert(
|
||||
ids=ids[batch_slice],
|
||||
embeddings=embeddings[batch_slice].tolist(),
|
||||
documents=documents[batch_slice],
|
||||
metadatas=metadatas[batch_slice],
|
||||
)
|
||||
|
||||
return ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during ChromaDB upsert: {str(e)}")
|
||||
raise
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
try:
|
||||
embedding = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
) # higher priority for query
|
||||
|
||||
results = self._collection.query(
|
||||
query_embeddings=embedding.tolist()
|
||||
if not isinstance(embedding, list)
|
||||
else embedding,
|
||||
n_results=top_k * 2, # Request more results to allow for filtering
|
||||
include=["metadatas", "distances", "documents"],
|
||||
)
|
||||
|
||||
# Filter results by cosine similarity threshold and take top k
|
||||
# We request 2x results initially to have enough after filtering
|
||||
# ChromaDB returns cosine similarity (1 = identical, 0 = orthogonal)
|
||||
# We convert to distance (0 = identical, 1 = orthogonal) via (1 - similarity)
|
||||
# Only keep results with distance below threshold, then take top k
|
||||
return [
|
||||
{
|
||||
"id": results["ids"][0][i],
|
||||
"distance": 1 - results["distances"][0][i],
|
||||
"content": results["documents"][0][i],
|
||||
"created_at": results["metadatas"][0][i].get("created_at"),
|
||||
**results["metadatas"][0][i],
|
||||
}
|
||||
for i in range(len(results["ids"][0]))
|
||||
if (1 - results["distances"][0][i]) >= self.cosine_better_than_threshold
|
||||
][:top_k]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during ChromaDB query: {str(e)}")
|
||||
raise
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# ChromaDB handles persistence automatically
|
||||
pass
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""Delete an entity by its ID.
|
||||
|
||||
Args:
|
||||
entity_name: The ID of the entity to delete
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Deleting entity with ID {entity_name} from {self.namespace}")
|
||||
self._collection.delete(ids=[entity_name])
|
||||
except Exception as e:
|
||||
logger.error(f"Error during entity deletion: {str(e)}")
|
||||
raise
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""Delete an entity and its relations by ID.
|
||||
In vector DB context, this is equivalent to delete_entity.
|
||||
|
||||
Args:
|
||||
entity_name: The ID of the entity to delete
|
||||
"""
|
||||
await self.delete_entity(entity_name)
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
"""Delete vectors with specified IDs
|
||||
|
||||
Args:
|
||||
ids: List of vector IDs to be deleted
|
||||
"""
|
||||
try:
|
||||
self._collection.delete(ids=ids)
|
||||
logger.debug(
|
||||
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during prefix search in ChromaDB: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
try:
|
||||
# Query the collection for a single vector by ID
|
||||
result = self._collection.get(
|
||||
ids=[id], include=["metadatas", "embeddings", "documents"]
|
||||
)
|
||||
|
||||
if not result or not result["ids"] or len(result["ids"]) == 0:
|
||||
return None
|
||||
|
||||
# Format the result to match the expected structure
|
||||
return {
|
||||
"id": result["ids"][0],
|
||||
"vector": result["embeddings"][0],
|
||||
"content": result["documents"][0],
|
||||
"created_at": result["metadatas"][0].get("created_at"),
|
||||
**result["metadatas"][0],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Query the collection for multiple vectors by IDs
|
||||
result = self._collection.get(
|
||||
ids=ids, include=["metadatas", "embeddings", "documents"]
|
||||
)
|
||||
|
||||
if not result or not result["ids"] or len(result["ids"]) == 0:
|
||||
return []
|
||||
|
||||
# Format the results to match the expected structure and preserve ordering
|
||||
formatted_map: dict[str, dict[str, Any]] = {}
|
||||
for i, result_id in enumerate(result["ids"]):
|
||||
record = {
|
||||
"id": result_id,
|
||||
"vector": result["embeddings"][i],
|
||||
"content": result["documents"][i],
|
||||
"created_at": result["metadatas"][i].get("created_at"),
|
||||
**result["metadatas"][i],
|
||||
}
|
||||
formatted_map[str(result_id)] = record
|
||||
|
||||
ordered_results: list[dict[str, Any] | None] = []
|
||||
for requested_id in ids:
|
||||
ordered_results.append(formatted_map.get(str(requested_id)))
|
||||
|
||||
return ordered_results
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all vector data from storage and clean up resources
|
||||
|
||||
This method will delete all documents from the ChromaDB collection.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
# Get all IDs in the collection
|
||||
result = self._collection.get(include=[])
|
||||
if result and result["ids"] and len(result["ids"]) > 0:
|
||||
# Delete all documents
|
||||
self._collection.delete(ids=result["ids"])
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} drop ChromaDB collection {self.namespace}"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping ChromaDB collection {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
585
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/faiss_impl.py
Normal file
585
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/faiss_impl.py
Normal file
@@ -0,0 +1,585 @@
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Any, final
|
||||
import json
|
||||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lightrag.utils import logger, compute_mdhash_id
|
||||
from lightrag.base import BaseVectorStorage
|
||||
|
||||
from .shared_storage import (
|
||||
get_namespace_lock,
|
||||
get_update_flag,
|
||||
set_all_update_flags,
|
||||
)
|
||||
|
||||
# You must manually install faiss-cpu or faiss-gpu before using FAISS vector db
|
||||
import faiss # type: ignore
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class FaissVectorDBStorage(BaseVectorStorage):
|
||||
"""
|
||||
A Faiss-based Vector DB Storage for LightRAG.
|
||||
Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
self._validate_embedding_func()
|
||||
# Grab config values if available
|
||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||
if cosine_threshold is None:
|
||||
raise ValueError(
|
||||
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
# Where to save index file if you want persistent storage
|
||||
working_dir = self.global_config["working_dir"]
|
||||
if self.workspace:
|
||||
# Include workspace in the file path for data isolation
|
||||
workspace_dir = os.path.join(working_dir, self.workspace)
|
||||
|
||||
else:
|
||||
# Default behavior when workspace is empty
|
||||
workspace_dir = working_dir
|
||||
self.workspace = ""
|
||||
|
||||
os.makedirs(workspace_dir, exist_ok=True)
|
||||
self._faiss_index_file = os.path.join(
|
||||
workspace_dir, f"faiss_index_{self.namespace}.index"
|
||||
)
|
||||
self._meta_file = self._faiss_index_file + ".meta.json"
|
||||
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
# Embedding dimension (e.g. 768) must match your embedding function
|
||||
self._dim = self.embedding_func.embedding_dim
|
||||
|
||||
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
|
||||
# If you have a large number of vectors, you might want IVF or other indexes.
|
||||
# For demonstration, we use a simple IndexFlatIP.
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
# Keep a local store for metadata, IDs, etc.
|
||||
# Maps <int faiss_id> → metadata (including your original ID).
|
||||
self._id_to_meta = {}
|
||||
|
||||
self._load_faiss_index()
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize storage data"""
|
||||
# Get the update flag for cross-process update notification
|
||||
self.storage_updated = await get_update_flag(
|
||||
self.namespace, workspace=self.workspace
|
||||
)
|
||||
# Get the storage lock for use in other methods
|
||||
self._storage_lock = get_namespace_lock(
|
||||
self.namespace, workspace=self.workspace
|
||||
)
|
||||
|
||||
async def _get_index(self):
|
||||
"""Check if the shtorage should be reloaded"""
|
||||
# Acquire lock to prevent concurrent read and write
|
||||
async with self._storage_lock:
|
||||
# Check if storage was updated by another process
|
||||
if self.storage_updated.value:
|
||||
logger.info(
|
||||
f"[{self.workspace}] Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process"
|
||||
)
|
||||
# Reload data
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
self._id_to_meta = {}
|
||||
self._load_faiss_index()
|
||||
self.storage_updated.value = False
|
||||
return self._index
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""
|
||||
Insert or update vectors in the Faiss index.
|
||||
|
||||
data: {
|
||||
"custom_id_1": {
|
||||
"content": <text>,
|
||||
...metadata...
|
||||
},
|
||||
"custom_id_2": {
|
||||
"content": <text>,
|
||||
...metadata...
|
||||
},
|
||||
...
|
||||
}
|
||||
"""
|
||||
logger.debug(
|
||||
f"[{self.workspace}] FAISS: Inserting {len(data)} to {self.namespace}"
|
||||
)
|
||||
if not data:
|
||||
return
|
||||
|
||||
current_time = int(time.time())
|
||||
|
||||
# Prepare data for embedding
|
||||
list_data = []
|
||||
contents = []
|
||||
for k, v in data.items():
|
||||
# Store only known meta fields if needed
|
||||
meta = {mf: v[mf] for mf in self.meta_fields if mf in v}
|
||||
meta["__id__"] = k
|
||||
meta["__created_at__"] = current_time
|
||||
list_data.append(meta)
|
||||
contents.append(v["content"])
|
||||
|
||||
# Split into batches for embedding if needed
|
||||
batches = [
|
||||
contents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
|
||||
embedding_tasks = [
|
||||
self.embedding_func(batch, context="document") for batch in batches
|
||||
]
|
||||
embeddings_list = await asyncio.gather(*embedding_tasks)
|
||||
|
||||
# Flatten the list of arrays
|
||||
embeddings = np.concatenate(embeddings_list, axis=0)
|
||||
if len(embeddings) != len(list_data):
|
||||
logger.error(
|
||||
f"[{self.workspace}] Embedding size mismatch. Embeddings: {len(embeddings)}, Data: {len(list_data)}"
|
||||
)
|
||||
return []
|
||||
|
||||
# Convert to float32 and normalize embeddings for cosine similarity (in-place)
|
||||
embeddings = embeddings.astype(np.float32)
|
||||
faiss.normalize_L2(embeddings)
|
||||
|
||||
# Upsert logic:
|
||||
# 1. Identify which vectors to remove if they exist
|
||||
# 2. Remove them
|
||||
# 3. Add the new vectors
|
||||
existing_ids_to_remove = []
|
||||
for meta, emb in zip(list_data, embeddings):
|
||||
faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"])
|
||||
if faiss_internal_id is not None:
|
||||
existing_ids_to_remove.append(faiss_internal_id)
|
||||
|
||||
if existing_ids_to_remove:
|
||||
await self._remove_faiss_ids(existing_ids_to_remove)
|
||||
|
||||
# Step 2: Add new vectors
|
||||
index = await self._get_index()
|
||||
start_idx = index.ntotal
|
||||
index.add(embeddings)
|
||||
|
||||
# Step 3: Store metadata + vector for each new ID
|
||||
for i, meta in enumerate(list_data):
|
||||
fid = start_idx + i
|
||||
# Store the raw vector so we can rebuild if something is removed
|
||||
meta["__vector__"] = embeddings[i].tolist()
|
||||
self._id_to_meta.update({fid: meta})
|
||||
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Upserted {len(list_data)} vectors into Faiss index."
|
||||
)
|
||||
return [m["__id__"] for m in list_data]
|
||||
|
||||
async def query(
|
||||
self, query: str, top_k: int, query_embedding: list[float] = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Search by a textual query; returns top_k results with their metadata + similarity distance.
|
||||
"""
|
||||
if query_embedding is not None:
|
||||
embedding = np.array([query_embedding], dtype=np.float32)
|
||||
else:
|
||||
embedding = await self.embedding_func(
|
||||
[query], context="query", _priority=5
|
||||
) # higher priority for query
|
||||
# embedding is shape (1, dim)
|
||||
embedding = np.array(embedding, dtype=np.float32)
|
||||
|
||||
faiss.normalize_L2(embedding) # we do in-place normalization
|
||||
|
||||
# Perform the similarity search
|
||||
index = await self._get_index()
|
||||
distances, indices = index.search(embedding, top_k)
|
||||
|
||||
distances = distances[0]
|
||||
indices = indices[0]
|
||||
|
||||
results = []
|
||||
for dist, idx in zip(distances, indices):
|
||||
if idx == -1:
|
||||
# Faiss returns -1 if no neighbor
|
||||
continue
|
||||
|
||||
# Cosine similarity threshold
|
||||
if dist < self.cosine_better_than_threshold:
|
||||
continue
|
||||
|
||||
meta = self._id_to_meta.get(idx, {})
|
||||
# Filter out __vector__ from query results to avoid returning large vector data
|
||||
filtered_meta = {k: v for k, v in meta.items() if k != "__vector__"}
|
||||
results.append(
|
||||
{
|
||||
**filtered_meta,
|
||||
"id": meta.get("__id__"),
|
||||
"distance": float(dist),
|
||||
"created_at": meta.get("__created_at__"),
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@property
|
||||
def client_storage(self):
|
||||
# Return whatever structure LightRAG might need for debugging
|
||||
return {"data": list(self._id_to_meta.values())}
|
||||
|
||||
async def delete(self, ids: list[str]):
|
||||
"""
|
||||
Delete vectors for the provided custom IDs.
|
||||
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Deleting {len(ids)} vectors from {self.namespace}"
|
||||
)
|
||||
to_remove = []
|
||||
for cid in ids:
|
||||
fid = self._find_faiss_id_by_custom_id(cid)
|
||||
if fid is not None:
|
||||
to_remove.append(fid)
|
||||
|
||||
if to_remove:
|
||||
await self._remove_faiss_ids(to_remove)
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Successfully deleted {len(to_remove)} vectors from {self.namespace}"
|
||||
)
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Attempting to delete entity {entity_name} with ID {entity_id}"
|
||||
)
|
||||
await self.delete([entity_id])
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
logger.debug(f"[{self.workspace}] Searching relations for entity {entity_name}")
|
||||
relations = []
|
||||
for fid, meta in self._id_to_meta.items():
|
||||
if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name:
|
||||
relations.append(fid)
|
||||
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Found {len(relations)} relations for {entity_name}"
|
||||
)
|
||||
if relations:
|
||||
await self._remove_faiss_ids(relations)
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Deleted {len(relations)} relations for {entity_name}"
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------------------
|
||||
# Internal helper methods
|
||||
# --------------------------------------------------------------------------------
|
||||
|
||||
def _find_faiss_id_by_custom_id(self, custom_id: str):
|
||||
"""
|
||||
Return the Faiss internal ID for a given custom ID, or None if not found.
|
||||
"""
|
||||
for fid, meta in self._id_to_meta.items():
|
||||
if meta.get("__id__") == custom_id:
|
||||
return fid
|
||||
return None
|
||||
|
||||
async def _remove_faiss_ids(self, fid_list):
|
||||
"""
|
||||
Remove a list of internal Faiss IDs from the index.
|
||||
Because IndexFlatIP doesn't support 'removals',
|
||||
we rebuild the index excluding those vectors.
|
||||
"""
|
||||
keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list]
|
||||
|
||||
# Rebuild the index
|
||||
vectors_to_keep = []
|
||||
new_id_to_meta = {}
|
||||
for old_fid in keep_fids:
|
||||
vec_meta = self._id_to_meta[old_fid]
|
||||
if "__vector__" in vec_meta:
|
||||
vec = vec_meta["__vector__"]
|
||||
elif old_fid < self._index.ntotal:
|
||||
vec = self._index.reconstruct(old_fid).tolist()
|
||||
vec_meta["__vector__"] = vec
|
||||
else:
|
||||
logger.warning(
|
||||
f"[{self.workspace}] Skipping fid={old_fid} during rebuild: "
|
||||
f"no vector and fid exceeds index size ({self._index.ntotal})"
|
||||
)
|
||||
continue
|
||||
new_fid = len(vectors_to_keep)
|
||||
vectors_to_keep.append(vec)
|
||||
new_id_to_meta[new_fid] = vec_meta
|
||||
|
||||
async with self._storage_lock:
|
||||
# Re-init index
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
if vectors_to_keep:
|
||||
arr = np.array(vectors_to_keep, dtype=np.float32)
|
||||
self._index.add(arr)
|
||||
|
||||
self._id_to_meta = new_id_to_meta
|
||||
|
||||
def _save_faiss_index(self):
|
||||
"""
|
||||
Save the current Faiss index + metadata to disk so it can persist across runs.
|
||||
"""
|
||||
faiss.write_index(self._index, self._faiss_index_file)
|
||||
|
||||
# Save metadata dict to JSON, excluding __vector__ since vectors are
|
||||
# already stored in the Faiss index file and can be reconstructed on load.
|
||||
serializable_dict = {}
|
||||
for fid, meta in self._id_to_meta.items():
|
||||
filtered_meta = {k: v for k, v in meta.items() if k != "__vector__"}
|
||||
serializable_dict[str(fid)] = filtered_meta
|
||||
|
||||
# Atomic write: write to temp file first, then rename to reduce
|
||||
# mismatch risk between index and meta files on crash.
|
||||
tmp_meta_file = self._meta_file + ".tmp"
|
||||
with open(tmp_meta_file, "w", encoding="utf-8") as f:
|
||||
json.dump(serializable_dict, f)
|
||||
os.replace(tmp_meta_file, self._meta_file)
|
||||
|
||||
def _load_faiss_index(self):
|
||||
"""
|
||||
Load the Faiss index + metadata from disk if it exists,
|
||||
and rebuild in-memory structures so we can query.
|
||||
"""
|
||||
if not os.path.exists(self._faiss_index_file):
|
||||
logger.warning(
|
||||
f"[{self.workspace}] No existing Faiss index file found for {self.namespace}"
|
||||
)
|
||||
return
|
||||
|
||||
dim_mismatch = False
|
||||
try:
|
||||
# Load the Faiss index
|
||||
self._index = faiss.read_index(self._faiss_index_file)
|
||||
|
||||
# Verify dimension consistency between loaded index and embedding function
|
||||
if self._index.d != self._dim:
|
||||
error_msg = (
|
||||
f"Dimension mismatch: loaded Faiss index has dimension {self._index.d}, "
|
||||
f"but embedding function expects dimension {self._dim}. "
|
||||
f"Please ensure the embedding model matches the stored index or rebuild the index."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
dim_mismatch = True
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Load metadata
|
||||
with open(self._meta_file, "r", encoding="utf-8") as f:
|
||||
stored_dict = json.load(f)
|
||||
|
||||
# Convert string keys back to int and reconstruct vectors from index
|
||||
self._id_to_meta = {}
|
||||
for fid_str, meta in stored_dict.items():
|
||||
fid = int(fid_str)
|
||||
if fid >= self._index.ntotal:
|
||||
logger.warning(
|
||||
f"[{self.workspace}] Skipping metadata row fid={fid}: "
|
||||
f"exceeds index size ({self._index.ntotal})"
|
||||
)
|
||||
continue
|
||||
if "__vector__" not in meta:
|
||||
meta["__vector__"] = self._index.reconstruct(fid).tolist()
|
||||
self._id_to_meta[fid] = meta
|
||||
|
||||
logger.info(
|
||||
f"[{self.workspace}] Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}"
|
||||
)
|
||||
except Exception as e:
|
||||
if dim_mismatch:
|
||||
raise
|
||||
logger.error(
|
||||
f"[{self.workspace}] Failed to load Faiss index or metadata: {e}"
|
||||
)
|
||||
logger.warning(f"[{self.workspace}] Starting with an empty Faiss index.")
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
self._id_to_meta = {}
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
async with self._storage_lock:
|
||||
# Check if storage was updated by another process
|
||||
if self.storage_updated.value:
|
||||
# Storage was updated by another process, reload data instead of saving
|
||||
logger.warning(
|
||||
f"[{self.workspace}] Storage for FAISS {self.namespace} was updated by another process, reloading..."
|
||||
)
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
self._id_to_meta = {}
|
||||
self._load_faiss_index()
|
||||
self.storage_updated.value = False
|
||||
return False # Return error
|
||||
|
||||
# Acquire lock and perform persistence
|
||||
async with self._storage_lock:
|
||||
try:
|
||||
# Save data to disk
|
||||
self._save_faiss_index()
|
||||
# Notify other processes that data has been updated
|
||||
await set_all_update_flags(self.namespace, workspace=self.workspace)
|
||||
# Reset own update flag to avoid self-reloading
|
||||
self.storage_updated.value = False
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[{self.workspace}] Error saving FAISS index for {self.namespace}: {e}"
|
||||
)
|
||||
return False # Return error
|
||||
|
||||
return True # Return success
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
# Find the Faiss internal ID for the custom ID
|
||||
fid = self._find_faiss_id_by_custom_id(id)
|
||||
if fid is None:
|
||||
return None
|
||||
|
||||
# Get the metadata for the found ID
|
||||
metadata = self._id_to_meta.get(fid, {})
|
||||
if not metadata:
|
||||
return None
|
||||
|
||||
# Filter out __vector__ from metadata to avoid returning large vector data
|
||||
filtered_metadata = {k: v for k, v in metadata.items() if k != "__vector__"}
|
||||
return {
|
||||
**filtered_metadata,
|
||||
"id": metadata.get("__id__"),
|
||||
"created_at": metadata.get("__created_at__"),
|
||||
}
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
results: list[dict[str, Any] | None] = []
|
||||
for id in ids:
|
||||
record = None
|
||||
fid = self._find_faiss_id_by_custom_id(id)
|
||||
if fid is not None:
|
||||
metadata = self._id_to_meta.get(fid)
|
||||
if metadata:
|
||||
# Filter out __vector__ from metadata to avoid returning large vector data
|
||||
filtered_metadata = {
|
||||
k: v for k, v in metadata.items() if k != "__vector__"
|
||||
}
|
||||
record = {
|
||||
**filtered_metadata,
|
||||
"id": metadata.get("__id__"),
|
||||
"created_at": metadata.get("__created_at__"),
|
||||
}
|
||||
results.append(record)
|
||||
|
||||
return results
|
||||
|
||||
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
|
||||
"""Get vectors by their IDs, returning only ID and vector data for efficiency
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
Dictionary mapping IDs to their vector embeddings
|
||||
Format: {id: [vector_values], ...}
|
||||
"""
|
||||
if not ids:
|
||||
return {}
|
||||
|
||||
vectors_dict = {}
|
||||
for id in ids:
|
||||
# Find the Faiss internal ID for the custom ID
|
||||
fid = self._find_faiss_id_by_custom_id(id)
|
||||
if fid is not None and fid in self._id_to_meta:
|
||||
metadata = self._id_to_meta[fid]
|
||||
# Get the stored vector from metadata
|
||||
if "__vector__" in metadata:
|
||||
vectors_dict[id] = metadata["__vector__"]
|
||||
|
||||
return vectors_dict
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all vector data from storage and clean up resources
|
||||
|
||||
This method will:
|
||||
1. Remove the vector database storage file if it exists
|
||||
2. Reinitialize the vector database client
|
||||
3. Update flags to notify other processes
|
||||
4. Changes is persisted to disk immediately
|
||||
|
||||
This method will remove all vectors from the Faiss index and delete the storage files.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
async with self._storage_lock:
|
||||
# Reset the index
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
self._id_to_meta = {}
|
||||
|
||||
# Remove storage files if they exist
|
||||
if os.path.exists(self._faiss_index_file):
|
||||
os.remove(self._faiss_index_file)
|
||||
if os.path.exists(self._meta_file):
|
||||
os.remove(self._meta_file)
|
||||
|
||||
self._id_to_meta = {}
|
||||
self._load_faiss_index()
|
||||
|
||||
# Notify other processes
|
||||
await set_all_update_flags(self.namespace, workspace=self.workspace)
|
||||
self.storage_updated.value = False
|
||||
|
||||
logger.info(
|
||||
f"[{self.workspace}] Process {os.getpid()} drop FAISS index {self.namespace}"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[{self.workspace}] Error dropping FAISS index {self.namespace}: {e}"
|
||||
)
|
||||
return {"status": "error", "message": str(e)}
|
||||
@@ -0,0 +1,422 @@
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
from typing import Any, Union, final
|
||||
|
||||
from lightrag.base import (
|
||||
DocProcessingStatus,
|
||||
DocStatus,
|
||||
DocStatusStorage,
|
||||
)
|
||||
from lightrag.utils import (
|
||||
_cooperative_yield,
|
||||
load_json,
|
||||
logger,
|
||||
write_json,
|
||||
get_pinyin_sort_key,
|
||||
)
|
||||
from lightrag.exceptions import StorageNotInitializedError
|
||||
from .shared_storage import (
|
||||
get_namespace_data,
|
||||
get_namespace_lock,
|
||||
get_data_init_lock,
|
||||
get_update_flag,
|
||||
set_all_update_flags,
|
||||
clear_all_update_flags,
|
||||
try_initialize_namespace,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class JsonDocStatusStorage(DocStatusStorage):
|
||||
"""JSON implementation of document status storage"""
|
||||
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
if self.workspace:
|
||||
# Include workspace in the file path for data isolation
|
||||
workspace_dir = os.path.join(working_dir, self.workspace)
|
||||
else:
|
||||
# Default behavior when workspace is empty
|
||||
workspace_dir = working_dir
|
||||
self.workspace = ""
|
||||
|
||||
os.makedirs(workspace_dir, exist_ok=True)
|
||||
self._file_name = os.path.join(workspace_dir, f"kv_store_{self.namespace}.json")
|
||||
self._data = None
|
||||
self._storage_lock = None
|
||||
self.storage_updated = None
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize storage data"""
|
||||
self._storage_lock = get_namespace_lock(
|
||||
self.namespace, workspace=self.workspace
|
||||
)
|
||||
self.storage_updated = await get_update_flag(
|
||||
self.namespace, workspace=self.workspace
|
||||
)
|
||||
async with get_data_init_lock():
|
||||
# check need_init must before get_namespace_data
|
||||
need_init = await try_initialize_namespace(
|
||||
self.namespace, workspace=self.workspace
|
||||
)
|
||||
self._data = await get_namespace_data(
|
||||
self.namespace, workspace=self.workspace
|
||||
)
|
||||
if need_init:
|
||||
loaded_data = load_json(self._file_name) or {}
|
||||
async with self._storage_lock:
|
||||
self._data.update(loaded_data)
|
||||
logger.info(
|
||||
f"[{self.workspace}] Process {os.getpid()} doc status load {self.namespace} with {len(loaded_data)} records"
|
||||
)
|
||||
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
||||
if self._storage_lock is None:
|
||||
raise StorageNotInitializedError("JsonDocStatusStorage")
|
||||
async with self._storage_lock:
|
||||
return set(keys) - set(self._data.keys())
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
ordered_results: list[dict[str, Any] | None] = []
|
||||
if self._storage_lock is None:
|
||||
raise StorageNotInitializedError("JsonDocStatusStorage")
|
||||
async with self._storage_lock:
|
||||
for id in ids:
|
||||
data = self._data.get(id, None)
|
||||
if data:
|
||||
ordered_results.append(data.copy())
|
||||
else:
|
||||
ordered_results.append(None)
|
||||
return ordered_results
|
||||
|
||||
async def get_status_counts(self) -> dict[str, int]:
|
||||
"""Get counts of documents in each status"""
|
||||
counts = {status.value: 0 for status in DocStatus}
|
||||
if self._storage_lock is None:
|
||||
raise StorageNotInitializedError("JsonDocStatusStorage")
|
||||
async with self._storage_lock:
|
||||
for doc in self._data.values():
|
||||
counts[doc["status"]] += 1
|
||||
return counts
|
||||
|
||||
async def get_docs_by_status(
|
||||
self, status: DocStatus
|
||||
) -> dict[str, DocProcessingStatus]:
|
||||
"""Get all documents with a specific status"""
|
||||
return await self.get_docs_by_statuses([status])
|
||||
|
||||
async def get_docs_by_statuses(
|
||||
self, statuses: list[DocStatus]
|
||||
) -> dict[str, DocProcessingStatus]:
|
||||
"""Get all documents matching any of the given statuses in a single pass.
|
||||
|
||||
Acquires the storage lock once and scans the in-memory dict once,
|
||||
filtering against a set of status values. More efficient than N separate
|
||||
get_docs_by_status() calls, which would acquire the lock N times and scan
|
||||
the data N times.
|
||||
"""
|
||||
if not statuses:
|
||||
return {}
|
||||
status_values = {s.value for s in statuses}
|
||||
result = {}
|
||||
async with self._storage_lock:
|
||||
for k, v in self._data.items():
|
||||
if v["status"] not in status_values:
|
||||
continue
|
||||
try:
|
||||
data = v.copy()
|
||||
data.pop("content", None)
|
||||
if not data.get("file_path"):
|
||||
data["file_path"] = "no-file-path"
|
||||
if "metadata" not in data:
|
||||
data["metadata"] = {}
|
||||
if "error_msg" not in data:
|
||||
data["error_msg"] = None
|
||||
result[k] = DocProcessingStatus(**data)
|
||||
except (KeyError, TypeError) as e:
|
||||
logger.error(
|
||||
f"[{self.workspace}] Missing required field for document {k}: {e}"
|
||||
)
|
||||
continue
|
||||
return result
|
||||
|
||||
async def get_docs_by_track_id(
|
||||
self, track_id: str
|
||||
) -> dict[str, DocProcessingStatus]:
|
||||
"""Get all documents with a specific track_id"""
|
||||
result = {}
|
||||
async with self._storage_lock:
|
||||
for k, v in self._data.items():
|
||||
if v.get("track_id") == track_id:
|
||||
try:
|
||||
# Make a copy of the data to avoid modifying the original
|
||||
data = v.copy()
|
||||
# Remove deprecated content field if it exists
|
||||
data.pop("content", None)
|
||||
# Normalize missing or null file_path
|
||||
if not data.get("file_path"):
|
||||
data["file_path"] = "no-file-path"
|
||||
# Ensure new fields exist with default values
|
||||
if "metadata" not in data:
|
||||
data["metadata"] = {}
|
||||
if "error_msg" not in data:
|
||||
data["error_msg"] = None
|
||||
result[k] = DocProcessingStatus(**data)
|
||||
except KeyError as e:
|
||||
logger.error(
|
||||
f"[{self.workspace}] Missing required field for document {k}: {e}"
|
||||
)
|
||||
continue
|
||||
return result
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
async with self._storage_lock:
|
||||
if self.storage_updated.value:
|
||||
data_dict = (
|
||||
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
||||
)
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}"
|
||||
)
|
||||
|
||||
# Write JSON and check if sanitization was applied
|
||||
needs_reload = write_json(data_dict, self._file_name)
|
||||
|
||||
# If data was sanitized, reload cleaned data to update shared memory
|
||||
if needs_reload:
|
||||
logger.info(
|
||||
f"[{self.workspace}] Reloading sanitized data into shared memory for {self.namespace}"
|
||||
)
|
||||
cleaned_data = load_json(self._file_name)
|
||||
if cleaned_data is not None:
|
||||
self._data.clear()
|
||||
self._data.update(cleaned_data)
|
||||
|
||||
await clear_all_update_flags(self.namespace, workspace=self.workspace)
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
"""
|
||||
if not data:
|
||||
return
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Inserting {len(data)} records to {self.namespace}"
|
||||
)
|
||||
if self._storage_lock is None:
|
||||
raise StorageNotInitializedError("JsonDocStatusStorage")
|
||||
# Prepare data outside the lock: this only mutates the caller-supplied
|
||||
# dict values, not shared storage state, so no lock needed here.
|
||||
for i, (doc_id, doc_data) in enumerate(data.items(), start=1):
|
||||
if "chunks_list" not in doc_data:
|
||||
doc_data["chunks_list"] = []
|
||||
await _cooperative_yield(i)
|
||||
async with self._storage_lock:
|
||||
self._data.update(data)
|
||||
await set_all_update_flags(self.namespace, workspace=self.workspace)
|
||||
|
||||
await self.index_done_callback()
|
||||
|
||||
async def is_empty(self) -> bool:
|
||||
"""Check if the storage is empty
|
||||
|
||||
Returns:
|
||||
bool: True if storage is empty, False otherwise
|
||||
|
||||
Raises:
|
||||
StorageNotInitializedError: If storage is not initialized
|
||||
"""
|
||||
if self._storage_lock is None:
|
||||
raise StorageNotInitializedError("JsonDocStatusStorage")
|
||||
async with self._storage_lock:
|
||||
return len(self._data) == 0
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
async with self._storage_lock:
|
||||
return self._data.get(id)
|
||||
|
||||
async def get_docs_paginated(
|
||||
self,
|
||||
status_filter: DocStatus | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
sort_field: str = "updated_at",
|
||||
sort_direction: str = "desc",
|
||||
) -> tuple[list[tuple[str, DocProcessingStatus]], int]:
|
||||
"""Get documents with pagination support
|
||||
|
||||
Args:
|
||||
status_filter: Filter by document status, None for all statuses
|
||||
page: Page number (1-based)
|
||||
page_size: Number of documents per page (10-200)
|
||||
sort_field: Field to sort by ('created_at', 'updated_at', 'id')
|
||||
sort_direction: Sort direction ('asc' or 'desc')
|
||||
|
||||
Returns:
|
||||
Tuple of (list of (doc_id, DocProcessingStatus) tuples, total_count)
|
||||
"""
|
||||
# Validate parameters
|
||||
if page < 1:
|
||||
page = 1
|
||||
if page_size < 10:
|
||||
page_size = 10
|
||||
elif page_size > 200:
|
||||
page_size = 200
|
||||
|
||||
if sort_field not in ["created_at", "updated_at", "id", "file_path"]:
|
||||
sort_field = "updated_at"
|
||||
|
||||
if sort_direction.lower() not in ["asc", "desc"]:
|
||||
sort_direction = "desc"
|
||||
|
||||
# For JSON storage, we load all data and sort/filter in memory
|
||||
all_docs = []
|
||||
|
||||
async with self._storage_lock:
|
||||
for doc_id, doc_data in self._data.items():
|
||||
# Apply status filter
|
||||
if (
|
||||
status_filter is not None
|
||||
and doc_data.get("status") != status_filter.value
|
||||
):
|
||||
continue
|
||||
|
||||
try:
|
||||
# Prepare document data
|
||||
data = doc_data.copy()
|
||||
data.pop("content", None)
|
||||
if not data.get("file_path"):
|
||||
data["file_path"] = "no-file-path"
|
||||
if "metadata" not in data:
|
||||
data["metadata"] = {}
|
||||
if "error_msg" not in data:
|
||||
data["error_msg"] = None
|
||||
|
||||
doc_status = DocProcessingStatus(**data)
|
||||
|
||||
# Add sort key for sorting
|
||||
if sort_field == "id":
|
||||
doc_status._sort_key = doc_id
|
||||
elif sort_field == "file_path":
|
||||
# Use pinyin sorting for file_path field to support Chinese characters
|
||||
file_path_value = getattr(doc_status, sort_field, "")
|
||||
doc_status._sort_key = get_pinyin_sort_key(file_path_value)
|
||||
else:
|
||||
doc_status._sort_key = getattr(doc_status, sort_field, "")
|
||||
|
||||
all_docs.append((doc_id, doc_status))
|
||||
|
||||
except KeyError as e:
|
||||
logger.error(
|
||||
f"[{self.workspace}] Error processing document {doc_id}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Sort documents
|
||||
reverse_sort = sort_direction.lower() == "desc"
|
||||
all_docs.sort(
|
||||
key=lambda x: getattr(x[1], "_sort_key", ""), reverse=reverse_sort
|
||||
)
|
||||
|
||||
# Remove sort key from documents
|
||||
for doc_id, doc in all_docs:
|
||||
if hasattr(doc, "_sort_key"):
|
||||
delattr(doc, "_sort_key")
|
||||
|
||||
total_count = len(all_docs)
|
||||
|
||||
# Apply pagination
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
paginated_docs = all_docs[start_idx:end_idx]
|
||||
|
||||
return paginated_docs, total_count
|
||||
|
||||
async def get_all_status_counts(self) -> dict[str, int]:
|
||||
"""Get counts of documents in each status for all documents
|
||||
|
||||
Returns:
|
||||
Dictionary mapping status names to counts, including 'all' field
|
||||
"""
|
||||
counts = await self.get_status_counts()
|
||||
|
||||
# Add 'all' field with total count
|
||||
total_count = sum(counts.values())
|
||||
counts["all"] = total_count
|
||||
|
||||
return counts
|
||||
|
||||
async def delete(self, doc_ids: list[str]) -> None:
|
||||
"""Delete specific records from storage by their IDs
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
|
||||
Args:
|
||||
ids (list[str]): List of document IDs to be deleted from storage
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
async with self._storage_lock:
|
||||
any_deleted = False
|
||||
for doc_id in doc_ids:
|
||||
result = self._data.pop(doc_id, None)
|
||||
if result is not None:
|
||||
any_deleted = True
|
||||
|
||||
if any_deleted:
|
||||
await set_all_update_flags(self.namespace, workspace=self.workspace)
|
||||
|
||||
async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]:
|
||||
"""Get document by file path
|
||||
|
||||
Args:
|
||||
file_path: The file path to search for
|
||||
|
||||
Returns:
|
||||
Union[dict[str, Any], None]: Document data if found, None otherwise
|
||||
Returns the same format as get_by_ids method
|
||||
"""
|
||||
if self._storage_lock is None:
|
||||
raise StorageNotInitializedError("JsonDocStatusStorage")
|
||||
|
||||
async with self._storage_lock:
|
||||
for doc_id, doc_data in self._data.items():
|
||||
if doc_data.get("file_path") == file_path:
|
||||
# Return complete document data, consistent with get_by_ids method
|
||||
return doc_data
|
||||
|
||||
return None
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all document status data from storage and clean up resources
|
||||
|
||||
This method will:
|
||||
1. Clear all document status data from memory
|
||||
2. Update flags to notify other processes
|
||||
3. Trigger index_done_callback to save the empty state
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
async with self._storage_lock:
|
||||
self._data.clear()
|
||||
await set_all_update_flags(self.namespace, workspace=self.workspace)
|
||||
|
||||
await self.index_done_callback()
|
||||
logger.info(
|
||||
f"[{self.workspace}] Process {os.getpid()} drop {self.namespace}"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.workspace}] Error dropping {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
307
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/json_kv_impl.py
Normal file
307
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/json_kv_impl.py
Normal file
@@ -0,0 +1,307 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, final
|
||||
|
||||
from lightrag.base import (
|
||||
BaseKVStorage,
|
||||
)
|
||||
from lightrag.utils import (
|
||||
_cooperative_yield,
|
||||
load_json,
|
||||
logger,
|
||||
write_json,
|
||||
)
|
||||
from lightrag.exceptions import StorageNotInitializedError
|
||||
from .shared_storage import (
|
||||
get_namespace_data,
|
||||
get_namespace_lock,
|
||||
get_data_init_lock,
|
||||
get_update_flag,
|
||||
set_all_update_flags,
|
||||
clear_all_update_flags,
|
||||
try_initialize_namespace,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class JsonKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
if self.workspace:
|
||||
# Include workspace in the file path for data isolation
|
||||
workspace_dir = os.path.join(working_dir, self.workspace)
|
||||
else:
|
||||
# Default behavior when workspace is empty
|
||||
workspace_dir = working_dir
|
||||
self.workspace = ""
|
||||
|
||||
os.makedirs(workspace_dir, exist_ok=True)
|
||||
self._file_name = os.path.join(workspace_dir, f"kv_store_{self.namespace}.json")
|
||||
self._data = None
|
||||
self._storage_lock = None
|
||||
self.storage_updated = None
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize storage data"""
|
||||
self._storage_lock = get_namespace_lock(
|
||||
self.namespace, workspace=self.workspace
|
||||
)
|
||||
self.storage_updated = await get_update_flag(
|
||||
self.namespace, workspace=self.workspace
|
||||
)
|
||||
async with get_data_init_lock():
|
||||
# check need_init must before get_namespace_data
|
||||
need_init = await try_initialize_namespace(
|
||||
self.namespace, workspace=self.workspace
|
||||
)
|
||||
self._data = await get_namespace_data(
|
||||
self.namespace, workspace=self.workspace
|
||||
)
|
||||
if need_init:
|
||||
loaded_data = load_json(self._file_name) or {}
|
||||
async with self._storage_lock:
|
||||
# Migrate legacy cache structure if needed
|
||||
if self.namespace.endswith("_cache"):
|
||||
loaded_data = await self._migrate_legacy_cache_structure(
|
||||
loaded_data
|
||||
)
|
||||
|
||||
self._data.update(loaded_data)
|
||||
data_count = len(loaded_data)
|
||||
|
||||
logger.info(
|
||||
f"[{self.workspace}] Process {os.getpid()} KV load {self.namespace} with {data_count} records"
|
||||
)
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
async with self._storage_lock:
|
||||
if self.storage_updated.value:
|
||||
data_dict = (
|
||||
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
||||
)
|
||||
|
||||
# Calculate data count - all data is now flattened
|
||||
data_count = len(data_dict)
|
||||
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
|
||||
)
|
||||
|
||||
# Write JSON and check if sanitization was applied
|
||||
needs_reload = write_json(data_dict, self._file_name)
|
||||
|
||||
# If data was sanitized, reload cleaned data to update shared memory
|
||||
if needs_reload:
|
||||
logger.info(
|
||||
f"[{self.workspace}] Reloading sanitized data into shared memory for {self.namespace}"
|
||||
)
|
||||
cleaned_data = load_json(self._file_name)
|
||||
if cleaned_data is not None:
|
||||
self._data.clear()
|
||||
self._data.update(cleaned_data)
|
||||
|
||||
await clear_all_update_flags(self.namespace, workspace=self.workspace)
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
async with self._storage_lock:
|
||||
result = self._data.get(id)
|
||||
if result:
|
||||
# Create a copy to avoid modifying the original data
|
||||
result = dict(result)
|
||||
# Ensure time fields are present, provide default values for old data
|
||||
result.setdefault("create_time", 0)
|
||||
result.setdefault("update_time", 0)
|
||||
# Ensure _id field contains the clean ID
|
||||
result["_id"] = id
|
||||
return result
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
async with self._storage_lock:
|
||||
results = []
|
||||
for id in ids:
|
||||
data = self._data.get(id, None)
|
||||
if data:
|
||||
# Create a copy to avoid modifying the original data
|
||||
result = {k: v for k, v in data.items()}
|
||||
# Ensure time fields are present, provide default values for old data
|
||||
result.setdefault("create_time", 0)
|
||||
result.setdefault("update_time", 0)
|
||||
# Ensure _id field contains the clean ID
|
||||
result["_id"] = id
|
||||
results.append(result)
|
||||
else:
|
||||
results.append(None)
|
||||
return results
|
||||
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
async with self._storage_lock:
|
||||
return set(keys) - set(self._data.keys())
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
"""
|
||||
if not data:
|
||||
return
|
||||
|
||||
import time
|
||||
|
||||
current_time = int(time.time()) # Get current Unix timestamp
|
||||
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Inserting {len(data)} records to {self.namespace}"
|
||||
)
|
||||
if self._storage_lock is None:
|
||||
raise StorageNotInitializedError("JsonKVStorage")
|
||||
async with self._storage_lock:
|
||||
# Add timestamps to data based on whether key exists.
|
||||
# The loop reads self._data (k in self._data) so it must stay inside
|
||||
# the lock. _cooperative_yield is safe here: NamespaceLock is
|
||||
# non-reentrant, so other coroutines waiting on this lock will block
|
||||
# until we release it; the yield only benefits unrelated coroutines.
|
||||
for i, (k, v) in enumerate(data.items(), start=1):
|
||||
# For text_chunks namespace, ensure llm_cache_list field exists
|
||||
if self.namespace.endswith("text_chunks"):
|
||||
if "llm_cache_list" not in v:
|
||||
v["llm_cache_list"] = []
|
||||
|
||||
# Add timestamps based on whether key exists
|
||||
if k in self._data: # Key exists, only update update_time
|
||||
v["update_time"] = current_time
|
||||
else: # New key, set both create_time and update_time
|
||||
v["create_time"] = current_time
|
||||
v["update_time"] = current_time
|
||||
|
||||
v["_id"] = k
|
||||
await _cooperative_yield(i)
|
||||
|
||||
self._data.update(data)
|
||||
await set_all_update_flags(self.namespace, workspace=self.workspace)
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
"""Delete specific records from storage by their IDs
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
|
||||
Args:
|
||||
ids (list[str]): List of document IDs to be deleted from storage
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
async with self._storage_lock:
|
||||
any_deleted = False
|
||||
for doc_id in ids:
|
||||
result = self._data.pop(doc_id, None)
|
||||
if result is not None:
|
||||
any_deleted = True
|
||||
|
||||
if any_deleted:
|
||||
await set_all_update_flags(self.namespace, workspace=self.workspace)
|
||||
|
||||
async def is_empty(self) -> bool:
|
||||
"""Check if the storage is empty
|
||||
|
||||
Returns:
|
||||
bool: True if storage contains no data, False otherwise
|
||||
"""
|
||||
async with self._storage_lock:
|
||||
return len(self._data) == 0
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all data from storage and clean up resources
|
||||
This action will persistent the data to disk immediately.
|
||||
|
||||
This method will:
|
||||
1. Clear all data from memory
|
||||
2. Update flags to notify other processes
|
||||
3. Trigger index_done_callback to save the empty state
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
async with self._storage_lock:
|
||||
self._data.clear()
|
||||
await set_all_update_flags(self.namespace, workspace=self.workspace)
|
||||
|
||||
await self.index_done_callback()
|
||||
logger.info(
|
||||
f"[{self.workspace}] Process {os.getpid()} drop {self.namespace}"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.workspace}] Error dropping {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def _migrate_legacy_cache_structure(self, data: dict) -> dict:
|
||||
"""Migrate legacy nested cache structure to flattened structure
|
||||
|
||||
Args:
|
||||
data: Original data dictionary that may contain legacy structure
|
||||
|
||||
Returns:
|
||||
Migrated data dictionary with flattened cache keys (sanitized if needed)
|
||||
"""
|
||||
from lightrag.utils import generate_cache_key
|
||||
|
||||
# Early return if data is empty
|
||||
if not data:
|
||||
return data
|
||||
|
||||
# Check first entry to see if it's already in new format
|
||||
first_key = next(iter(data.keys()))
|
||||
if ":" in first_key and len(first_key.split(":")) == 3:
|
||||
# Already in flattened format, return as-is
|
||||
return data
|
||||
|
||||
migrated_data = {}
|
||||
migration_count = 0
|
||||
|
||||
for key, value in data.items():
|
||||
# Check if this is a legacy nested cache structure
|
||||
if isinstance(value, dict) and all(
|
||||
isinstance(v, dict) and "return" in v for v in value.values()
|
||||
):
|
||||
# This looks like a legacy cache mode with nested structure
|
||||
mode = key
|
||||
for cache_hash, cache_entry in value.items():
|
||||
cache_type = cache_entry.get("cache_type", "extract")
|
||||
flattened_key = generate_cache_key(mode, cache_type, cache_hash)
|
||||
migrated_data[flattened_key] = cache_entry
|
||||
migration_count += 1
|
||||
else:
|
||||
# Keep non-cache data or already flattened cache data as-is
|
||||
migrated_data[key] = value
|
||||
|
||||
if migration_count > 0:
|
||||
logger.info(
|
||||
f"[{self.workspace}] Migrated {migration_count} legacy cache entries to flattened structure"
|
||||
)
|
||||
# Persist migrated data immediately and check if sanitization was applied
|
||||
needs_reload = write_json(migrated_data, self._file_name)
|
||||
|
||||
# If data was sanitized during write, reload cleaned data
|
||||
if needs_reload:
|
||||
logger.info(
|
||||
f"[{self.workspace}] Reloading sanitized migration data for {self.namespace}"
|
||||
)
|
||||
cleaned_data = load_json(self._file_name)
|
||||
if cleaned_data is not None:
|
||||
return cleaned_data # Return cleaned data to update shared memory
|
||||
|
||||
return migrated_data
|
||||
|
||||
async def finalize(self):
|
||||
"""Finalize storage resources
|
||||
Persistence cache data to disk before exiting
|
||||
"""
|
||||
if self.namespace.endswith("_cache"):
|
||||
await self.index_done_callback()
|
||||
1346
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/memgraph_impl.py
Normal file
1346
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/memgraph_impl.py
Normal file
File diff suppressed because it is too large
Load Diff
1803
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/milvus_impl.py
Normal file
1803
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/milvus_impl.py
Normal file
File diff suppressed because it is too large
Load Diff
2640
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/mongo_impl.py
Normal file
2640
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/mongo_impl.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,430 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
import zlib
|
||||
from typing import Any, final
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
from lightrag.utils import (
|
||||
logger,
|
||||
compute_mdhash_id,
|
||||
)
|
||||
|
||||
from lightrag.base import BaseVectorStorage
|
||||
from nano_vectordb import NanoVectorDB
|
||||
from .shared_storage import (
|
||||
get_namespace_lock,
|
||||
get_update_flag,
|
||||
set_all_update_flags,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class NanoVectorDBStorage(BaseVectorStorage):
|
||||
def __post_init__(self):
|
||||
self._validate_embedding_func()
|
||||
# Initialize basic attributes
|
||||
self._client = None
|
||||
self._storage_lock = None
|
||||
self.storage_updated = None
|
||||
|
||||
# Use global config value if specified, otherwise use default
|
||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||
if cosine_threshold is None:
|
||||
raise ValueError(
|
||||
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
working_dir = self.global_config["working_dir"]
|
||||
if self.workspace:
|
||||
# Include workspace in the file path for data isolation
|
||||
workspace_dir = os.path.join(working_dir, self.workspace)
|
||||
self.final_namespace = f"{self.workspace}_{self.namespace}"
|
||||
else:
|
||||
# Default behavior when workspace is empty
|
||||
self.final_namespace = self.namespace
|
||||
self.workspace = ""
|
||||
workspace_dir = working_dir
|
||||
|
||||
os.makedirs(workspace_dir, exist_ok=True)
|
||||
self._client_file_name = os.path.join(
|
||||
workspace_dir, f"vdb_{self.namespace}.json"
|
||||
)
|
||||
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim,
|
||||
storage_file=self._client_file_name,
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize storage data"""
|
||||
# Get the update flag for cross-process update notification
|
||||
self.storage_updated = await get_update_flag(
|
||||
self.namespace, workspace=self.workspace
|
||||
)
|
||||
# Get the storage lock for use in other methods
|
||||
self._storage_lock = get_namespace_lock(
|
||||
self.namespace, workspace=self.workspace
|
||||
)
|
||||
|
||||
async def _get_client(self):
|
||||
"""Check if the storage should be reloaded"""
|
||||
# Acquire lock to prevent concurrent read and write
|
||||
async with self._storage_lock:
|
||||
# Check if data needs to be reloaded
|
||||
if self.storage_updated.value:
|
||||
logger.info(
|
||||
f"[{self.workspace}] Process {os.getpid()} reloading {self.namespace} due to update by another process"
|
||||
)
|
||||
# Reload data
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim,
|
||||
storage_file=self._client_file_name,
|
||||
)
|
||||
# Reset update flag
|
||||
self.storage_updated.value = False
|
||||
|
||||
return self._client
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
# logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
|
||||
current_time = int(time.time())
|
||||
list_data = [
|
||||
{
|
||||
"__id__": k,
|
||||
"__created_at__": current_time,
|
||||
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
||||
}
|
||||
for k, v in data.items()
|
||||
]
|
||||
contents = [v["content"] for v in data.values()]
|
||||
batches = [
|
||||
contents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
|
||||
# Execute embedding outside of lock to avoid long lock times
|
||||
embedding_tasks = [
|
||||
self.embedding_func(batch, context="document") for batch in batches
|
||||
]
|
||||
embeddings_list = await asyncio.gather(*embedding_tasks)
|
||||
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
if len(embeddings) == len(list_data):
|
||||
for i, d in enumerate(list_data):
|
||||
# Compress vector using Float16 + zlib + Base64 for storage optimization
|
||||
vector_f16 = embeddings[i].astype(np.float16)
|
||||
compressed_vector = zlib.compress(vector_f16.tobytes())
|
||||
encoded_vector = base64.b64encode(compressed_vector).decode("utf-8")
|
||||
d["vector"] = encoded_vector
|
||||
d["__vector__"] = embeddings[i]
|
||||
client = await self._get_client()
|
||||
results = client.upsert(datas=list_data)
|
||||
return results
|
||||
else:
|
||||
# sometimes the embedding is not returned correctly. just log it.
|
||||
logger.error(
|
||||
f"[{self.workspace}] embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
|
||||
)
|
||||
|
||||
async def query(
|
||||
self, query: str, top_k: int, query_embedding: list[float] = None
|
||||
) -> list[dict[str, Any]]:
|
||||
# Use provided embedding or compute it
|
||||
if query_embedding is not None:
|
||||
embedding = query_embedding
|
||||
else:
|
||||
# Execute embedding outside of lock to avoid improve cocurrent
|
||||
embedding = await self.embedding_func(
|
||||
[query], context="query", _priority=5
|
||||
) # higher priority for query
|
||||
embedding = embedding[0]
|
||||
|
||||
client = await self._get_client()
|
||||
results = client.query(
|
||||
query=embedding,
|
||||
top_k=top_k,
|
||||
better_than_threshold=self.cosine_better_than_threshold,
|
||||
)
|
||||
results = [
|
||||
{
|
||||
**{k: v for k, v in dp.items() if k != "vector"},
|
||||
"id": dp["__id__"],
|
||||
"distance": dp["__metrics__"],
|
||||
"created_at": dp.get("__created_at__"),
|
||||
}
|
||||
for dp in results
|
||||
]
|
||||
return results
|
||||
|
||||
@property
|
||||
async def client_storage(self):
|
||||
client = await self._get_client()
|
||||
return getattr(client, "_NanoVectorDB__storage")
|
||||
|
||||
async def delete(self, ids: list[str]):
|
||||
"""Delete vectors with specified IDs
|
||||
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
|
||||
Args:
|
||||
ids: List of vector IDs to be deleted
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
# Record count before deletion
|
||||
before_count = len(client)
|
||||
|
||||
client.delete(ids)
|
||||
|
||||
# Calculate actual deleted count
|
||||
after_count = len(client)
|
||||
deleted_count = before_count - after_count
|
||||
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Successfully deleted {deleted_count} vectors from {self.namespace}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[{self.workspace}] Error while deleting vectors from {self.namespace}: {e}"
|
||||
)
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
|
||||
try:
|
||||
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Attempting to delete entity {entity_name} with ID {entity_id}"
|
||||
)
|
||||
|
||||
# Check if the entity exists
|
||||
client = await self._get_client()
|
||||
if client.get([entity_id]):
|
||||
client.delete([entity_id])
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Successfully deleted entity {entity_name}"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Entity {entity_name} not found in storage"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.workspace}] Error deleting entity {entity_name}: {e}")
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
|
||||
try:
|
||||
client = await self._get_client()
|
||||
storage = getattr(client, "_NanoVectorDB__storage")
|
||||
relations = [
|
||||
dp
|
||||
for dp in storage["data"]
|
||||
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
|
||||
]
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Found {len(relations)} relations for entity {entity_name}"
|
||||
)
|
||||
ids_to_delete = [relation["__id__"] for relation in relations]
|
||||
|
||||
if ids_to_delete:
|
||||
client = await self._get_client()
|
||||
client.delete(ids_to_delete)
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Deleted {len(ids_to_delete)} relations for {entity_name}"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"[{self.workspace}] No relations found for entity {entity_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[{self.workspace}] Error deleting relations for {entity_name}: {e}"
|
||||
)
|
||||
|
||||
async def index_done_callback(self) -> bool:
|
||||
"""Save data to disk"""
|
||||
async with self._storage_lock:
|
||||
# Check if storage was updated by another process
|
||||
if self.storage_updated.value:
|
||||
# Storage was updated by another process, reload data instead of saving
|
||||
logger.warning(
|
||||
f"[{self.workspace}] Storage for {self.namespace} was updated by another process, reloading..."
|
||||
)
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim,
|
||||
storage_file=self._client_file_name,
|
||||
)
|
||||
# Reset update flag
|
||||
self.storage_updated.value = False
|
||||
return False # Return error
|
||||
|
||||
# Acquire lock and perform persistence
|
||||
async with self._storage_lock:
|
||||
try:
|
||||
# Save data to disk
|
||||
self._client.save()
|
||||
# Notify other processes that data has been updated
|
||||
await set_all_update_flags(self.namespace, workspace=self.workspace)
|
||||
# Reset own update flag to avoid self-reloading
|
||||
self.storage_updated.value = False
|
||||
return True # Return success
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[{self.workspace}] Error saving data for {self.namespace}: {e}"
|
||||
)
|
||||
return False # Return error
|
||||
|
||||
return True # Return success
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
client = await self._get_client()
|
||||
result = client.get([id])
|
||||
if result:
|
||||
dp = result[0]
|
||||
return {
|
||||
**{k: v for k, v in dp.items() if k != "vector"},
|
||||
"id": dp.get("__id__"),
|
||||
"created_at": dp.get("__created_at__"),
|
||||
}
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
client = await self._get_client()
|
||||
results = client.get(ids)
|
||||
result_map: dict[str, dict[str, Any]] = {}
|
||||
|
||||
for dp in results:
|
||||
if not dp:
|
||||
continue
|
||||
record = {
|
||||
**{k: v for k, v in dp.items() if k != "vector"},
|
||||
"id": dp.get("__id__"),
|
||||
"created_at": dp.get("__created_at__"),
|
||||
}
|
||||
key = record.get("id")
|
||||
if key is not None:
|
||||
result_map[str(key)] = record
|
||||
|
||||
ordered_results: list[dict[str, Any] | None] = []
|
||||
for requested_id in ids:
|
||||
ordered_results.append(result_map.get(str(requested_id)))
|
||||
|
||||
return ordered_results
|
||||
|
||||
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
|
||||
"""Get vectors by their IDs, returning only ID and vector data for efficiency
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
Dictionary mapping IDs to their vector embeddings
|
||||
Format: {id: [vector_values], ...}
|
||||
"""
|
||||
if not ids:
|
||||
return {}
|
||||
|
||||
client = await self._get_client()
|
||||
results = client.get(ids)
|
||||
|
||||
vectors_dict = {}
|
||||
for result in results:
|
||||
if result and "vector" in result and "__id__" in result:
|
||||
# Decompress vector data (Base64 + zlib + Float16 compressed)
|
||||
decoded = base64.b64decode(result["vector"])
|
||||
decompressed = zlib.decompress(decoded)
|
||||
vector_f16 = np.frombuffer(decompressed, dtype=np.float16)
|
||||
vector_f32 = vector_f16.astype(np.float32).tolist()
|
||||
vectors_dict[result["__id__"]] = vector_f32
|
||||
|
||||
return vectors_dict
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all vector data from storage and clean up resources
|
||||
|
||||
This method will:
|
||||
1. Remove the vector database storage file if it exists
|
||||
2. Reinitialize the vector database client
|
||||
3. Update flags to notify other processes
|
||||
4. Changes is persisted to disk immediately
|
||||
|
||||
This method is intended for use in scenarios where all data needs to be removed,
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
async with self._storage_lock:
|
||||
# delete _client_file_name
|
||||
if os.path.exists(self._client_file_name):
|
||||
os.remove(self._client_file_name)
|
||||
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim,
|
||||
storage_file=self._client_file_name,
|
||||
)
|
||||
|
||||
# Notify other processes that data has been updated
|
||||
await set_all_update_flags(self.namespace, workspace=self.workspace)
|
||||
# Reset own update flag to avoid self-reloading
|
||||
self.storage_updated.value = False
|
||||
|
||||
logger.info(
|
||||
f"[{self.workspace}] Process {os.getpid()} drop {self.namespace}(file:{self._client_file_name})"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.workspace}] Error dropping {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
2019
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/neo4j_impl.py
Normal file
2019
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/neo4j_impl.py
Normal file
File diff suppressed because it is too large
Load Diff
606
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/networkx_impl.py
Normal file
606
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/networkx_impl.py
Normal file
@@ -0,0 +1,606 @@
|
||||
import os
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import final
|
||||
|
||||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
from lightrag.utils import logger
|
||||
from lightrag.base import BaseGraphStorage
|
||||
import networkx as nx
|
||||
from .shared_storage import (
|
||||
get_namespace_lock,
|
||||
get_update_flag,
|
||||
set_all_update_flags,
|
||||
)
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# use the .env that is inside the current folder
|
||||
# allows to use different .env file for each lightrag instance
|
||||
# the OS environment variables take precedence over the .env file
|
||||
load_dotenv(dotenv_path=".env", override=False)
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class NetworkXStorage(BaseGraphStorage):
|
||||
@staticmethod
|
||||
def load_nx_graph(file_name) -> nx.Graph:
|
||||
if os.path.exists(file_name):
|
||||
return nx.read_graphml(file_name)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def write_nx_graph(graph: nx.Graph, file_name, workspace="_"):
|
||||
logger.info(
|
||||
f"[{workspace}] Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
|
||||
)
|
||||
nx.write_graphml(graph, file_name)
|
||||
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
if self.workspace:
|
||||
# Include workspace in the file path for data isolation
|
||||
workspace_dir = os.path.join(working_dir, self.workspace)
|
||||
else:
|
||||
# Default behavior when workspace is empty
|
||||
workspace_dir = working_dir
|
||||
self.workspace = ""
|
||||
|
||||
os.makedirs(workspace_dir, exist_ok=True)
|
||||
self._graphml_xml_file = os.path.join(
|
||||
workspace_dir, f"graph_{self.namespace}.graphml"
|
||||
)
|
||||
self._storage_lock = None
|
||||
self.storage_updated = None
|
||||
self._graph = None
|
||||
|
||||
# Load initial graph
|
||||
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
||||
if preloaded_graph is not None:
|
||||
logger.info(
|
||||
f"[{self.workspace}] Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"[{self.workspace}] Created new empty graph file: {self._graphml_xml_file}"
|
||||
)
|
||||
self._graph = preloaded_graph or nx.Graph()
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize storage data"""
|
||||
# Get the update flag for cross-process update notification
|
||||
self.storage_updated = await get_update_flag(
|
||||
self.namespace, workspace=self.workspace
|
||||
)
|
||||
# Get the storage lock for use in other methods
|
||||
self._storage_lock = get_namespace_lock(
|
||||
self.namespace, workspace=self.workspace
|
||||
)
|
||||
|
||||
async def _get_graph(self):
|
||||
"""Check if the storage should be reloaded"""
|
||||
# Acquire lock to prevent concurrent read and write
|
||||
async with self._storage_lock:
|
||||
# Check if data needs to be reloaded
|
||||
if self.storage_updated.value:
|
||||
logger.info(
|
||||
f"[{self.workspace}] Process {os.getpid()} reloading graph {self._graphml_xml_file} due to modifications by another process"
|
||||
)
|
||||
# Reload data
|
||||
self._graph = (
|
||||
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
|
||||
)
|
||||
# Reset update flag
|
||||
self.storage_updated.value = False
|
||||
|
||||
return self._graph
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
graph = await self._get_graph()
|
||||
return graph.has_node(node_id)
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
graph = await self._get_graph()
|
||||
return graph.has_edge(source_node_id, target_node_id)
|
||||
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
graph = await self._get_graph()
|
||||
return graph.nodes.get(node_id)
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
graph = await self._get_graph()
|
||||
if graph.has_node(node_id):
|
||||
return graph.degree(node_id)
|
||||
return 0
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
graph = await self._get_graph()
|
||||
src_degree = graph.degree(src_id) if graph.has_node(src_id) else 0
|
||||
tgt_degree = graph.degree(tgt_id) if graph.has_node(tgt_id) else 0
|
||||
return src_degree + tgt_degree
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> dict[str, str] | None:
|
||||
graph = await self._get_graph()
|
||||
return graph.edges.get((source_node_id, target_node_id))
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
graph = await self._get_graph()
|
||||
if graph.has_node(source_node_id):
|
||||
return list(graph.edges(source_node_id))
|
||||
return None
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
graph.add_node(node_id, **node_data)
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||
|
||||
async def upsert_nodes_batch(self, nodes: list[tuple[str, dict[str, str]]]) -> None:
|
||||
"""Batch insert/update multiple nodes in a single call.
|
||||
|
||||
Much faster than calling upsert_node() in a loop for large imports
|
||||
because it avoids per-call async event loop overhead.
|
||||
|
||||
Args:
|
||||
nodes: List of (node_id, node_data) tuples.
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
for node_id, node_data in nodes:
|
||||
graph.add_node(node_id, **node_data)
|
||||
|
||||
async def has_nodes_batch(self, node_ids: list[str]) -> set[str]:
|
||||
"""Check existence of multiple nodes in a single call.
|
||||
|
||||
Returns:
|
||||
Set of node_ids that exist in the graph.
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
return {nid for nid in node_ids if graph.has_node(nid)}
|
||||
|
||||
async def upsert_edges_batch(
|
||||
self, edges: list[tuple[str, str, dict[str, str]]]
|
||||
) -> None:
|
||||
"""Batch insert/update multiple edges in a single call.
|
||||
|
||||
Args:
|
||||
edges: List of (source_id, target_id, edge_data) tuples.
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
for src, tgt, edge_data in edges:
|
||||
graph.add_edge(src, tgt, **edge_data)
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
if graph.has_node(node_id):
|
||||
graph.remove_node(node_id)
|
||||
logger.debug(f"[{self.workspace}] Node {node_id} deleted from the graph")
|
||||
else:
|
||||
logger.warning(
|
||||
f"[{self.workspace}] Node {node_id} not found in the graph for deletion"
|
||||
)
|
||||
|
||||
async def remove_nodes(self, nodes: list[str]):
|
||||
"""Delete multiple nodes
|
||||
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
|
||||
Args:
|
||||
nodes: List of node IDs to be deleted
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
for node in nodes:
|
||||
if graph.has_node(node):
|
||||
graph.remove_node(node)
|
||||
|
||||
async def remove_edges(self, edges: list[tuple[str, str]]):
|
||||
"""Delete multiple edges
|
||||
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
|
||||
Args:
|
||||
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
for source, target in edges:
|
||||
if graph.has_edge(source, target):
|
||||
graph.remove_edge(source, target)
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
"""
|
||||
Get all node labels(entity names) in the graph
|
||||
Returns:
|
||||
[label1, label2, ...] # Alphabetically sorted label list
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
labels = set()
|
||||
for node in graph.nodes():
|
||||
labels.add(str(node)) # Add node id as a label
|
||||
|
||||
# Return sorted list
|
||||
return sorted(list(labels))
|
||||
|
||||
async def get_popular_labels(self, limit: int = 300) -> list[str]:
|
||||
"""
|
||||
Get popular labels(entity names) by node degree (most connected entities)
|
||||
|
||||
Args:
|
||||
limit: Maximum number of labels to return
|
||||
|
||||
Returns:
|
||||
List of labels sorted by degree (highest first)
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
|
||||
# Get degrees of all nodes and sort by degree descending
|
||||
degrees = dict(graph.degree())
|
||||
sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Return top labels limited by the specified limit
|
||||
popular_labels = [str(node) for node, _ in sorted_nodes[:limit]]
|
||||
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Retrieved {len(popular_labels)} popular labels (limit: {limit})"
|
||||
)
|
||||
|
||||
return popular_labels
|
||||
|
||||
async def search_labels(self, query: str, limit: int = 50) -> list[str]:
|
||||
"""
|
||||
Search labels(entity names) with fuzzy matching
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of matching labels sorted by relevance
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
query_lower = query.lower().strip()
|
||||
|
||||
if not query_lower:
|
||||
return []
|
||||
|
||||
# Collect matching nodes with relevance scores
|
||||
matches = []
|
||||
for node in graph.nodes():
|
||||
node_str = str(node)
|
||||
node_lower = node_str.lower()
|
||||
|
||||
# Skip if no match
|
||||
if query_lower not in node_lower:
|
||||
continue
|
||||
|
||||
# Calculate relevance score
|
||||
# Exact match gets highest score
|
||||
if node_lower == query_lower:
|
||||
score = 1000
|
||||
# Prefix match gets high score
|
||||
elif node_lower.startswith(query_lower):
|
||||
score = 500
|
||||
# Contains match gets base score, with bonus for shorter strings
|
||||
else:
|
||||
# Shorter strings with matches are more relevant
|
||||
score = 100 - len(node_str)
|
||||
# Bonus for word boundary matches
|
||||
if f" {query_lower}" in node_lower or f"_{query_lower}" in node_lower:
|
||||
score += 50
|
||||
|
||||
matches.append((node_str, score))
|
||||
|
||||
# Sort by relevance score (desc) then alphabetically
|
||||
matches.sort(key=lambda x: (-x[1], x[0]))
|
||||
|
||||
# Return top matches limited by the specified limit
|
||||
search_results = [match[0] for match in matches[:limit]]
|
||||
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Search query '{query}' returned {len(search_results)} results (limit: {limit})"
|
||||
)
|
||||
|
||||
return search_results
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self,
|
||||
node_label: str,
|
||||
max_depth: int = 3,
|
||||
max_nodes: int = None,
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
||||
|
||||
Args:
|
||||
node_label: Label of the starting node,* means all nodes
|
||||
max_depth: Maximum depth of the subgraph, Defaults to 3
|
||||
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||
indicating whether the graph was truncated due to max_nodes limit
|
||||
"""
|
||||
# Get max_nodes from global_config if not provided
|
||||
if max_nodes is None:
|
||||
max_nodes = self.global_config.get("max_graph_nodes", 1000)
|
||||
else:
|
||||
# Limit max_nodes to not exceed global_config max_graph_nodes
|
||||
max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000))
|
||||
|
||||
graph = await self._get_graph()
|
||||
|
||||
result = KnowledgeGraph()
|
||||
|
||||
# Handle special case for "*" label
|
||||
if node_label == "*":
|
||||
# Get degrees of all nodes
|
||||
degrees = dict(graph.degree())
|
||||
# Sort nodes by degree in descending order and take top max_nodes
|
||||
sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Check if graph is truncated
|
||||
if len(sorted_nodes) > max_nodes:
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
f"[{self.workspace}] Graph truncated: {len(sorted_nodes)} nodes found, limited to {max_nodes}"
|
||||
)
|
||||
|
||||
limited_nodes = [node for node, _ in sorted_nodes[:max_nodes]]
|
||||
# Create subgraph with the highest degree nodes
|
||||
subgraph = graph.subgraph(limited_nodes)
|
||||
else:
|
||||
# Check if node exists
|
||||
if node_label not in graph:
|
||||
logger.warning(
|
||||
f"[{self.workspace}] Node {node_label} not found in the graph"
|
||||
)
|
||||
return KnowledgeGraph() # Return empty graph
|
||||
|
||||
# Use modified BFS to get nodes, prioritizing high-degree nodes at the same depth
|
||||
bfs_nodes = []
|
||||
visited = set()
|
||||
# Store (node, depth, degree) in the queue
|
||||
queue = deque([(node_label, 0, graph.degree(node_label))])
|
||||
|
||||
# Flag to track if there are unexplored neighbors due to depth limit
|
||||
has_unexplored_neighbors = False
|
||||
|
||||
# Modified breadth-first search with degree-based prioritization
|
||||
while queue and len(bfs_nodes) < max_nodes:
|
||||
# Get the current depth from the first node in queue
|
||||
current_depth = queue[0][1]
|
||||
|
||||
# Collect all nodes at the current depth
|
||||
current_level_nodes = []
|
||||
while queue and queue[0][1] == current_depth:
|
||||
current_level_nodes.append(queue.popleft())
|
||||
|
||||
# Sort nodes at current depth by degree (highest first)
|
||||
current_level_nodes.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
# Process all nodes at current depth in order of degree
|
||||
for current_node, depth, degree in current_level_nodes:
|
||||
if current_node not in visited:
|
||||
visited.add(current_node)
|
||||
bfs_nodes.append(current_node)
|
||||
|
||||
# Only explore neighbors if we haven't reached max_depth
|
||||
if depth < max_depth:
|
||||
# Add neighbor nodes to queue with incremented depth
|
||||
neighbors = list(graph.neighbors(current_node))
|
||||
# Filter out already visited neighbors
|
||||
unvisited_neighbors = [
|
||||
n for n in neighbors if n not in visited
|
||||
]
|
||||
# Add neighbors to the queue with their degrees
|
||||
for neighbor in unvisited_neighbors:
|
||||
neighbor_degree = graph.degree(neighbor)
|
||||
queue.append((neighbor, depth + 1, neighbor_degree))
|
||||
else:
|
||||
# Check if there are unexplored neighbors (skipped due to depth limit)
|
||||
neighbors = list(graph.neighbors(current_node))
|
||||
unvisited_neighbors = [
|
||||
n for n in neighbors if n not in visited
|
||||
]
|
||||
if unvisited_neighbors:
|
||||
has_unexplored_neighbors = True
|
||||
|
||||
# Check if we've reached max_nodes
|
||||
if len(bfs_nodes) >= max_nodes:
|
||||
break
|
||||
|
||||
# Check if graph is truncated - either due to max_nodes limit or depth limit
|
||||
if (queue and len(bfs_nodes) >= max_nodes) or has_unexplored_neighbors:
|
||||
if len(bfs_nodes) >= max_nodes:
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
f"[{self.workspace}] Graph truncated: max_nodes limit {max_nodes} reached"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"[{self.workspace}] Graph truncated: found {len(bfs_nodes)} nodes within max_depth {max_depth}"
|
||||
)
|
||||
|
||||
# Create subgraph with BFS discovered nodes
|
||||
subgraph = graph.subgraph(bfs_nodes)
|
||||
|
||||
# Add nodes to result
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
for node in subgraph.nodes():
|
||||
if str(node) in seen_nodes:
|
||||
continue
|
||||
|
||||
node_data = dict(subgraph.nodes[node])
|
||||
# Get entity_type as labels
|
||||
labels = []
|
||||
if "entity_type" in node_data:
|
||||
if isinstance(node_data["entity_type"], list):
|
||||
labels.extend(node_data["entity_type"])
|
||||
else:
|
||||
labels.append(node_data["entity_type"])
|
||||
|
||||
# Create node with properties
|
||||
node_properties = {k: v for k, v in node_data.items()}
|
||||
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=str(node), labels=[str(node)], properties=node_properties
|
||||
)
|
||||
)
|
||||
seen_nodes.add(str(node))
|
||||
|
||||
# Add edges to result
|
||||
for edge in subgraph.edges():
|
||||
source, target = edge
|
||||
# Esure unique edge_id for undirect graph
|
||||
if str(source) > str(target):
|
||||
source, target = target, source
|
||||
edge_id = f"{source}-{target}"
|
||||
if edge_id in seen_edges:
|
||||
continue
|
||||
|
||||
edge_data = dict(subgraph.edges[edge])
|
||||
|
||||
# Create edge with complete information
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
type="DIRECTED",
|
||||
source=str(source),
|
||||
target=str(target),
|
||||
properties=edge_data,
|
||||
)
|
||||
)
|
||||
seen_edges.add(edge_id)
|
||||
|
||||
logger.info(
|
||||
f"[{self.workspace}] Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||
)
|
||||
return result
|
||||
|
||||
async def get_all_nodes(self) -> list[dict]:
|
||||
"""Get all nodes in the graph.
|
||||
|
||||
Returns:
|
||||
A list of all nodes, where each node is a dictionary of its properties
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
all_nodes = []
|
||||
for node_id, node_data in graph.nodes(data=True):
|
||||
node_data_with_id = node_data.copy()
|
||||
node_data_with_id["id"] = node_id
|
||||
all_nodes.append(node_data_with_id)
|
||||
return all_nodes
|
||||
|
||||
async def get_all_edges(self) -> list[dict]:
|
||||
"""Get all edges in the graph.
|
||||
|
||||
Returns:
|
||||
A list of all edges, where each edge is a dictionary of its properties
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
all_edges = []
|
||||
for u, v, edge_data in graph.edges(data=True):
|
||||
edge_data_with_nodes = edge_data.copy()
|
||||
edge_data_with_nodes["source"] = u
|
||||
edge_data_with_nodes["target"] = v
|
||||
all_edges.append(edge_data_with_nodes)
|
||||
return all_edges
|
||||
|
||||
async def index_done_callback(self) -> bool:
|
||||
"""Save data to disk"""
|
||||
async with self._storage_lock:
|
||||
# Check if storage was updated by another process
|
||||
if self.storage_updated.value:
|
||||
# Storage was updated by another process, reload data instead of saving
|
||||
logger.info(
|
||||
f"[{self.workspace}] Graph was updated by another process, reloading..."
|
||||
)
|
||||
self._graph = (
|
||||
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
|
||||
)
|
||||
# Reset update flag
|
||||
self.storage_updated.value = False
|
||||
return False # Return error
|
||||
|
||||
# Acquire lock and perform persistence
|
||||
async with self._storage_lock:
|
||||
try:
|
||||
# Save data to disk
|
||||
NetworkXStorage.write_nx_graph(
|
||||
self._graph, self._graphml_xml_file, self.workspace
|
||||
)
|
||||
# Notify other processes that data has been updated
|
||||
await set_all_update_flags(self.namespace, workspace=self.workspace)
|
||||
# Reset own update flag to avoid self-reloading
|
||||
self.storage_updated.value = False
|
||||
return True # Return success
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.workspace}] Error saving graph: {e}")
|
||||
return False # Return error
|
||||
|
||||
return True
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all graph data from storage and clean up resources
|
||||
|
||||
This method will:
|
||||
1. Remove the graph storage file if it exists
|
||||
2. Reset the graph to an empty state
|
||||
3. Update flags to notify other processes
|
||||
4. Changes is persisted to disk immediately
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
async with self._storage_lock:
|
||||
# delete _client_file_name
|
||||
if os.path.exists(self._graphml_xml_file):
|
||||
os.remove(self._graphml_xml_file)
|
||||
self._graph = nx.Graph()
|
||||
# Notify other processes that data has been updated
|
||||
await set_all_update_flags(self.namespace, workspace=self.workspace)
|
||||
# Reset own update flag to avoid self-reloading
|
||||
self.storage_updated.value = False
|
||||
logger.info(
|
||||
f"[{self.workspace}] Process {os.getpid()} drop graph file:{self._graphml_xml_file}"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[{self.workspace}] Error dropping graph file:{self._graphml_xml_file}: {e}"
|
||||
)
|
||||
return {"status": "error", "message": str(e)}
|
||||
3045
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/opensearch_impl.py
Normal file
3045
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/opensearch_impl.py
Normal file
File diff suppressed because it is too large
Load Diff
6675
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/postgres_impl.py
Normal file
6675
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/postgres_impl.py
Normal file
File diff suppressed because it is too large
Load Diff
1051
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/qdrant_impl.py
Normal file
1051
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/qdrant_impl.py
Normal file
File diff suppressed because it is too large
Load Diff
1137
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/redis_impl.py
Normal file
1137
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/redis_impl.py
Normal file
File diff suppressed because it is too large
Load Diff
1717
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/shared_storage.py
Normal file
1717
.tmp/lightrag_inspect/lightrag_pkg/lightrag/kg/shared_storage.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user