import asyncio import os from typing import Any, final, Optional, Dict from dataclasses import dataclass, fields import numpy as np from lightrag.utils import logger, compute_mdhash_id from ..base import BaseVectorStorage from ..constants import DEFAULT_MAX_FILE_PATH_LENGTH from ..kg.shared_storage import get_data_init_lock import pipmaster as pm if not pm.is_installed("pymilvus"): pm.install("pymilvus>=2.6.2") import configparser from pymilvus import MilvusClient, DataType, CollectionSchema, FieldSchema # type: ignore from packaging import version config = configparser.ConfigParser() config.read("config.ini", "utf-8") # Supported index types SUPPORTED_INDEX_TYPES = { "AUTOINDEX", "HNSW", "HNSW_SQ", "HNSW_PQ", "HNSW_PRQ", "IVF_FLAT", "IVF_SQ8", "IVF_PQ", "DISKANN", "SCANN", } # Supported metric types SUPPORTED_METRIC_TYPES = {"COSINE", "L2", "IP"} # HNSW_SQ quantization types SUPPORTED_SQ_TYPES = {"SQ4U", "SQ6", "SQ8", "BF16", "FP16"} SUPPORTED_REFINE_TYPES = {"SQ6", "SQ8", "BF16", "FP16", "FP32"} # Index type version requirements # Important: HNSW_SQ was first introduced in Milvus 2.6.8 (not 2.5) INDEX_VERSION_REQUIREMENTS = { "HNSW_SQ": "2.6.8", # HNSW_SQ requires Milvus 2.6.8+ (supports sq_types such as SQ4U, SQ6, SQ8, BF16, FP16) } def _get_env_bool(key: str, default: bool = False) -> bool: """Parse environment variable as boolean""" val = os.environ.get(key, "").lower() if val in ("true", "1", "yes", "on"): return True elif val in ("false", "0", "no", "off"): return False return default def _get_env_int(key: str, default: int) -> int: """Parse environment variable as integer""" val = os.environ.get(key, "") if val: try: return int(val) except ValueError: logger.warning( f"Invalid integer value for {key}: {val}, using default {default}" ) return default @dataclass class MilvusIndexConfig: """ Milvus vector index configuration class Supports configuration via environment variables or initialization parameters. Initialization parameters take precedence over environment variables. """ # Base configuration index_type: Optional[str] = None metric_type: Optional[str] = None # HNSW series parameters hnsw_m: Optional[int] = None hnsw_ef_construction: Optional[int] = None hnsw_ef: Optional[int] = None # HNSW_SQ specific parameters sq_type: Optional[str] = None sq_refine: Optional[bool] = None sq_refine_type: Optional[str] = None sq_refine_k: Optional[int] = None # IVF series parameters ivf_nlist: Optional[int] = None ivf_nprobe: Optional[int] = None def __post_init__(self): """Load configuration from environment variables (init parameters take precedence)""" # Index type self.index_type = ( self.index_type or os.environ.get("MILVUS_INDEX_TYPE", "AUTOINDEX") ).upper() # Metric type self.metric_type = ( self.metric_type or os.environ.get("MILVUS_METRIC_TYPE", "COSINE") ).upper() # HNSW parameters # Defaults aligned with Milvus 2.4+ official documentation if self.hnsw_m is None: self.hnsw_m = _get_env_int("MILVUS_HNSW_M", 16) if self.hnsw_ef_construction is None: self.hnsw_ef_construction = _get_env_int("MILVUS_HNSW_EF_CONSTRUCTION", 360) if self.hnsw_ef is None: self.hnsw_ef = _get_env_int("MILVUS_HNSW_EF", 200) # HNSW_SQ parameters if self.sq_type is None: self.sq_type = os.environ.get("MILVUS_HNSW_SQ_TYPE", "SQ8").upper() if self.sq_refine is None: self.sq_refine = _get_env_bool("MILVUS_HNSW_SQ_REFINE", False) if self.sq_refine_type is None: self.sq_refine_type = os.environ.get( "MILVUS_HNSW_SQ_REFINE_TYPE", "FP32" ).upper() if self.sq_refine_k is None: self.sq_refine_k = _get_env_int("MILVUS_HNSW_SQ_REFINE_K", 10) # IVF parameters if self.ivf_nlist is None: self.ivf_nlist = _get_env_int("MILVUS_IVF_NLIST", 1024) if self.ivf_nprobe is None: self.ivf_nprobe = _get_env_int("MILVUS_IVF_NPROBE", 16) # Validate configuration self._validate() def _validate(self): """Validate configuration validity""" if self.index_type not in SUPPORTED_INDEX_TYPES: raise ValueError( f"Unsupported index type: {self.index_type}. " f"Supported: {SUPPORTED_INDEX_TYPES}" ) if self.metric_type not in SUPPORTED_METRIC_TYPES: raise ValueError( f"Unsupported metric type: {self.metric_type}. " f"Supported: {SUPPORTED_METRIC_TYPES}" ) if self.index_type == "HNSW_SQ": if self.sq_type not in SUPPORTED_SQ_TYPES: raise ValueError( f"Unsupported sq_type: {self.sq_type}. " f"Supported: {SUPPORTED_SQ_TYPES}" ) if self.sq_refine and self.sq_refine_type not in SUPPORTED_REFINE_TYPES: raise ValueError( f"Unsupported refine_type: {self.sq_refine_type}. " f"Supported: {SUPPORTED_REFINE_TYPES}" ) # Parameter range validation if not (2 <= self.hnsw_m <= 2048): raise ValueError(f"hnsw_m must be in [2, 2048], got {self.hnsw_m}") if self.hnsw_ef_construction < 1: raise ValueError( f"hnsw_ef_construction must be >= 1, got {self.hnsw_ef_construction}" ) if self.ivf_nlist < 1 or self.ivf_nlist > 65536: raise ValueError(f"ivf_nlist must be in [1, 65536], got {self.ivf_nlist}") def validate_milvus_version(self, server_version: str) -> None: """ Validate Milvus server version supports the configured index type Args: server_version: Milvus server version string (e.g., "2.6.9") Raises: ValueError: Version does not meet index type requirements """ current_ver = version.parse( server_version.split("-")[0] ) # Handle "2.6.9-dev" format # Check HNSW_SQ index type version requirements (requires 2.6.8+) if self.index_type == "HNSW_SQ": required = INDEX_VERSION_REQUIREMENTS["HNSW_SQ"] if current_ver < version.parse(required): raise ValueError( f"HNSW_SQ requires Milvus {required}+, " f"current version: {server_version}" ) logger.info( f"Milvus version {server_version} validated for index type " f"{self.index_type}" + (f" with sq_type {self.sq_type}" if self.index_type == "HNSW_SQ" else "") ) def build_index_params(self, index_params, field_name: str = "vector"): """ Build pymilvus index parameters Args: index_params: IndexParams instance (from compatibility helper or client.prepare_index_params()) field_name: Vector field name Returns: IndexParams object, or a dict fallback when direct API creation is needed. """ if index_params is None: if self.index_type == "AUTOINDEX": logger.info( "Using AUTOINDEX with direct API fallback because IndexParams is unavailable" ) return { "field_name": field_name, "index_type": self.index_type, "metric_type": self.metric_type, "params": {}, } raise RuntimeError( f"IndexParams not available but required for index type " f"'{self.index_type}'. Ensure pymilvus is installed correctly." ) params: Dict[str, Any] = {} # HNSW series indexes if self.index_type in ("HNSW", "HNSW_SQ", "HNSW_PQ", "HNSW_PRQ"): params["M"] = self.hnsw_m params["efConstruction"] = self.hnsw_ef_construction # HNSW_SQ specific parameters if self.index_type == "HNSW_SQ": params["sq_type"] = self.sq_type if self.sq_refine: params["refine"] = True params["refine_type"] = self.sq_refine_type # IVF series indexes elif self.index_type in ("IVF_FLAT", "IVF_SQ8", "IVF_PQ"): params["nlist"] = self.ivf_nlist # DISKANN / SCANN have no additional params index_params.add_index( field_name=field_name, index_type=self.index_type, metric_type=self.metric_type, params=params, ) logger.info( f"Milvus index configured: type={self.index_type}, " f"metric={self.metric_type}, params={params}" ) return index_params def build_search_params(self) -> Dict[str, Any]: """ Build search parameters Returns: Search parameters dictionary """ search_params: Dict[str, Any] = {} if self.index_type in ("HNSW", "HNSW_SQ", "HNSW_PQ", "HNSW_PRQ"): search_params["ef"] = self.hnsw_ef if self.index_type == "HNSW_SQ" and self.sq_refine: search_params["refine_k"] = self.sq_refine_k elif self.index_type in ("IVF_FLAT", "IVF_SQ8", "IVF_PQ"): search_params["nprobe"] = self.ivf_nprobe return {"params": search_params} if search_params else {} @classmethod def get_config_field_names(cls) -> set: """Get all configuration field names from the dataclass. This method provides a single source of truth for configuration parameter names, eliminating the need to maintain duplicate hardcoded lists elsewhere. Returns: Set of field names that can be used to extract configuration from kwargs """ return {f.name for f in fields(cls)} def to_dict(self) -> Dict[str, Any]: """Export configuration as dictionary (for logging/debugging)""" return { "index_type": self.index_type, "metric_type": self.metric_type, "hnsw_m": self.hnsw_m, "hnsw_ef_construction": self.hnsw_ef_construction, "hnsw_ef": self.hnsw_ef, "sq_type": self.sq_type if self.index_type == "HNSW_SQ" else None, "sq_refine": self.sq_refine if self.index_type == "HNSW_SQ" else None, "sq_refine_type": ( self.sq_refine_type if self.index_type == "HNSW_SQ" and self.sq_refine else None ), "sq_refine_k": ( self.sq_refine_k if self.index_type == "HNSW_SQ" and self.sq_refine else None ), "ivf_nlist": ( self.ivf_nlist if self.index_type.startswith("IVF") else None ), "ivf_nprobe": ( self.ivf_nprobe if self.index_type.startswith("IVF") else None ), } @final @dataclass class MilvusVectorDBStorage(BaseVectorStorage): def _get_milvus_connection_kwargs(self, include_db_name: bool = True) -> dict: """Build Milvus connection kwargs from env/config.""" connection_kwargs = { "uri": os.environ.get( "MILVUS_URI", config.get( "milvus", "uri", fallback=os.path.join( self.global_config["working_dir"], "milvus_lite.db" ), ), ), "user": os.environ.get( "MILVUS_USER", config.get("milvus", "user", fallback=None) ), "password": os.environ.get( "MILVUS_PASSWORD", config.get("milvus", "password", fallback=None), ), "token": os.environ.get( "MILVUS_TOKEN", config.get("milvus", "token", fallback=None) ), } db_name = os.environ.get( "MILVUS_DB_NAME", config.get("milvus", "db_name", fallback=None), ) if include_db_name and db_name: connection_kwargs["db_name"] = db_name return connection_kwargs def _get_milvus_db_name(self) -> Optional[str]: """Return the configured Milvus database name, if any.""" db_name = self._get_milvus_connection_kwargs(include_db_name=True).get( "db_name" ) if db_name is None: return None normalized_name = str(db_name).strip() return normalized_name or None def _create_milvus_client(self) -> MilvusClient: """Create a Milvus client and ensure the configured database exists.""" client = MilvusClient( **self._get_milvus_connection_kwargs(include_db_name=False) ) db_name = self._get_milvus_db_name() if not db_name: return client existing_databases = set(client.list_databases()) if db_name not in existing_databases: logger.warning( f"[{self.workspace}] Milvus database '{db_name}' not found, creating it" ) client.create_database(db_name) use_database = getattr(client, "use_database", None) or getattr( client, "using_database", None ) if callable(use_database): use_database(db_name) logger.debug( f"[{self.workspace}] Using Milvus database '{db_name}' for namespace '{self.namespace}'" ) return client return MilvusClient(**self._get_milvus_connection_kwargs(include_db_name=True)) def _create_schema_for_namespace(self) -> CollectionSchema: """Create schema based on the current instance's namespace""" # Get vector dimension from embedding_func dimension = self.embedding_func.embedding_dim # Base fields (common to all collections) base_fields = [ FieldSchema( name="id", dtype=DataType.VARCHAR, max_length=64, is_primary=True ), FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension), FieldSchema(name="created_at", dtype=DataType.INT64), ] # Determine specific fields based on namespace if self.namespace.endswith("entities"): specific_fields = [ FieldSchema( name="entity_name", dtype=DataType.VARCHAR, max_length=512, nullable=True, ), FieldSchema( name="file_path", dtype=DataType.VARCHAR, max_length=DEFAULT_MAX_FILE_PATH_LENGTH, nullable=True, ), ] description = "LightRAG entities vector storage" elif self.namespace.endswith("relationships"): specific_fields = [ FieldSchema( name="src_id", dtype=DataType.VARCHAR, max_length=512, nullable=True ), FieldSchema( name="tgt_id", dtype=DataType.VARCHAR, max_length=512, nullable=True ), FieldSchema( name="file_path", dtype=DataType.VARCHAR, max_length=DEFAULT_MAX_FILE_PATH_LENGTH, nullable=True, ), ] description = "LightRAG relationships vector storage" elif self.namespace.endswith("chunks"): specific_fields = [ FieldSchema( name="full_doc_id", dtype=DataType.VARCHAR, max_length=64, nullable=True, ), FieldSchema( name="file_path", dtype=DataType.VARCHAR, max_length=DEFAULT_MAX_FILE_PATH_LENGTH, nullable=True, ), ] description = "LightRAG chunks vector storage" else: # Default generic schema (backward compatibility) specific_fields = [ FieldSchema( name="file_path", dtype=DataType.VARCHAR, max_length=DEFAULT_MAX_FILE_PATH_LENGTH, nullable=True, ), ] description = "LightRAG generic vector storage" # Merge all fields all_fields = base_fields + specific_fields return CollectionSchema( fields=all_fields, description=description, enable_dynamic_field=True, # Support dynamic fields ) def _get_index_params(self): """Get IndexParams in a version-compatible way""" try: # Try to use client's prepare_index_params method (most common) if hasattr(self._client, "prepare_index_params"): return self._client.prepare_index_params() except Exception: pass try: # Try to import IndexParams from different possible locations from pymilvus.client.prepare import IndexParams # type: ignore return IndexParams() except ImportError: pass try: from pymilvus.client.types import IndexParams # type: ignore return IndexParams() except ImportError: pass try: from pymilvus import IndexParams # type: ignore return IndexParams() except ImportError: pass # If all else fails, return None to use fallback method return None def _create_scalar_index_fallback(self, field_name: str, index_type: str): """Fallback method to create scalar index using direct API""" # Skip unsupported index types if index_type == "SORTED": logger.info( f"[{self.workspace}] Skipping SORTED index for {field_name} (not supported in this Milvus version)" ) return try: self._client.create_index( collection_name=self.final_namespace, field_name=field_name, index_params={"index_type": index_type}, ) logger.debug( f"[{self.workspace}] Created {field_name} index using fallback method" ) except Exception as e: logger.info( f"[{self.workspace}] Could not create {field_name} index using fallback method: {e}" ) def _create_indexes_after_collection(self): """Create indexes after collection is created""" # Build vector index using index configuration # Use compatibility helper to get IndexParams index_params_for_vector = self._get_index_params() vector_index_params = self.index_config.build_index_params( index_params_for_vector, field_name="vector" ) # Re-raise exceptions to surface vector index creation failures if isinstance(vector_index_params, dict): self._client.create_index( collection_name=self.final_namespace, field_name=vector_index_params["field_name"], index_params={ "index_type": vector_index_params["index_type"], "metric_type": vector_index_params["metric_type"], "params": vector_index_params["params"], }, ) else: self._client.create_index( collection_name=self.final_namespace, index_params=vector_index_params, ) logger.debug( f"[{self.workspace}] Created vector index with config: {self.index_config.to_dict()}" ) # Create scalar indexes based on namespace # Wrap scalar index creation in try-except to allow graceful degradation try: # Try to get IndexParams in a version-compatible way scalar_index_params = self._get_index_params() if scalar_index_params is not None: # Create scalar indexes based on namespace if self.namespace.endswith("entities"): # Create indexes for entity fields try: entity_name_index = self._get_index_params() entity_name_index.add_index( field_name="entity_name", index_type="INVERTED" ) self._client.create_index( collection_name=self.final_namespace, index_params=entity_name_index, ) except Exception as e: logger.debug( f"[{self.workspace}] IndexParams method failed for entity_name: {e}" ) self._create_scalar_index_fallback("entity_name", "INVERTED") elif self.namespace.endswith("relationships"): # Create indexes for relationship fields try: src_id_index = self._get_index_params() src_id_index.add_index( field_name="src_id", index_type="INVERTED" ) self._client.create_index( collection_name=self.final_namespace, index_params=src_id_index, ) except Exception as e: logger.debug( f"[{self.workspace}] IndexParams method failed for src_id: {e}" ) self._create_scalar_index_fallback("src_id", "INVERTED") try: tgt_id_index = self._get_index_params() tgt_id_index.add_index( field_name="tgt_id", index_type="INVERTED" ) self._client.create_index( collection_name=self.final_namespace, index_params=tgt_id_index, ) except Exception as e: logger.debug( f"[{self.workspace}] IndexParams method failed for tgt_id: {e}" ) self._create_scalar_index_fallback("tgt_id", "INVERTED") elif self.namespace.endswith("chunks"): # Create indexes for chunk fields try: doc_id_index = self._get_index_params() doc_id_index.add_index( field_name="full_doc_id", index_type="INVERTED" ) self._client.create_index( collection_name=self.final_namespace, index_params=doc_id_index, ) except Exception as e: logger.debug( f"[{self.workspace}] IndexParams method failed for full_doc_id: {e}" ) self._create_scalar_index_fallback("full_doc_id", "INVERTED") else: # Fallback to direct API calls if IndexParams is not available logger.info( f"[{self.workspace}] IndexParams not available, using fallback methods for {self.namespace}" ) # Create scalar indexes using fallback if self.namespace.endswith("entities"): self._create_scalar_index_fallback("entity_name", "INVERTED") elif self.namespace.endswith("relationships"): self._create_scalar_index_fallback("src_id", "INVERTED") self._create_scalar_index_fallback("tgt_id", "INVERTED") elif self.namespace.endswith("chunks"): self._create_scalar_index_fallback("full_doc_id", "INVERTED") logger.info( f"[{self.workspace}] Created indexes for collection: {self.namespace}" ) except Exception as e: # Scalar index failures are logged as warnings (not critical) logger.warning( f"[{self.workspace}] Failed to create some scalar indexes for {self.namespace}: {e}" ) def _get_required_fields_for_namespace(self) -> dict: """Get required core field definitions for current namespace""" # Base fields (common to all types) base_fields = { "id": {"type": "VarChar", "is_primary": True}, "vector": {"type": "FloatVector"}, "created_at": {"type": "Int64"}, } # Add specific fields based on namespace if self.namespace.endswith("entities"): specific_fields = { "entity_name": {"type": "VarChar"}, "file_path": {"type": "VarChar"}, } elif self.namespace.endswith("relationships"): specific_fields = { "src_id": {"type": "VarChar"}, "tgt_id": {"type": "VarChar"}, "file_path": {"type": "VarChar"}, } elif self.namespace.endswith("chunks"): specific_fields = { "full_doc_id": {"type": "VarChar"}, "file_path": {"type": "VarChar"}, } else: specific_fields = { "file_path": {"type": "VarChar"}, } return {**base_fields, **specific_fields} def _is_field_compatible(self, existing_field: dict, expected_config: dict) -> bool: """Check compatibility of a single field""" field_name = existing_field.get("name", "unknown") existing_type = existing_field.get("type") expected_type = expected_config.get("type") logger.debug( f"[{self.workspace}] Checking field '{field_name}': existing_type={existing_type} (type={type(existing_type)}), expected_type={expected_type}" ) # Convert DataType enum values to string names if needed original_existing_type = existing_type if hasattr(existing_type, "name"): existing_type = existing_type.name logger.debug( f"[{self.workspace}] Converted enum to name: {original_existing_type} -> {existing_type}" ) elif isinstance(existing_type, int): # Map common Milvus internal type codes to type names for backward compatibility type_mapping = { 21: "VarChar", 101: "FloatVector", 5: "Int64", 9: "Double", } mapped_type = type_mapping.get(existing_type, str(existing_type)) logger.debug( f"[{self.workspace}] Mapped numeric type: {existing_type} -> {mapped_type}" ) existing_type = mapped_type # Normalize type names for comparison type_aliases = { "VARCHAR": "VarChar", "String": "VarChar", "FLOAT_VECTOR": "FloatVector", "INT64": "Int64", "BigInt": "Int64", "DOUBLE": "Double", "Float": "Double", } original_existing = existing_type original_expected = expected_type existing_type = type_aliases.get(existing_type, existing_type) expected_type = type_aliases.get(expected_type, expected_type) if original_existing != existing_type or original_expected != expected_type: logger.debug( f"[{self.workspace}] Applied aliases: {original_existing} -> {existing_type}, {original_expected} -> {expected_type}" ) # Basic type compatibility check type_compatible = existing_type == expected_type logger.debug( f"[{self.workspace}] Type compatibility for '{field_name}': {existing_type} == {expected_type} -> {type_compatible}" ) if not type_compatible: logger.warning( f"[{self.workspace}] Type mismatch for field '{field_name}': expected {expected_type}, got {existing_type}" ) return False # Primary key check - be more flexible about primary key detection if expected_config.get("is_primary"): # Check multiple possible field names for primary key status is_primary = ( existing_field.get("is_primary_key", False) or existing_field.get("is_primary", False) or existing_field.get("primary_key", False) ) logger.debug( f"[{self.workspace}] Primary key check for '{field_name}': expected=True, actual={is_primary}" ) logger.debug( f"[{self.workspace}] Raw field data for '{field_name}': {existing_field}" ) # For ID field, be more lenient - if it's the ID field, assume it should be primary if field_name == "id" and not is_primary: logger.info( f"[{self.workspace}] ID field '{field_name}' not marked as primary in existing collection, but treating as compatible" ) # Don't fail for ID field primary key mismatch elif not is_primary: logger.warning( f"[{self.workspace}] Primary key mismatch for field '{field_name}': expected primary key, but field is not primary" ) return False logger.debug(f"[{self.workspace}] Field '{field_name}' is compatible") return True def _check_vector_dimension(self, collection_info: dict): """Check vector dimension compatibility""" current_dimension = self.embedding_func.embedding_dim # Find vector field dimension for field in collection_info.get("fields", []): if field.get("name") == "vector": field_type = field.get("type") # Extract type name from DataType enum or string type_name = None if hasattr(field_type, "name"): type_name = field_type.name elif isinstance(field_type, str): type_name = field_type else: type_name = str(field_type) # Check if it's a vector type (supports multiple formats) if type_name in ["FloatVector", "FLOAT_VECTOR"]: existing_dimension = field.get("params", {}).get("dim") # Convert both to int for comparison to handle type mismatches # (Milvus API may return string "1024" vs int 1024) try: existing_dim_int = ( int(existing_dimension) if existing_dimension is not None else None ) current_dim_int = ( int(current_dimension) if current_dimension is not None else None ) except (TypeError, ValueError) as e: logger.error( f"[{self.workspace}] Failed to parse dimensions: existing={existing_dimension} (type={type(existing_dimension)}), " f"current={current_dimension} (type={type(current_dimension)}), error={e}" ) raise ValueError( f"Invalid dimension values for collection '{self.final_namespace}': " f"existing={existing_dimension}, current={current_dimension}" ) from e if existing_dim_int != current_dim_int: raise ValueError( f"Vector dimension mismatch for collection '{self.final_namespace}': " f"existing={existing_dim_int}, current={current_dim_int}" ) logger.debug( f"[{self.workspace}] Vector dimension check passed: {current_dim_int}" ) return # If no vector field found, this might be an old collection created with simple schema logger.warning( f"[{self.workspace}] Vector field not found in collection '{self.namespace}'. This might be an old collection created with simple schema." ) logger.warning( f"[{self.workspace}] Consider recreating the collection for optimal performance." ) return def _check_file_path_length_restriction(self, collection_info: dict) -> bool: """Check if collection has file_path length restrictions that need migration Returns: bool: True if migration is needed, False otherwise """ existing_fields = { field["name"]: field for field in collection_info.get("fields", []) } # Check if file_path field exists and has length restrictions if "file_path" in existing_fields: file_path_field = existing_fields["file_path"] # Get max_length from field params max_length = file_path_field.get("params", {}).get("max_length") if max_length and max_length < DEFAULT_MAX_FILE_PATH_LENGTH: logger.info( f"[{self.workspace}] Collection {self.namespace} has file_path max_length={max_length}, " f"needs migration to {DEFAULT_MAX_FILE_PATH_LENGTH}" ) return True return False def _check_schema_compatibility(self, collection_info: dict): """Check schema field compatibility and detect migration needs""" existing_fields = { field["name"]: field for field in collection_info.get("fields", []) } # Check if this is an old collection created with simple schema has_vector_field = any( field.get("name") == "vector" for field in collection_info.get("fields", []) ) if not has_vector_field: logger.warning( f"[{self.workspace}] Collection {self.namespace} appears to be created with old simple schema (no vector field)" ) logger.warning( f"[{self.workspace}] This collection will work but may have suboptimal performance" ) logger.warning( f"[{self.workspace}] Consider recreating the collection for optimal performance" ) return # Check if migration is needed for file_path length restrictions if self._check_file_path_length_restriction(collection_info): logger.info( f"[{self.workspace}] Starting automatic migration for collection {self.namespace}" ) self._migrate_collection_schema() return # For collections with vector field, check basic compatibility # Only check for critical incompatibilities, not missing optional fields critical_fields = {"id": {"type": "VarChar", "is_primary": True}} incompatible_fields = [] for field_name, expected_config in critical_fields.items(): if field_name in existing_fields: existing_field = existing_fields[field_name] if not self._is_field_compatible(existing_field, expected_config): incompatible_fields.append( f"{field_name}: expected {expected_config['type']}, " f"got {existing_field.get('type')}" ) if incompatible_fields: raise ValueError( f"Critical schema incompatibility in collection '{self.final_namespace}': {incompatible_fields}" ) # Get all expected fields for informational purposes expected_fields = self._get_required_fields_for_namespace() missing_fields = [ field for field in expected_fields if field not in existing_fields ] if missing_fields: logger.info( f"[{self.workspace}] Collection {self.namespace} missing optional fields: {missing_fields}" ) logger.info( "These fields would be available in a newly created collection for better performance" ) logger.debug( f"[{self.workspace}] Schema compatibility check passed for {self.namespace}" ) def _migrate_collection_schema(self): """Migrate collection schema using query_iterator - completely solves query window limitations""" original_collection_name = self.final_namespace temp_collection_name = f"{self.final_namespace}_temp" iterator = None try: logger.info( f"[{self.workspace}] Starting iterator-based schema migration for {self.namespace}" ) # Step 1: Create temporary collection with new schema logger.info( f"[{self.workspace}] Step 1: Creating temporary collection: {temp_collection_name}" ) # Temporarily update final_namespace for index creation self.final_namespace = temp_collection_name new_schema = self._create_schema_for_namespace() self._client.create_collection( collection_name=temp_collection_name, schema=new_schema ) try: self._create_indexes_after_collection() except Exception as index_error: logger.warning( f"[{self.workspace}] Failed to create indexes for new collection: {index_error}" ) # Continue with migration even if index creation fails # Load the new collection self._client.load_collection(temp_collection_name) # Step 2: Copy data using query_iterator (solves query window limitation) logger.info( f"[{self.workspace}] Step 2: Copying data using query_iterator from: {original_collection_name}" ) # Create query iterator try: iterator = self._client.query_iterator( collection_name=original_collection_name, batch_size=2000, # Adjustable batch size for optimal performance output_fields=["*"], # Get all fields ) logger.debug(f"[{self.workspace}] Query iterator created successfully") except Exception as iterator_error: logger.error( f"[{self.workspace}] Failed to create query iterator: {iterator_error}" ) raise # Iterate through all data total_migrated = 0 batch_number = 1 while True: try: batch_data = iterator.next() if not batch_data: # No more data available break # Insert batch data to new collection try: self._client.insert( collection_name=temp_collection_name, data=batch_data ) total_migrated += len(batch_data) logger.info( f"[{self.workspace}] Iterator batch {batch_number}: " f"processed {len(batch_data)} records, total migrated: {total_migrated}" ) batch_number += 1 except Exception as batch_error: logger.error( f"[{self.workspace}] Failed to insert iterator batch {batch_number}: {batch_error}" ) raise except Exception as next_error: logger.error( f"[{self.workspace}] Iterator next() failed at batch {batch_number}: {next_error}" ) raise if total_migrated > 0: logger.info( f"[{self.workspace}] Successfully migrated {total_migrated} records using iterator" ) else: logger.info( f"[{self.workspace}] No data found in original collection, migration completed" ) # Step 3: Rename origin collection (keep for safety) logger.info( f"[{self.workspace}] Step 3: Rename origin collection to {original_collection_name}_old" ) try: self._client.rename_collection( original_collection_name, f"{original_collection_name}_old" ) except Exception as rename_error: try: logger.warning( f"[{self.workspace}] Try to drop origin collection instead" ) self._client.drop_collection(original_collection_name) except Exception as e: logger.error( f"[{self.workspace}] Rename operation failed: {rename_error}" ) raise e # Step 4: Rename temporary collection to original name logger.info( f"[{self.workspace}] Step 4: Renaming collection {temp_collection_name} -> {original_collection_name}" ) try: self._client.rename_collection( temp_collection_name, original_collection_name ) logger.info(f"[{self.workspace}] Rename operation completed") except Exception as rename_error: logger.error( f"[{self.workspace}] Rename operation failed: {rename_error}" ) raise RuntimeError( f"Failed to rename collection: {rename_error}" ) from rename_error # Restore final_namespace self.final_namespace = original_collection_name except Exception as e: logger.error( f"[{self.workspace}] Iterator-based migration failed for {self.namespace}: {e}" ) # Attempt cleanup of temporary collection if it exists try: if self._client and self._client.has_collection(temp_collection_name): logger.info( f"[{self.workspace}] Cleaning up failed migration temporary collection" ) self._client.drop_collection(temp_collection_name) except Exception as cleanup_error: logger.warning( f"[{self.workspace}] Failed to cleanup temporary collection: {cleanup_error}" ) # Re-raise the original error raise RuntimeError( f"Iterator-based migration failed for collection {self.namespace}: {e}" ) from e finally: # Ensure iterator is properly closed if iterator: try: iterator.close() logger.debug( f"[{self.workspace}] Query iterator closed successfully" ) except Exception as close_error: logger.warning( f"[{self.workspace}] Failed to close query iterator: {close_error}" ) def _validate_collection_compatibility(self): """Validate existing collection's dimension and schema compatibility""" try: collection_info = self._client.describe_collection(self.final_namespace) # 1. Check vector dimension self._check_vector_dimension(collection_info) # 2. Check schema compatibility self._check_schema_compatibility(collection_info) logger.info( f"[{self.workspace}] VectorDB Collection '{self.namespace}' compatibility validation passed" ) except Exception as e: logger.error( f"[{self.workspace}] Collection compatibility validation failed for {self.namespace}: {e}" ) raise @staticmethod def _is_missing_vector_index_error(error: Exception) -> bool: """Return True when the error indicates the collection lacks a vector index.""" error_message = str(error).lower() return ( "no vector index" in error_message or "please create index firstly" in error_message ) def _repair_missing_vector_index(self): """Create indexes for an existing collection that is missing its vector index.""" logger.warning( f"[{self.workspace}] Collection '{self.namespace}' is missing a vector index, attempting repair" ) self._create_indexes_after_collection() def _ensure_collection_loaded(self): """Ensure the collection is loaded into memory for search operations""" try: # Check if collection exists first if not self._client.has_collection(self.final_namespace): logger.error( f"[{self.workspace}] Collection {self.namespace} does not exist" ) raise ValueError(f"Collection {self.final_namespace} does not exist") # Load the collection if it's not already loaded # In Milvus, collections need to be loaded before they can be searched self._client.load_collection(self.final_namespace) # logger.debug(f"[{self.workspace}] Collection {self.namespace} loaded successfully") except Exception as e: logger.error( f"[{self.workspace}] Failed to load collection {self.namespace}: {e}" ) raise def _create_collection_if_not_exist(self): """Create collection if not exists and check existing collection compatibility""" try: # Check if our specific collection exists collection_exists = self._client.has_collection(self.final_namespace) logger.info( f"[{self.workspace}] VectorDB collection '{self.namespace}' exists check: {collection_exists}" ) if collection_exists: # Double-check by trying to describe the collection try: self._client.describe_collection(self.final_namespace) self._validate_collection_compatibility() try: # Ensure the collection is loaded after validation self._ensure_collection_loaded() return except Exception as load_error: if not self._is_missing_vector_index_error(load_error): raise try: self._repair_missing_vector_index() self._ensure_collection_loaded() logger.info( f"[{self.workspace}] Repaired missing vector index for existing collection '{self.namespace}'" ) return except Exception as repair_error: raise RuntimeError( f"Index repair failed for collection '{self.final_namespace}'. " f"Original error: {repair_error}" ) from repair_error except Exception as validation_error: # CRITICAL: Collection exists but validation failed # This indicates potential data migration failure or incompatible schema # Stop execution to prevent data loss and require manual intervention logger.error( f"[{self.workspace}] CRITICAL ERROR: Collection '{self.namespace}' exists but validation failed!" ) logger.error( f"[{self.workspace}] This indicates potential data migration failure or schema incompatibility." ) logger.error( f"[{self.workspace}] Validation error: {validation_error}" ) logger.error(f"[{self.workspace}] MANUAL INTERVENTION REQUIRED:") logger.error( f"[{self.workspace}] 1. Check the existing collection schema and data integrity" ) logger.error( f"[{self.workspace}] 2. Backup existing data if needed" ) logger.error( f"[{self.workspace}] 3. Manually resolve schema compatibility issues" ) logger.error( f"[{self.workspace}] 4. Consider dropping and recreating the collection if data is not critical" ) logger.error( f"[{self.workspace}] Program execution stopped to prevent potential data loss." ) # Raise a specific exception to stop execution raise RuntimeError( f"Collection validation failed for '{self.final_namespace}'. " f"Data migration failure detected. Manual intervention required to prevent data loss. " f"Original error: {validation_error}" ) # Collection doesn't exist, create new collection logger.info(f"[{self.workspace}] Creating new collection: {self.namespace}") schema = self._create_schema_for_namespace() # Create collection with schema only first self._client.create_collection( collection_name=self.final_namespace, schema=schema ) # Then create indexes self._create_indexes_after_collection() # Load the newly created collection self._ensure_collection_loaded() logger.info( f"[{self.workspace}] Successfully created Milvus collection: {self.namespace}" ) except RuntimeError: # Re-raise RuntimeError (validation failures) without modification # These are critical errors that should stop execution raise except Exception as e: logger.error( f"[{self.workspace}] Error in _create_collection_if_not_exist for {self.namespace}: {e}" ) # If there's any error (other than validation failure), try to force create the collection logger.info( f"[{self.workspace}] Attempting to force create collection {self.namespace}..." ) try: # Try to drop the collection first if it exists in a bad state try: if self._client.has_collection(self.final_namespace): logger.info( f"[{self.workspace}] Dropping potentially corrupted collection {self.namespace}" ) self._client.drop_collection(self.final_namespace) except Exception as drop_error: logger.warning( f"[{self.workspace}] Could not drop collection {self.namespace}: {drop_error}" ) # Create fresh collection schema = self._create_schema_for_namespace() self._client.create_collection( collection_name=self.final_namespace, schema=schema ) self._create_indexes_after_collection() # Load the newly created collection self._ensure_collection_loaded() logger.info( f"[{self.workspace}] Successfully force-created collection {self.namespace}" ) except Exception as create_error: logger.error( f"[{self.workspace}] Failed to force-create collection {self.namespace}: {create_error}" ) raise def __post_init__(self): self._validate_embedding_func() # Extract MilvusIndexConfig parameters from vector_db_storage_cls_kwargs # # IMPORTANT: This approach allows Milvus index configuration via vector_db_storage_cls_kwargs, # which is the RECOMMENDED method for framework integration (e.g., RAGAnything). # # All 11 index configuration parameters can be passed through vector_db_storage_cls_kwargs: # - index_type, metric_type # - hnsw_m, hnsw_ef_construction, hnsw_ef # - sq_type, sq_refine, sq_refine_type, sq_refine_k # - ivf_nlist, ivf_nprobe # # Example: # LightRAG( # vector_storage="MilvusVectorDBStorage", # vector_db_storage_cls_kwargs={ # "cosine_better_than_threshold": 0.2, # "index_type": "HNSW", # "metric_type": "COSINE", # "hnsw_m": 32, # "hnsw_ef_construction": 256, # } # ) # # Use MilvusIndexConfig.get_config_field_names() to dynamically extract valid parameters. # This ensures we always stay in sync with the MilvusIndexConfig dataclass definition. kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) index_config_keys = MilvusIndexConfig.get_config_field_names() index_config_params = { k: v for k, v in kwargs.items() if k in index_config_keys } # Initialize index configuration (if not already set) # Configuration priority: init params from kwargs > environment variables > defaults if not hasattr(self, "index_config") or self.index_config is None: self.index_config = MilvusIndexConfig(**index_config_params) # Check for MILVUS_WORKSPACE environment variable first (higher priority) # This allows administrators to force a specific workspace for all Milvus storage instances milvus_workspace = os.environ.get("MILVUS_WORKSPACE") if milvus_workspace and milvus_workspace.strip(): # Use environment variable value, overriding the passed workspace parameter effective_workspace = milvus_workspace.strip() logger.info( f"Using MILVUS_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')" ) else: # Use the workspace parameter passed during initialization effective_workspace = self.workspace if effective_workspace: logger.debug( f"Using passed workspace parameter: '{effective_workspace}'" ) # Build final_namespace with workspace prefix for data isolation # Keep original namespace unchanged for type detection logic if effective_workspace: self.final_namespace = f"{effective_workspace}_{self.namespace}" logger.debug( f"Final namespace with workspace prefix: '{self.final_namespace}'" ) else: # When workspace is empty, final_namespace equals original namespace self.final_namespace = self.namespace self.workspace = "" logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'") cosine_threshold = kwargs.get("cosine_better_than_threshold") if cosine_threshold is None: raise ValueError( "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" ) self.cosine_better_than_threshold = cosine_threshold # Ensure created_at is in meta_fields if "created_at" not in self.meta_fields: self.meta_fields.add("created_at") # Initialize client as None - will be created in initialize() method self._client = None self._max_batch_size = self.global_config["embedding_batch_num"] self._initialized = False async def initialize(self): """Initialize Milvus collection""" async with get_data_init_lock(): if self._initialized: return try: # Create MilvusClient if not already created if self._client is None: self._client = self._create_milvus_client() logger.debug( f"[{self.workspace}] MilvusClient created successfully" ) # Validate Milvus version compatibility with configured index if self.index_config.index_type in INDEX_VERSION_REQUIREMENTS: try: server_version = self._client.get_server_version() self.index_config.validate_milvus_version(server_version) except Exception as version_error: logger.error( f"[{self.workspace}] Milvus version validation failed: {version_error}" ) raise # Create collection and check compatibility self._create_collection_if_not_exist() self._initialized = True logger.info( f"[{self.workspace}] Milvus collection '{self.namespace}' initialized successfully" ) except Exception as e: logger.error( f"[{self.workspace}] Failed to initialize Milvus collection '{self.namespace}': {e}" ) raise async def upsert(self, data: dict[str, dict[str, Any]]) -> None: # logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}") if not data: return # Ensure collection is loaded before upserting self._ensure_collection_loaded() import time current_time = int(time.time()) list_data: list[dict[str, Any]] = [ { "id": k, "created_at": current_time, **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, } for k, v in data.items() ] contents = [v["content"] for v in data.values()] batches = [ contents[i : i + self._max_batch_size] for i in range(0, len(contents), self._max_batch_size) ] embedding_tasks = [ self.embedding_func(batch, context="document") for batch in batches ] embeddings_list = await asyncio.gather(*embedding_tasks) embeddings = np.concatenate(embeddings_list) for i, d in enumerate(list_data): d["vector"] = embeddings[i] results = self._client.upsert( collection_name=self.final_namespace, data=list_data ) return results async def query( self, query: str, top_k: int, query_embedding: list[float] = None ) -> list[dict[str, Any]]: # Ensure collection is loaded before querying self._ensure_collection_loaded() # Use provided embedding or compute it if query_embedding is not None: embedding = [query_embedding] # Milvus expects a list of embeddings else: embedding = await self.embedding_func( [query], context="query", _priority=5 ) # higher priority for query # Include all meta_fields (created_at is now always included) output_fields = list(self.meta_fields) # Build search params from index config search_params_base = self.index_config.build_search_params() # Merge with metric type and radius threshold search_params = { "metric_type": self.index_config.metric_type, "params": { **search_params_base.get("params", {}), "radius": self.cosine_better_than_threshold, }, } results = self._client.search( collection_name=self.final_namespace, data=embedding, limit=top_k, output_fields=output_fields, search_params=search_params, ) return [ { **dp["entity"], "id": dp["id"], "distance": dp["distance"], "created_at": dp.get("created_at"), } for dp in results[0] ] async def index_done_callback(self) -> None: # Milvus handles persistence automatically pass async def delete_entity(self, entity_name: str) -> None: """Delete an entity from the vector database Args: entity_name: The name of the entity to delete """ try: # Compute entity ID from name entity_id = compute_mdhash_id(entity_name, prefix="ent-") logger.debug( f"[{self.workspace}] Attempting to delete entity {entity_name} with ID {entity_id}" ) # Delete the entity from Milvus collection result = self._client.delete( collection_name=self.final_namespace, pks=[entity_id] ) if result and result.get("delete_count", 0) > 0: logger.debug( f"[{self.workspace}] Successfully deleted entity {entity_name}" ) else: logger.debug( f"[{self.workspace}] Entity {entity_name} not found in storage" ) except Exception as e: logger.error(f"[{self.workspace}] Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: """Delete all relations associated with an entity Args: entity_name: The name of the entity whose relations should be deleted """ try: # Ensure collection is loaded before querying self._ensure_collection_loaded() # Search for relations where entity is either source or target expr = f'src_id == "{entity_name}" or tgt_id == "{entity_name}"' # Find all relations involving this entity results = self._client.query( collection_name=self.final_namespace, filter=expr, output_fields=["id"] ) if not results or len(results) == 0: logger.debug( f"[{self.workspace}] No relations found for entity {entity_name}" ) return # Extract IDs of relations to delete relation_ids = [item["id"] for item in results] logger.debug( f"[{self.workspace}] Found {len(relation_ids)} relations for entity {entity_name}" ) # Delete the relations if relation_ids: delete_result = self._client.delete( collection_name=self.final_namespace, pks=relation_ids ) logger.debug( f"[{self.workspace}] Deleted {delete_result.get('delete_count', 0)} relations for {entity_name}" ) except Exception as e: logger.error( f"[{self.workspace}] Error deleting relations for {entity_name}: {e}" ) async def delete(self, ids: list[str]) -> None: """Delete vectors with specified IDs Args: ids: List of vector IDs to be deleted """ try: # Ensure collection is loaded before deleting self._ensure_collection_loaded() # Delete vectors by IDs result = self._client.delete(collection_name=self.final_namespace, pks=ids) if result and result.get("delete_count", 0) > 0: logger.debug( f"[{self.workspace}] Successfully deleted {result.get('delete_count', 0)} vectors from {self.namespace}" ) else: logger.debug( f"[{self.workspace}] No vectors were deleted from {self.namespace}" ) except Exception as e: logger.error( f"[{self.workspace}] Error while deleting vectors from {self.namespace}: {e}" ) async def get_by_id(self, id: str) -> dict[str, Any] | None: """Get vector data by its ID Args: id: The unique identifier of the vector Returns: The vector data if found, or None if not found """ try: # Ensure collection is loaded before querying self._ensure_collection_loaded() # Include all meta_fields (created_at is now always included) plus id output_fields = list(self.meta_fields) + ["id"] # Query Milvus for a specific ID result = self._client.query( collection_name=self.final_namespace, filter=f'id == "{id}"', output_fields=output_fields, ) if not result or len(result) == 0: return None return result[0] except Exception as e: logger.error( f"[{self.workspace}] Error retrieving vector data for ID {id}: {e}" ) return None async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: """Get multiple vector data by their IDs Args: ids: List of unique identifiers Returns: List of vector data objects that were found """ if not ids: return [] try: # Ensure collection is loaded before querying self._ensure_collection_loaded() # Include all meta_fields (created_at is now always included) plus id output_fields = list(self.meta_fields) + ["id"] # Prepare the ID filter expression id_list = '", "'.join(ids) filter_expr = f'id in ["{id_list}"]' # Query Milvus with the filter result = self._client.query( collection_name=self.final_namespace, filter=filter_expr, output_fields=output_fields, ) if not result: return [] result_map: dict[str, dict[str, Any]] = {} for row in result: if not row: continue row_id = row.get("id") if row_id is not None: result_map[str(row_id)] = row ordered_results: list[dict[str, Any] | None] = [] for requested_id in ids: ordered_results.append(result_map.get(str(requested_id))) return ordered_results except Exception as e: logger.error( f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}" ) return [] async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]: """Get vectors by their IDs, returning only ID and vector data for efficiency Args: ids: List of unique identifiers Returns: Dictionary mapping IDs to their vector embeddings Format: {id: [vector_values], ...} """ if not ids: return {} try: # Ensure collection is loaded before querying self._ensure_collection_loaded() # Prepare the ID filter expression id_list = '", "'.join(ids) filter_expr = f'id in ["{id_list}"]' # Query Milvus with the filter, requesting only vector field result = self._client.query( collection_name=self.final_namespace, filter=filter_expr, output_fields=["vector"], ) vectors_dict = {} for item in result: if item and "vector" in item and "id" in item: # Convert numpy array to list if needed vector_data = item["vector"] if isinstance(vector_data, np.ndarray): vector_data = vector_data.tolist() vectors_dict[item["id"]] = vector_data return vectors_dict except Exception as e: logger.error( f"[{self.workspace}] Error retrieving vectors by IDs from {self.namespace}: {e}" ) return {} async def drop(self) -> dict[str, str]: """Drop all vector data from storage and clean up resources This method will delete all data from the Milvus collection. Returns: dict[str, str]: Operation status and message - On success: {"status": "success", "message": "data dropped"} - On failure: {"status": "error", "message": ""} """ try: # Drop the collection and recreate it if self._client.has_collection(self.final_namespace): self._client.drop_collection(self.final_namespace) # Recreate the collection self._create_collection_if_not_exist() logger.info( f"[{self.workspace}] Process {os.getpid()} drop Milvus collection {self.namespace}" ) return {"status": "success", "message": "data dropped"} except Exception as e: logger.error( f"[{self.workspace}] Error dropping Milvus collection {self.namespace}: {e}" ) return {"status": "error", "message": str(e)}