feat: 重构知识库系统,移除Hermes集成,增强RAG和同步功能

主要变更:
- 移除Hermes智能体及相关回调服务
- 新增知识库RAG、同步、调度、规范化和索引任务服务
- 重构orchestrator服务,增强运行时聊天功能
- 更新前端聊天、政策制度、设置等页面样式和逻辑
- 更新expense_claims和document_intelligence服务
- 删除llm_wiki相关服务和测试文件
- 更新docker-compose配置和启动脚本
This commit is contained in:
caoxiaozhu
2026-05-17 08:38:41 +00:00
parent 212c935308
commit 68f663f2f4
308 changed files with 83729 additions and 13588 deletions

View File

@@ -0,0 +1,265 @@
from ..utils import verbose_debug, VERBOSE_DEBUG
import sys
import os
import logging
import warnings
from typing import Any, Union, AsyncIterator
import pipmaster as pm # Pipmaster for dynamic library install
if sys.version_info < (3, 9):
from typing import AsyncIterator
else:
from collections.abc import AsyncIterator
# Install Anthropic SDK if not present
if not pm.is_installed("anthropic"):
pm.install("anthropic")
from anthropic import (
AsyncAnthropic,
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import (
safe_unicode_decode,
logger,
)
from lightrag.api import __api_version__
# Custom exception for retry mechanism
class InvalidResponseError(Exception):
"""Custom exception class for triggering retry mechanism"""
pass
# Core Anthropic completion function with retry
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError, InvalidResponseError)
),
)
async def anthropic_complete_if_cache(
model: str,
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
enable_cot: bool = False,
base_url: str | None = None,
api_key: str | None = None,
**kwargs: Any,
) -> Union[str, AsyncIterator[str]]:
if history_messages is None:
history_messages = []
if enable_cot:
logger.debug(
"enable_cot=True is not supported for the Anthropic API and will be ignored."
)
if not api_key:
api_key = os.environ.get("ANTHROPIC_API_KEY")
default_headers = {
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json",
}
# Set logger level to INFO when VERBOSE_DEBUG is off
if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
logging.getLogger("anthropic").setLevel(logging.INFO)
kwargs.pop("hashing_kv", None)
kwargs.pop("keyword_extraction", None)
timeout = kwargs.pop("timeout", None)
anthropic_async_client = (
AsyncAnthropic(
default_headers=default_headers, api_key=api_key, timeout=timeout
)
if base_url is None
else AsyncAnthropic(
base_url=base_url,
default_headers=default_headers,
api_key=api_key,
timeout=timeout,
)
)
messages: list[dict[str, Any]] = []
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
logger.debug("===== Sending Query to Anthropic LLM =====")
logger.debug(f"Model: {model} Base URL: {base_url}")
logger.debug(f"Additional kwargs: {kwargs}")
verbose_debug(f"Query: {prompt}")
verbose_debug(f"System prompt: {system_prompt}")
try:
create_params = {"model": model, "messages": messages, "stream": True, **kwargs}
if system_prompt:
create_params["system"] = system_prompt
response = await anthropic_async_client.messages.create(**create_params)
except APIConnectionError as e:
logger.error(f"Anthropic API Connection Error: {e}")
raise
except RateLimitError as e:
logger.error(f"Anthropic API Rate Limit Error: {e}")
raise
except APITimeoutError as e:
logger.error(f"Anthropic API Timeout Error: {e}")
raise
except Exception as e:
body = getattr(e, "body", None)
request_id = getattr(e, "request_id", None)
req = getattr(e, "request", None)
extra_parts = []
if body:
extra_parts.append(f"Response body: {body}")
if request_id:
extra_parts.append(f"Request ID: {request_id}")
if req is not None:
extra_parts.append(f"Request URL: {req.url}")
extra = ("\n" + "\n".join(extra_parts)) if extra_parts else ""
logger.error(
f"Anthropic API Call Failed,\nModel: {model},\nParams: {kwargs}, Got: {e}{extra}"
)
raise
async def stream_response():
try:
async for event in response:
content = (
event.delta.text
if hasattr(event, "delta")
and hasattr(event.delta, "text")
and event.delta.text
else None
)
if content is None:
continue
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
yield content
except Exception as e:
logger.error(f"Error in stream response: {str(e)}")
raise
return stream_response()
# Generic Anthropic completion function
async def anthropic_complete(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
enable_cot: bool = False,
**kwargs: Any,
) -> Union[str, AsyncIterator[str]]:
if history_messages is None:
history_messages = []
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await anthropic_complete_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
# Claude 3 Opus specific completion
async def claude_3_opus_complete(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
enable_cot: bool = False,
**kwargs: Any,
) -> Union[str, AsyncIterator[str]]:
if history_messages is None:
history_messages = []
return await anthropic_complete_if_cache(
"claude-3-opus-20240229",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
# Claude 3 Sonnet specific completion
async def claude_3_sonnet_complete(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
enable_cot: bool = False,
**kwargs: Any,
) -> Union[str, AsyncIterator[str]]:
if history_messages is None:
history_messages = []
return await anthropic_complete_if_cache(
"claude-3-sonnet-20240229",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
# Claude 3 Haiku specific completion
async def claude_3_haiku_complete(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
enable_cot: bool = False,
**kwargs: Any,
) -> Union[str, AsyncIterator[str]]:
if history_messages is None:
history_messages = []
return await anthropic_complete_if_cache(
"claude-3-haiku-20240307",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
# Backward-compatibility shim: the previous embedding implementation lived in
# this module under the (misleading) name ``anthropic_embed`` even though it
# called Voyage AI under the hood. The real implementation now lives in
# ``lightrag.llm.voyageai.voyageai_embed``. Keep the old name importable for one
# release cycle so downstream users get a clear deprecation warning instead of
# an ImportError. Remove in a future major version.
def anthropic_embed(*args, **kwargs):
"""Deprecated alias for :func:`lightrag.llm.voyageai.voyageai_embed`.
This shim accepts the same arguments as the original ``anthropic_embed``
function (which was always backed by VoyageAI) and forwards them to
:func:`voyageai_embed`. It will be removed in a future release.
"""
warnings.warn(
"lightrag.llm.anthropic.anthropic_embed is deprecated and will be "
"removed in a future release. Import "
"lightrag.llm.voyageai.voyageai_embed instead.",
DeprecationWarning,
stacklevel=2,
)
from lightrag.llm.voyageai import voyageai_embed
return voyageai_embed.func(*args, **kwargs)

View File

@@ -0,0 +1,22 @@
"""
Azure OpenAI compatibility layer.
This module provides backward compatibility by re-exporting Azure OpenAI functions
from the main openai module where the actual implementation resides.
All core logic for both OpenAI and Azure OpenAI now lives in lightrag.llm.openai,
with this module serving as a thin compatibility wrapper for existing code that
imports from lightrag.llm.azure_openai.
"""
from lightrag.llm.openai import (
azure_openai_complete_if_cache,
azure_openai_complete,
azure_openai_embed,
)
__all__ = [
"azure_openai_complete_if_cache",
"azure_openai_complete",
"azure_openai_embed",
]

View File

@@ -0,0 +1,485 @@
import copy
import os
import json
import logging
import pipmaster as pm # Pipmaster for dynamic library install
if not pm.is_installed("aioboto3"):
pm.install("aioboto3")
import aioboto3
import numpy as np
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
import sys
from lightrag.utils import wrap_embedding_func_with_attrs
if sys.version_info < (3, 9):
from typing import AsyncIterator
else:
from collections.abc import AsyncIterator
from typing import Union
# Import botocore exceptions for proper exception handling
try:
from botocore.exceptions import (
ClientError,
ConnectionError as BotocoreConnectionError,
ReadTimeoutError,
)
except ImportError:
# If botocore is not installed, define placeholders
ClientError = Exception
BotocoreConnectionError = Exception
ReadTimeoutError = Exception
class BedrockError(Exception):
"""Generic error for issues related to Amazon Bedrock"""
class BedrockRateLimitError(BedrockError):
"""Error for rate limiting and throttling issues"""
class BedrockConnectionError(BedrockError):
"""Error for network and connection issues"""
class BedrockTimeoutError(BedrockError):
"""Error for timeout issues"""
def _set_env_if_present(key: str, value):
"""Set environment variable only if a non-empty value is provided."""
if value is not None and value != "":
os.environ[key] = value
def _handle_bedrock_exception(e: Exception, operation: str = "Bedrock API") -> None:
"""Convert AWS Bedrock exceptions to appropriate custom exceptions.
Args:
e: The exception to handle
operation: Description of the operation for error messages
Raises:
BedrockRateLimitError: For rate limiting and throttling issues (retryable)
BedrockConnectionError: For network and server issues (retryable)
BedrockTimeoutError: For timeout issues (retryable)
BedrockError: For validation and other non-retryable errors
"""
error_message = str(e)
# Handle botocore ClientError with specific error codes
if isinstance(e, ClientError):
error_code = e.response.get("Error", {}).get("Code", "")
error_msg = e.response.get("Error", {}).get("Message", error_message)
# Rate limiting and throttling errors (retryable)
if error_code in [
"ThrottlingException",
"ProvisionedThroughputExceededException",
]:
logging.error(f"{operation} rate limit error: {error_msg}")
raise BedrockRateLimitError(f"Rate limit error: {error_msg}")
# Server errors (retryable)
elif error_code in ["ServiceUnavailableException", "InternalServerException"]:
logging.error(f"{operation} connection error: {error_msg}")
raise BedrockConnectionError(f"Service error: {error_msg}")
# Check for 5xx HTTP status codes (retryable)
elif e.response.get("ResponseMetadata", {}).get("HTTPStatusCode", 0) >= 500:
logging.error(f"{operation} server error: {error_msg}")
raise BedrockConnectionError(f"Server error: {error_msg}")
# Validation and other client errors (non-retryable)
else:
logging.error(f"{operation} client error: {error_msg}")
raise BedrockError(f"Client error: {error_msg}")
# Connection errors (retryable)
elif isinstance(e, BotocoreConnectionError):
logging.error(f"{operation} connection error: {error_message}")
raise BedrockConnectionError(f"Connection error: {error_message}")
# Timeout errors (retryable)
elif isinstance(e, (ReadTimeoutError, TimeoutError)):
logging.error(f"{operation} timeout error: {error_message}")
raise BedrockTimeoutError(f"Timeout error: {error_message}")
# Custom Bedrock errors (already properly typed)
elif isinstance(
e,
(
BedrockRateLimitError,
BedrockConnectionError,
BedrockTimeoutError,
BedrockError,
),
):
raise
# Unknown errors (non-retryable)
else:
logging.error(f"{operation} unexpected error: {error_message}")
raise BedrockError(f"Unexpected error: {error_message}")
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=(
retry_if_exception_type(BedrockRateLimitError)
| retry_if_exception_type(BedrockConnectionError)
| retry_if_exception_type(BedrockTimeoutError)
),
)
async def bedrock_complete_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
enable_cot: bool = False,
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
**kwargs,
) -> Union[str, AsyncIterator[str]]:
if enable_cot:
import logging
logging.debug(
"enable_cot=True is not supported for Bedrock and will be ignored."
)
# Respect existing env; only set if a non-empty value is available
access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id
secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key
session_token = os.environ.get("AWS_SESSION_TOKEN") or aws_session_token
_set_env_if_present("AWS_ACCESS_KEY_ID", access_key)
_set_env_if_present("AWS_SECRET_ACCESS_KEY", secret_key)
_set_env_if_present("AWS_SESSION_TOKEN", session_token)
# Region handling: prefer env, else kwarg (optional)
region = os.environ.get("AWS_REGION") or kwargs.pop("aws_region", None)
kwargs.pop("hashing_kv", None)
# Capture stream flag (if provided) and remove from kwargs since it's not a Bedrock API parameter
# We'll use this to determine whether to call converse_stream or converse
stream = bool(kwargs.pop("stream", False))
# Remove unsupported args for Bedrock Converse API
for k in [
"response_format",
"tools",
"tool_choice",
"seed",
"presence_penalty",
"frequency_penalty",
"n",
"logprobs",
"top_logprobs",
"max_completion_tokens",
"response_format",
]:
kwargs.pop(k, None)
# Fix message history format
messages = []
for history_message in history_messages:
message = copy.copy(history_message)
message["content"] = [{"text": message["content"]}]
messages.append(message)
# Add user prompt
messages.append({"role": "user", "content": [{"text": prompt}]})
# Initialize Converse API arguments
args = {"modelId": model, "messages": messages}
# Define system prompt
if system_prompt:
args["system"] = [{"text": system_prompt}]
# Map and set up inference parameters
inference_params_map = {
"max_tokens": "maxTokens",
"top_p": "topP",
"stop_sequences": "stopSequences",
}
if inference_params := list(
set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
):
args["inferenceConfig"] = {}
for param in inference_params:
args["inferenceConfig"][inference_params_map.get(param, param)] = (
kwargs.pop(param)
)
# Import logging for error handling
import logging
# For streaming responses, we need a different approach to keep the connection open
if stream:
# Create a session that will be used throughout the streaming process
session = aioboto3.Session()
client = None
# Define the generator function that will manage the client lifecycle
async def stream_generator():
nonlocal client
# Create the client outside the generator to ensure it stays open
client = await session.client(
"bedrock-runtime", region_name=region
).__aenter__()
event_stream = None
iteration_started = False
try:
# Make the API call
response = await client.converse_stream(**args, **kwargs)
event_stream = response.get("stream")
iteration_started = True
# Process the stream
async for event in event_stream:
# Validate event structure
if not event or not isinstance(event, dict):
continue
if "contentBlockDelta" in event:
delta = event["contentBlockDelta"].get("delta", {})
text = delta.get("text")
if text:
yield text
# Handle other event types that might indicate stream end
elif "messageStop" in event:
break
except Exception as e:
# Try to clean up resources if possible
if (
iteration_started
and event_stream
and hasattr(event_stream, "aclose")
and callable(getattr(event_stream, "aclose", None))
):
try:
await event_stream.aclose()
except Exception as close_error:
logging.warning(
f"Failed to close Bedrock event stream: {close_error}"
)
# Convert to appropriate exception type
_handle_bedrock_exception(e, "Bedrock streaming")
finally:
# Clean up the event stream
if (
iteration_started
and event_stream
and hasattr(event_stream, "aclose")
and callable(getattr(event_stream, "aclose", None))
):
try:
await event_stream.aclose()
except Exception as close_error:
logging.warning(
f"Failed to close Bedrock event stream in finally block: {close_error}"
)
# Clean up the client
if client:
try:
await client.__aexit__(None, None, None)
except Exception as client_close_error:
logging.warning(
f"Failed to close Bedrock client: {client_close_error}"
)
# Return the generator that manages its own lifecycle
return stream_generator()
# For non-streaming responses, use the standard async context manager pattern
session = aioboto3.Session()
async with session.client(
"bedrock-runtime", region_name=region
) as bedrock_async_client:
try:
# Use converse for non-streaming responses
response = await bedrock_async_client.converse(**args, **kwargs)
# Validate response structure
if (
not response
or "output" not in response
or "message" not in response["output"]
or "content" not in response["output"]["message"]
or not response["output"]["message"]["content"]
):
raise BedrockError("Invalid response structure from Bedrock API")
content = response["output"]["message"]["content"][0]["text"]
if not content or content.strip() == "":
raise BedrockError("Received empty content from Bedrock API")
return content
except Exception as e:
# Convert to appropriate exception type
_handle_bedrock_exception(e, "Bedrock converse")
# Generic Bedrock completion function
async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> Union[str, AsyncIterator[str]]:
kwargs.pop("keyword_extraction", None)
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
result = await bedrock_complete_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
return result
@wrap_embedding_func_with_attrs(
embedding_dim=1024, max_token_size=8192, model_name="amazon.titan-embed-text-v2:0"
)
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=(
retry_if_exception_type(BedrockRateLimitError)
| retry_if_exception_type(BedrockConnectionError)
| retry_if_exception_type(BedrockTimeoutError)
),
)
async def bedrock_embed(
texts: list[str],
model: str = "amazon.titan-embed-text-v2:0",
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
) -> np.ndarray:
# Respect existing env; only set if a non-empty value is available
access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id
secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key
session_token = os.environ.get("AWS_SESSION_TOKEN") or aws_session_token
_set_env_if_present("AWS_ACCESS_KEY_ID", access_key)
_set_env_if_present("AWS_SECRET_ACCESS_KEY", secret_key)
_set_env_if_present("AWS_SESSION_TOKEN", session_token)
# Region handling: prefer env
region = os.environ.get("AWS_REGION")
session = aioboto3.Session()
async with session.client(
"bedrock-runtime", region_name=region
) as bedrock_async_client:
try:
if (model_provider := model.split(".")[0]) == "amazon":
embed_texts = []
for text in texts:
try:
if "v2" in model:
body = json.dumps(
{
"inputText": text,
# 'dimensions': embedding_dim,
"embeddingTypes": ["float"],
}
)
elif "v1" in model:
body = json.dumps({"inputText": text})
else:
raise BedrockError(f"Model {model} is not supported!")
response = await bedrock_async_client.invoke_model(
modelId=model,
body=body,
accept="application/json",
contentType="application/json",
)
response_body = await response.get("body").json()
# Validate response structure
if not response_body or "embedding" not in response_body:
raise BedrockError(
f"Invalid embedding response structure for text: {text[:50]}..."
)
embedding = response_body["embedding"]
if not embedding:
raise BedrockError(
f"Received empty embedding for text: {text[:50]}..."
)
embed_texts.append(embedding)
except Exception as e:
# Convert to appropriate exception type
_handle_bedrock_exception(
e, "Bedrock embedding (amazon, text chunk)"
)
elif model_provider == "cohere":
try:
body = json.dumps(
{
"texts": texts,
"input_type": "search_document",
"truncate": "NONE",
}
)
response = await bedrock_async_client.invoke_model(
model=model,
body=body,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
# Validate response structure
if not response_body or "embeddings" not in response_body:
raise BedrockError(
"Invalid embedding response structure from Cohere"
)
embeddings = response_body["embeddings"]
if not embeddings or len(embeddings) != len(texts):
raise BedrockError(
f"Invalid embeddings count: expected {len(texts)}, got {len(embeddings) if embeddings else 0}"
)
embed_texts = embeddings
except Exception as e:
# Convert to appropriate exception type
_handle_bedrock_exception(e, "Bedrock embedding (cohere)")
else:
raise BedrockError(
f"Model provider '{model_provider}' is not supported!"
)
# Final validation
if not embed_texts:
raise BedrockError("No embeddings generated")
return np.array(embed_texts)
except Exception as e:
# Convert to appropriate exception type
_handle_bedrock_exception(e, "Bedrock embedding")

View File

@@ -0,0 +1,740 @@
"""
Module that implements containers for specific LLM bindings.
This module provides container implementations for various Large Language Model
bindings and integrations.
"""
from argparse import ArgumentParser, Namespace
import argparse
import json
from dataclasses import asdict, dataclass, field
from typing import Any, ClassVar, List, get_args, get_origin
from lightrag.utils import get_env_value
from lightrag.constants import DEFAULT_TEMPERATURE
def _resolve_optional_type(field_type: Any) -> Any:
"""Return the concrete type for Optional/Union annotations."""
origin = get_origin(field_type)
if origin in (list, dict, tuple):
return field_type
args = get_args(field_type)
if args:
non_none_args = [arg for arg in args if arg is not type(None)]
if len(non_none_args) == 1:
return non_none_args[0]
return field_type
# =============================================================================
# BindingOptions Base Class
# =============================================================================
#
# The BindingOptions class serves as the foundation for all LLM provider bindings
# in LightRAG. It provides a standardized framework for:
#
# 1. Configuration Management:
# - Defines how each LLM provider's configuration parameters are structured
# - Handles default values and type information for each parameter
# - Maps configuration options to command-line arguments and environment variables
#
# 2. Environment Integration:
# - Automatically generates environment variable names from binding parameters
# - Provides methods to create sample .env files for easy configuration
# - Supports configuration via environment variables with fallback to defaults
#
# 3. Command-Line Interface:
# - Dynamically generates command-line arguments for all registered bindings
# - Maintains consistent naming conventions across different LLM providers
# - Provides help text and type validation for each configuration option
#
# 4. Extensibility:
# - Uses class introspection to automatically discover all binding subclasses
# - Requires minimal boilerplate code when adding new LLM provider bindings
# - Maintains separation of concerns between different provider configurations
#
# This design pattern ensures that adding support for a new LLM provider requires
# only defining the provider-specific parameters and help text, while the base
# class handles all the common functionality for argument parsing, environment
# variable handling, and configuration management.
#
# Instances of a derived class of BindingOptions can be used to store multiple
# runtime configurations of options for a single LLM provider. using the
# asdict() method to convert the options to a dictionary.
#
# =============================================================================
@dataclass
class BindingOptions:
"""Base class for binding options."""
# mandatory name of binding
_binding_name: ClassVar[str]
# optional help message for each option
_help: ClassVar[dict[str, str]]
@staticmethod
def _all_class_vars(klass: type, include_inherited=True) -> dict[str, Any]:
"""Print class variables, optionally including inherited ones"""
if include_inherited:
# Get all class variables from MRO
vars_dict = {}
for base in reversed(klass.__mro__[:-1]): # Exclude 'object'
vars_dict.update(
{
k: v
for k, v in base.__dict__.items()
if (
not k.startswith("_")
and not callable(v)
and not isinstance(v, classmethod)
)
}
)
else:
# Only direct class variables
vars_dict = {
k: v
for k, v in klass.__dict__.items()
if (
not k.startswith("_")
and not callable(v)
and not isinstance(v, classmethod)
)
}
return vars_dict
@classmethod
def add_args(cls, parser: ArgumentParser):
group = parser.add_argument_group(f"{cls._binding_name} binding options")
for arg_item in cls.args_env_name_type_value():
# Handle JSON parsing for list types
if arg_item["type"] is List[str]:
def json_list_parser(value):
try:
parsed = json.loads(value)
if not isinstance(parsed, list):
raise argparse.ArgumentTypeError(
f"Expected JSON array, got {type(parsed).__name__}"
)
return parsed
except json.JSONDecodeError as e:
raise argparse.ArgumentTypeError(f"Invalid JSON: {e}")
# Get environment variable with JSON parsing
env_value = get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS)
if env_value is not argparse.SUPPRESS:
try:
env_value = json_list_parser(env_value)
except argparse.ArgumentTypeError:
env_value = argparse.SUPPRESS
group.add_argument(
f"--{arg_item['argname']}",
type=json_list_parser,
default=env_value,
help=arg_item["help"],
)
# Handle JSON parsing for dict types
elif arg_item["type"] is dict:
def json_dict_parser(value):
try:
parsed = json.loads(value)
if not isinstance(parsed, dict):
raise argparse.ArgumentTypeError(
f"Expected JSON object, got {type(parsed).__name__}"
)
return parsed
except json.JSONDecodeError as e:
raise argparse.ArgumentTypeError(f"Invalid JSON: {e}")
# Get environment variable with JSON parsing
env_value = get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS)
if env_value is not argparse.SUPPRESS:
try:
env_value = json_dict_parser(env_value)
except argparse.ArgumentTypeError:
env_value = argparse.SUPPRESS
group.add_argument(
f"--{arg_item['argname']}",
type=json_dict_parser,
default=env_value,
help=arg_item["help"],
)
# Handle boolean types specially to avoid argparse bool() constructor issues
elif arg_item["type"] is bool:
def bool_parser(value):
"""Custom boolean parser that handles string representations correctly"""
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.lower() in ("true", "1", "yes", "t", "on")
return bool(value)
# Get environment variable with proper type conversion
env_value = get_env_value(
f"{arg_item['env_name']}", argparse.SUPPRESS, bool
)
group.add_argument(
f"--{arg_item['argname']}",
type=bool_parser,
default=env_value,
help=arg_item["help"],
)
else:
resolved_type = arg_item["type"]
if resolved_type is not None:
resolved_type = _resolve_optional_type(resolved_type)
group.add_argument(
f"--{arg_item['argname']}",
type=resolved_type,
default=get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS),
help=arg_item["help"],
)
@classmethod
def args_env_name_type_value(cls):
import dataclasses
args_prefix = f"{cls._binding_name}".replace("_", "-")
env_var_prefix = f"{cls._binding_name}_".upper()
help = cls._help
# Check if this is a dataclass and use dataclass fields
if dataclasses.is_dataclass(cls):
for field in dataclasses.fields(cls):
# Skip private fields
if field.name.startswith("_"):
continue
# Get default value
if field.default is not dataclasses.MISSING:
default_value = field.default
elif field.default_factory is not dataclasses.MISSING:
default_value = field.default_factory()
else:
default_value = None
argdef = {
"argname": f"{args_prefix}-{field.name}",
"env_name": f"{env_var_prefix}{field.name.upper()}",
"type": _resolve_optional_type(field.type),
"default": default_value,
"help": f"{cls._binding_name} -- " + help.get(field.name, ""),
}
yield argdef
else:
# Fallback to old method for non-dataclass classes
class_vars = {
key: value
for key, value in cls._all_class_vars(cls).items()
if not callable(value) and not key.startswith("_")
}
# Get type hints to properly detect List[str] types
type_hints = {}
for base in cls.__mro__:
if hasattr(base, "__annotations__"):
type_hints.update(base.__annotations__)
for class_var in class_vars:
# Use type hint if available, otherwise fall back to type of value
var_type = type_hints.get(class_var, type(class_vars[class_var]))
argdef = {
"argname": f"{args_prefix}-{class_var}",
"env_name": f"{env_var_prefix}{class_var.upper()}",
"type": var_type,
"default": class_vars[class_var],
"help": f"{cls._binding_name} -- " + help.get(class_var, ""),
}
yield argdef
@classmethod
def generate_dot_env_sample(cls):
"""
Generate a sample .env file for all LightRAG binding options.
This method creates a .env file that includes all the binding options
defined by the subclasses of BindingOptions. It uses the args_env_name_type_value()
method to get the list of all options and their default values.
Returns:
str: A string containing the contents of the sample .env file.
"""
from io import StringIO
sample_top = (
"#" * 80
+ "\n"
+ (
"# Autogenerated .env entries list for LightRAG binding options\n"
"#\n"
"# To generate run:\n"
"# $ python -m lightrag.llm.binding_options\n"
)
+ "#" * 80
+ "\n"
)
sample_bottom = (
("#\n# End of .env entries for LightRAG binding options\n")
+ "#" * 80
+ "\n"
)
sample_stream = StringIO()
sample_stream.write(sample_top)
for klass in cls.__subclasses__():
for arg_item in klass.args_env_name_type_value():
if arg_item["help"]:
sample_stream.write(f"# {arg_item['help']}\n")
# Handle JSON formatting for list and dict types
if arg_item["type"] is List[str] or arg_item["type"] is dict:
default_value = json.dumps(arg_item["default"])
else:
default_value = arg_item["default"]
sample_stream.write(f"# {arg_item['env_name']}={default_value}\n\n")
sample_stream.write(sample_bottom)
return sample_stream.getvalue()
@classmethod
def options_dict(cls, args: Namespace) -> dict[str, Any]:
"""
Extract options dictionary for a specific binding from parsed arguments.
This method filters the parsed command-line arguments to return only those
that belong to the specific binding class. It removes the binding prefix
from argument names to create a clean options dictionary.
Args:
args (Namespace): Parsed command-line arguments containing all binding options
Returns:
dict[str, Any]: Dictionary mapping option names (without prefix) to their values
Example:
If args contains {'ollama_num_ctx': 512, 'other_option': 'value'}
and this is called on OllamaOptions, it returns {'num_ctx': 512}
"""
prefix = cls._binding_name + "_"
skipchars = len(prefix)
options = {
key[skipchars:]: value
for key, value in vars(args).items()
if key.startswith(prefix)
}
return options
def asdict(self) -> dict[str, Any]:
"""
Convert an instance of binding options to a dictionary.
This method uses dataclasses.asdict() to convert the dataclass instance
into a dictionary representation, including all its fields and values.
Returns:
dict[str, Any]: Dictionary representation of the binding options instance
"""
return asdict(self)
# =============================================================================
# Binding Options for Ollama
# =============================================================================
#
# Ollama binding options provide configuration for the Ollama local LLM server.
# These options control model behavior, sampling parameters, hardware utilization,
# and performance settings. The parameters are based on Ollama's API specification
# and provide fine-grained control over model inference and generation.
#
# The _OllamaOptionsMixin defines the complete set of available options, while
# OllamaEmbeddingOptions and OllamaLLMOptions provide specialized configurations
# for embedding and language model tasks respectively.
# =============================================================================
@dataclass
class _OllamaOptionsMixin:
"""Options for Ollama bindings."""
# Core context and generation parameters
num_ctx: int = 32768 # Context window size (number of tokens)
num_predict: int = 128 # Maximum number of tokens to predict
num_keep: int = 0 # Number of tokens to keep from the initial prompt
seed: int = -1 # Random seed for generation (-1 for random)
# Sampling parameters
temperature: float = DEFAULT_TEMPERATURE # Controls randomness (0.0-2.0)
top_k: int = 40 # Top-k sampling parameter
top_p: float = 0.9 # Top-p (nucleus) sampling parameter
tfs_z: float = 1.0 # Tail free sampling parameter
typical_p: float = 1.0 # Typical probability mass
min_p: float = 0.0 # Minimum probability threshold
# Repetition control
repeat_last_n: int = 64 # Number of tokens to consider for repetition penalty
repeat_penalty: float = 1.1 # Penalty for repetition
presence_penalty: float = 0.0 # Penalty for token presence
frequency_penalty: float = 0.0 # Penalty for token frequency
# Mirostat sampling
mirostat: int = (
# Mirostat sampling algorithm (0=disabled, 1=Mirostat 1.0, 2=Mirostat 2.0)
0
)
mirostat_tau: float = 5.0 # Mirostat target entropy
mirostat_eta: float = 0.1 # Mirostat learning rate
# Hardware and performance parameters
numa: bool = False # Enable NUMA optimization
num_batch: int = 512 # Batch size for processing
num_gpu: int = -1 # Number of GPUs to use (-1 for auto)
main_gpu: int = 0 # Main GPU index
low_vram: bool = False # Optimize for low VRAM
num_thread: int = 0 # Number of CPU threads (0 for auto)
# Memory and model parameters
f16_kv: bool = True # Use half-precision for key/value cache
logits_all: bool = False # Return logits for all tokens
vocab_only: bool = False # Only load vocabulary
use_mmap: bool = True # Use memory mapping for model files
use_mlock: bool = False # Lock model in memory
embedding_only: bool = False # Only use for embeddings
# Output control
penalize_newline: bool = True # Penalize newline tokens
stop: List[str] = field(default_factory=list) # Stop sequences
# optional help strings
_help: ClassVar[dict[str, str]] = {
"num_ctx": "Context window size (number of tokens)",
"num_predict": "Maximum number of tokens to predict",
"num_keep": "Number of tokens to keep from the initial prompt",
"seed": "Random seed for generation (-1 for random)",
"temperature": "Controls randomness (0.0-2.0, higher = more creative)",
"top_k": "Top-k sampling parameter (0 = disabled)",
"top_p": "Top-p (nucleus) sampling parameter (0.0-1.0)",
"tfs_z": "Tail free sampling parameter (1.0 = disabled)",
"typical_p": "Typical probability mass (1.0 = disabled)",
"min_p": "Minimum probability threshold (0.0 = disabled)",
"repeat_last_n": "Number of tokens to consider for repetition penalty",
"repeat_penalty": "Penalty for repetition (1.0 = no penalty)",
"presence_penalty": "Penalty for token presence (-2.0 to 2.0)",
"frequency_penalty": "Penalty for token frequency (-2.0 to 2.0)",
"mirostat": "Mirostat sampling algorithm (0=disabled, 1=Mirostat 1.0, 2=Mirostat 2.0)",
"mirostat_tau": "Mirostat target entropy",
"mirostat_eta": "Mirostat learning rate",
"numa": "Enable NUMA optimization",
"num_batch": "Batch size for processing",
"num_gpu": "Number of GPUs to use (-1 for auto)",
"main_gpu": "Main GPU index",
"low_vram": "Optimize for low VRAM",
"num_thread": "Number of CPU threads (0 for auto)",
"f16_kv": "Use half-precision for key/value cache",
"logits_all": "Return logits for all tokens",
"vocab_only": "Only load vocabulary",
"use_mmap": "Use memory mapping for model files",
"use_mlock": "Lock model in memory",
"embedding_only": "Only use for embeddings",
"penalize_newline": "Penalize newline tokens",
"stop": 'Stop sequences (JSON array of strings, e.g., \'["</s>", "\\n\\n"]\')',
}
@dataclass
class OllamaEmbeddingOptions(_OllamaOptionsMixin, BindingOptions):
"""Options for Ollama embeddings with specialized configuration for embedding tasks."""
# mandatory name of binding
_binding_name: ClassVar[str] = "ollama_embedding"
@dataclass
class OllamaLLMOptions(_OllamaOptionsMixin, BindingOptions):
"""Options for Ollama LLM with specialized configuration for LLM tasks."""
# mandatory name of binding
_binding_name: ClassVar[str] = "ollama_llm"
# =============================================================================
# Binding Options for Gemini
# =============================================================================
@dataclass
class GeminiLLMOptions(BindingOptions):
"""Options for Google Gemini models."""
_binding_name: ClassVar[str] = "gemini_llm"
temperature: float = DEFAULT_TEMPERATURE
top_p: float = 0.95
top_k: int = 40
max_output_tokens: int | None = None
candidate_count: int = 1
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
stop_sequences: List[str] = field(default_factory=list)
seed: int | None = None
thinking_config: dict | None = None
safety_settings: dict | None = None
_help: ClassVar[dict[str, str]] = {
"temperature": "Controls randomness (0.0-2.0, higher = more creative)",
"top_p": "Nucleus sampling parameter (0.0-1.0)",
"top_k": "Limits sampling to the top K tokens (1 disables the limit)",
"max_output_tokens": "Maximum tokens generated in the response",
"candidate_count": "Number of candidates returned per request",
"presence_penalty": "Penalty for token presence (-2.0 to 2.0)",
"frequency_penalty": "Penalty for token frequency (-2.0 to 2.0)",
"stop_sequences": "Stop sequences (JSON array of strings, e.g., '[\"END\"]')",
"seed": "Random seed for reproducible generation (leave empty for random)",
"thinking_config": "Thinking configuration (JSON dict, e.g., '{\"thinking_budget\": 1024}' or '{\"include_thoughts\": true}')",
"safety_settings": "JSON object with Gemini safety settings overrides",
}
@dataclass
class GeminiEmbeddingOptions(BindingOptions):
"""Options for Google Gemini embedding models."""
_binding_name: ClassVar[str] = "gemini_embedding"
task_type: str | None = None
_help: ClassVar[dict[str, str]] = {
"task_type": "Task type for embedding optimization. If not specified, automatically determined from context (RETRIEVAL_QUERY for queries, RETRIEVAL_DOCUMENT for documents). Supported types: RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING, CODE_RETRIEVAL_QUERY, QUESTION_ANSWERING, FACT_VERIFICATION",
}
# =============================================================================
# Binding Options for OpenAI
# =============================================================================
#
# OpenAI binding options provide configuration for OpenAI's API and Azure OpenAI.
# These options control model behavior, sampling parameters, and generation settings.
# The parameters are based on OpenAI's API specification and provide fine-grained
# control over model inference and generation.
#
# =============================================================================
@dataclass
class OpenAILLMOptions(BindingOptions):
"""Options for OpenAI LLM with configuration for OpenAI and Azure OpenAI API calls."""
# mandatory name of binding
_binding_name: ClassVar[str] = "openai_llm"
# Sampling and generation parameters
frequency_penalty: float = 0.0 # Penalty for token frequency (-2.0 to 2.0)
max_completion_tokens: int = None # Maximum number of tokens to generate
presence_penalty: float = 0.0 # Penalty for token presence (-2.0 to 2.0)
reasoning_effort: str = "medium" # Reasoning effort level (low, medium, high)
safety_identifier: str = "" # Safety identifier for content filtering
service_tier: str = "" # Service tier for API usage
stop: List[str] = field(default_factory=list) # Stop sequences
temperature: float = DEFAULT_TEMPERATURE # Controls randomness (0.0 to 2.0)
top_p: float = 1.0 # Nucleus sampling parameter (0.0 to 1.0)
max_tokens: int = None # Maximum number of tokens to generate(deprecated, use max_completion_tokens instead)
extra_body: dict = None # Extra body parameters for OpenRouter of vLLM
# Help descriptions
_help: ClassVar[dict[str, str]] = {
"frequency_penalty": "Penalty for token frequency (-2.0 to 2.0, positive values discourage repetition)",
"max_completion_tokens": "Maximum number of tokens to generate (optional, leave empty for model default)",
"presence_penalty": "Penalty for token presence (-2.0 to 2.0, positive values encourage new topics)",
"reasoning_effort": "Reasoning effort level for o1 models (low, medium, high)",
"safety_identifier": "Safety identifier for content filtering (optional)",
"service_tier": "Service tier for API usage (optional)",
"stop": 'Stop sequences (JSON array of strings, e.g., \'["</s>", "\\n\\n"]\')',
"temperature": "Controls randomness (0.0-2.0, higher = more creative)",
"top_p": "Nucleus sampling parameter (0.0-1.0, lower = more focused)",
"max_tokens": "Maximum number of tokens to generate (deprecated, use max_completion_tokens instead)",
"extra_body": 'Extra body parameters for OpenRouter of vLLM (JSON dict, e.g., \'"reasoning": {"reasoning": {"enabled": false}}\')',
}
# =============================================================================
# Main Section - For Testing and Sample Generation
# =============================================================================
#
# When run as a script, this module:
# 1. Generates and prints a sample .env file with all binding options
# 2. If "test" argument is provided, demonstrates argument parsing with Ollama binding
#
# Usage:
# python -m lightrag.llm.binding_options # Generate .env sample
# python -m lightrag.llm.binding_options test # Test argument parsing
#
# =============================================================================
if __name__ == "__main__":
import sys
import dotenv
# from io import StringIO
dotenv.load_dotenv(dotenv_path=".env", override=False)
# env_strstream = StringIO(
# ("OLLAMA_LLM_TEMPERATURE=0.1\nOLLAMA_EMBEDDING_TEMPERATURE=0.2\n")
# )
# # Load environment variables from .env file
# dotenv.load_dotenv(stream=env_strstream)
if len(sys.argv) > 1 and sys.argv[1] == "test":
# Add arguments for OllamaEmbeddingOptions, OllamaLLMOptions, and OpenAILLMOptions
parser = ArgumentParser(description="Test binding options")
OllamaEmbeddingOptions.add_args(parser)
OllamaLLMOptions.add_args(parser)
OpenAILLMOptions.add_args(parser)
# Parse arguments test
args = parser.parse_args(
[
"--ollama-embedding-num_ctx",
"1024",
"--ollama-llm-num_ctx",
"2048",
"--openai-llm-temperature",
"0.7",
"--openai-llm-max_completion_tokens",
"1000",
"--openai-llm-stop",
'["</s>", "\\n\\n"]',
"--openai-llm-reasoning",
'{"effort": "high", "max_tokens": 2000, "exclude": false, "enabled": true}',
]
)
print("Final args for LLM and Embedding:")
print(f"{args}\n")
print("Ollama LLM options:")
print(OllamaLLMOptions.options_dict(args))
print("\nOllama Embedding options:")
print(OllamaEmbeddingOptions.options_dict(args))
print("\nOpenAI LLM options:")
print(OpenAILLMOptions.options_dict(args))
# Test creating OpenAI options instance
openai_options = OpenAILLMOptions(
temperature=0.8,
max_completion_tokens=1500,
frequency_penalty=0.1,
presence_penalty=0.2,
stop=["<|end|>", "\n\n"],
)
print("\nOpenAI LLM options instance:")
print(openai_options.asdict())
# Test creating OpenAI options instance with reasoning parameter
openai_options_with_reasoning = OpenAILLMOptions(
temperature=0.9,
max_completion_tokens=2000,
reasoning={
"effort": "medium",
"max_tokens": 1500,
"exclude": True,
"enabled": True,
},
)
print("\nOpenAI LLM options instance with reasoning:")
print(openai_options_with_reasoning.asdict())
# Test dict parsing functionality
print("\n" + "=" * 50)
print("TESTING DICT PARSING FUNCTIONALITY")
print("=" * 50)
# Test valid JSON dict parsing
test_parser = ArgumentParser(description="Test dict parsing")
OpenAILLMOptions.add_args(test_parser)
try:
test_args = test_parser.parse_args(
["--openai-llm-reasoning", '{"effort": "low", "max_tokens": 1000}']
)
print("✓ Valid JSON dict parsing successful:")
print(
f" Parsed reasoning: {OpenAILLMOptions.options_dict(test_args)['reasoning']}"
)
except Exception as e:
print(f"✗ Valid JSON dict parsing failed: {e}")
# Test invalid JSON dict parsing
try:
test_args = test_parser.parse_args(
[
"--openai-llm-reasoning",
'{"effort": "low", "max_tokens": 1000', # Missing closing brace
]
)
print("✗ Invalid JSON should have failed but didn't")
except SystemExit:
print("✓ Invalid JSON dict parsing correctly rejected")
except Exception as e:
print(f"✓ Invalid JSON dict parsing correctly rejected: {e}")
# Test non-dict JSON parsing
try:
test_args = test_parser.parse_args(
[
"--openai-llm-reasoning",
'["not", "a", "dict"]', # Array instead of dict
]
)
print("✗ Non-dict JSON should have failed but didn't")
except SystemExit:
print("✓ Non-dict JSON parsing correctly rejected")
except Exception as e:
print(f"✓ Non-dict JSON parsing correctly rejected: {e}")
print("\n" + "=" * 50)
print("TESTING ENVIRONMENT VARIABLE SUPPORT")
print("=" * 50)
# Test environment variable support for dict
import os
os.environ["OPENAI_LLM_REASONING"] = (
'{"effort": "high", "max_tokens": 3000, "exclude": false}'
)
env_parser = ArgumentParser(description="Test env var dict parsing")
OpenAILLMOptions.add_args(env_parser)
try:
env_args = env_parser.parse_args(
[]
) # No command line args, should use env var
reasoning_from_env = OpenAILLMOptions.options_dict(env_args).get(
"reasoning"
)
if reasoning_from_env:
print("✓ Environment variable dict parsing successful:")
print(f" Parsed reasoning from env: {reasoning_from_env}")
else:
print("✗ Environment variable dict parsing failed: No reasoning found")
except Exception as e:
print(f"✗ Environment variable dict parsing failed: {e}")
finally:
# Clean up environment variable
if "OPENAI_LLM_REASONING" in os.environ:
del os.environ["OPENAI_LLM_REASONING"]
else:
print(BindingOptions.generate_dot_env_sample())

View File

@@ -0,0 +1,69 @@
import sys
if sys.version_info < (3, 9):
pass
else:
pass
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("lmdeploy"):
pm.install("lmdeploy")
from openai import (
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
import numpy as np
import aiohttp
import base64
import struct
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def siliconcloud_embedding(
texts: list[str],
model: str = "netease-youdao/bce-embedding-base_v1",
base_url: str = "https://api.siliconflow.cn/v1/embeddings",
max_token_size: int = 8192,
api_key: str = None,
) -> np.ndarray:
if api_key and not api_key.startswith("Bearer "):
api_key = "Bearer " + api_key
headers = {"Authorization": api_key, "Content-Type": "application/json"}
truncate_texts = [text[0:max_token_size] for text in texts]
payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
base64_strings = []
async with aiohttp.ClientSession() as session:
async with session.post(base_url, headers=headers, json=payload) as response:
content = await response.json()
if "code" in content:
raise ValueError(content)
base64_strings = [item["embedding"] for item in content["data"]]
embeddings = []
for string in base64_strings:
decode_bytes = base64.b64decode(string)
n = len(decode_bytes) // 4
float_array = struct.unpack("<" + "f" * n, decode_bytes)
embeddings.append(float_array)
return np.array(embeddings)

View File

@@ -0,0 +1,623 @@
"""
Gemini LLM binding for LightRAG.
This module provides asynchronous helpers that adapt Google's Gemini models
to the same interface used by the rest of the LightRAG LLM bindings. The
implementation mirrors the OpenAI helpers while relying on the official
``google-genai`` client under the hood.
"""
from __future__ import annotations
import os
from collections.abc import AsyncIterator
from functools import lru_cache
from typing import Any
import numpy as np
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import (
logger,
remove_think_tags,
safe_unicode_decode,
wrap_embedding_func_with_attrs,
)
import pipmaster as pm
# Install the Google Gemini client and its dependencies on demand
if not pm.is_installed("google-genai"):
pm.install("google-genai")
if not pm.is_installed("google-api-core"):
pm.install("google-api-core")
from google import genai # type: ignore
from google.genai import types # type: ignore
from google.api_core import exceptions as google_api_exceptions # type: ignore
class InvalidResponseError(Exception):
"""Custom exception class for triggering retry mechanism when Gemini returns empty responses"""
pass
@lru_cache(maxsize=8)
def _get_gemini_client(
api_key: str, base_url: str | None, timeout: int | None = None
) -> genai.Client:
"""
Create (or fetch cached) Gemini client.
Args:
api_key: Google Gemini API key (not used in Vertex AI mode).
base_url: Optional custom API endpoint.
timeout: Optional request timeout in milliseconds.
Returns:
genai.Client: Configured Gemini client instance.
"""
client_kwargs: dict[str, Any] = {}
# Add Vertex AI support
use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
if use_vertexai:
# Vertex AI mode: use project/location, NOT api_key
client_kwargs["vertexai"] = True
project = os.getenv("GOOGLE_CLOUD_PROJECT")
if project:
location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1")
client_kwargs["project"] = project
if location:
client_kwargs["location"] = location
else:
raise ValueError(
"GOOGLE_CLOUD_PROJECT must be set when using Vertex AI mode"
)
else:
# Standard Gemini API mode: use api_key
client_kwargs["api_key"] = api_key
if base_url and base_url != "DEFAULT_GEMINI_ENDPOINT" or timeout is not None:
try:
http_options_kwargs = {}
if base_url and base_url != "DEFAULT_GEMINI_ENDPOINT":
http_options_kwargs["base_url"] = base_url
if timeout is not None:
http_options_kwargs["timeout"] = timeout
client_kwargs["http_options"] = types.HttpOptions(**http_options_kwargs)
except Exception as e:
logger.error("Failed to apply custom Gemini http_options: %s", e)
raise e
return genai.Client(**client_kwargs)
def _ensure_api_key(api_key: str | None) -> str:
# In Vertex AI mode, API key is not required
use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
if use_vertexai:
# Return empty string for Vertex AI mode (not used)
return ""
key = api_key or os.getenv("LLM_BINDING_API_KEY") or os.getenv("GEMINI_API_KEY")
if not key:
raise ValueError(
"Gemini API key not provided. "
"Set LLM_BINDING_API_KEY or GEMINI_API_KEY in the environment."
)
return key
def _build_generation_config(
base_config: dict[str, Any] | None,
system_prompt: str | None,
keyword_extraction: bool,
) -> types.GenerateContentConfig | None:
config_data = dict(base_config or {})
if system_prompt:
if config_data.get("system_instruction"):
config_data["system_instruction"] = (
f"{config_data['system_instruction']}\n{system_prompt}"
)
else:
config_data["system_instruction"] = system_prompt
if keyword_extraction and not config_data.get("response_mime_type"):
config_data["response_mime_type"] = "application/json"
# Remove entries that are explicitly set to None to avoid type errors
sanitized = {
key: value
for key, value in config_data.items()
if value is not None and value != ""
}
if not sanitized:
return None
return types.GenerateContentConfig(**sanitized)
def _format_history_messages(history_messages: list[dict[str, Any]] | None) -> str:
if not history_messages:
return ""
history_lines: list[str] = []
for message in history_messages:
role = message.get("role", "user")
content = message.get("content", "")
history_lines.append(f"[{role}] {content}")
return "\n".join(history_lines)
def _extract_response_text(
response: Any, extract_thoughts: bool = False
) -> tuple[str, str]:
"""
Extract text content from Gemini response, separating regular content from thoughts.
Args:
response: Gemini API response object
extract_thoughts: Whether to extract thought content separately
Returns:
Tuple of (regular_text, thought_text)
"""
candidates = getattr(response, "candidates", None)
if not candidates:
return ("", "")
regular_parts: list[str] = []
thought_parts: list[str] = []
for candidate in candidates:
if not getattr(candidate, "content", None):
continue
# Use 'or []' to handle None values from parts attribute
for part in getattr(candidate.content, "parts", None) or []:
text = getattr(part, "text", None)
if not text:
continue
# Check if this part is thought content using the 'thought' attribute
is_thought = getattr(part, "thought", False)
if is_thought and extract_thoughts:
thought_parts.append(text)
elif not is_thought:
regular_parts.append(text)
return ("\n".join(regular_parts), "\n".join(thought_parts))
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=(
retry_if_exception_type(google_api_exceptions.InternalServerError)
| retry_if_exception_type(google_api_exceptions.ServiceUnavailable)
| retry_if_exception_type(google_api_exceptions.ResourceExhausted)
| retry_if_exception_type(google_api_exceptions.GatewayTimeout)
| retry_if_exception_type(google_api_exceptions.BadGateway)
| retry_if_exception_type(google_api_exceptions.DeadlineExceeded)
| retry_if_exception_type(google_api_exceptions.Aborted)
| retry_if_exception_type(google_api_exceptions.Unknown)
| retry_if_exception_type(InvalidResponseError)
),
)
async def gemini_complete_if_cache(
model: str,
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
enable_cot: bool = False,
base_url: str | None = None,
api_key: str | None = None,
token_tracker: Any | None = None,
stream: bool | None = None,
keyword_extraction: bool = False,
generation_config: dict[str, Any] | None = None,
timeout: int | None = None,
**_: Any,
) -> str | AsyncIterator[str]:
"""
Complete a prompt using Gemini's API with Chain of Thought (COT) support.
This function supports automatic integration of reasoning content from Gemini models
that provide Chain of Thought capabilities via the thinking_config API feature.
COT Integration:
- When enable_cot=True: Thought content is wrapped in <think>...</think> tags
- When enable_cot=False: Thought content is filtered out, only regular content returned
- Thought content is identified by the 'thought' attribute on response parts
- Requires thinking_config to be enabled in generation_config for API to return thoughts
Args:
model: The Gemini model to use.
prompt: The prompt to complete.
system_prompt: Optional system prompt to include.
history_messages: Optional list of previous messages in the conversation.
api_key: Optional Gemini API key. If None, uses environment variable.
base_url: Optional custom API endpoint.
generation_config: Optional generation configuration dict.
keyword_extraction: Whether to use JSON response format.
token_tracker: Optional token usage tracker for monitoring API usage.
stream: Whether to stream the response.
hashing_kv: Storage interface (for interface parity with other bindings).
enable_cot: Whether to include Chain of Thought content in the response.
timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API).
**_: Additional keyword arguments (ignored).
Returns:
The completed text (with COT content if enable_cot=True) or an async iterator
of text chunks if streaming. COT content is wrapped in <think>...</think> tags.
Raises:
RuntimeError: If the response from Gemini is empty.
ValueError: If API key is not provided or configured.
"""
key = _ensure_api_key(api_key)
# Convert timeout from seconds to milliseconds for Gemini API
timeout_ms = timeout * 1000 if timeout else None
client = _get_gemini_client(key, base_url, timeout_ms)
history_block = _format_history_messages(history_messages)
prompt_sections = []
if history_block:
prompt_sections.append(history_block)
prompt_sections.append(f"[user] {prompt}")
combined_prompt = "\n".join(prompt_sections)
config_obj = _build_generation_config(
generation_config,
system_prompt=system_prompt,
keyword_extraction=keyword_extraction,
)
request_kwargs: dict[str, Any] = {
"model": model,
"contents": [combined_prompt],
}
if config_obj is not None:
request_kwargs["config"] = config_obj
if stream:
async def _async_stream() -> AsyncIterator[str]:
# COT state tracking for streaming
cot_active = False
cot_started = False
initial_content_seen = False
usage_metadata = None
try:
# Use native async streaming from genai SDK
# Note: generate_content_stream returns Awaitable[AsyncIterator], need to await first
stream = await client.aio.models.generate_content_stream(
**request_kwargs
)
async for chunk in stream:
usage = getattr(chunk, "usage_metadata", None)
if usage is not None:
usage_metadata = usage
# Extract both regular and thought content
regular_text, thought_text = _extract_response_text(
chunk, extract_thoughts=True
)
if enable_cot:
# Process regular content
if regular_text:
if not initial_content_seen:
initial_content_seen = True
# Close COT section if it was active
if cot_active:
yield "</think>"
cot_active = False
# Process and yield regular content
if "\\u" in regular_text:
regular_text = safe_unicode_decode(
regular_text.encode("utf-8")
)
yield regular_text
# Process thought content
if thought_text:
if not initial_content_seen and not cot_started:
# Start COT section
yield "<think>"
cot_active = True
cot_started = True
# Yield thought content if COT is active
if cot_active:
if "\\u" in thought_text:
thought_text = safe_unicode_decode(
thought_text.encode("utf-8")
)
yield thought_text
else:
# COT disabled - only yield regular content
if regular_text:
if "\\u" in regular_text:
regular_text = safe_unicode_decode(
regular_text.encode("utf-8")
)
yield regular_text
# Ensure COT is properly closed if still active
if cot_active:
yield "</think>"
cot_active = False
except Exception as exc:
# Try to close COT tag before re-raising
if cot_active:
try:
yield "</think>"
except Exception:
pass
raise exc
finally:
# Track token usage after streaming completes
if token_tracker and usage_metadata:
token_tracker.add_usage(
{
"prompt_tokens": getattr(
usage_metadata, "prompt_token_count", 0
),
"completion_tokens": getattr(
usage_metadata, "candidates_token_count", 0
),
"total_tokens": getattr(
usage_metadata, "total_token_count", 0
),
}
)
return _async_stream()
# Non-streaming: use native async client
response = await client.aio.models.generate_content(**request_kwargs)
# Extract both regular text and thought text
regular_text, thought_text = _extract_response_text(response, extract_thoughts=True)
# Apply COT filtering logic based on enable_cot parameter
if enable_cot:
# Include thought content wrapped in <think> tags
if thought_text and thought_text.strip():
if not regular_text or regular_text.strip() == "":
# Only thought content available
final_text = f"<think>{thought_text}</think>"
else:
# Both content types present: prepend thought to regular content
final_text = f"<think>{thought_text}</think>{regular_text}"
else:
# No thought content, use regular content only
final_text = regular_text or ""
else:
# Filter out thought content, return only regular content
final_text = regular_text or ""
if not final_text:
raise InvalidResponseError("Gemini response did not contain any text content.")
if "\\u" in final_text:
final_text = safe_unicode_decode(final_text.encode("utf-8"))
final_text = remove_think_tags(final_text)
usage = getattr(response, "usage_metadata", None)
if token_tracker and usage:
token_tracker.add_usage(
{
"prompt_tokens": getattr(usage, "prompt_token_count", 0),
"completion_tokens": getattr(usage, "candidates_token_count", 0),
"total_tokens": getattr(usage, "total_token_count", 0),
}
)
logger.debug("Gemini response length: %s", len(final_text))
return final_text
async def gemini_model_complete(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
keyword_extraction: bool = False,
**kwargs: Any,
) -> str | AsyncIterator[str]:
hashing_kv = kwargs.get("hashing_kv")
model_name = None
if hashing_kv is not None:
model_name = hashing_kv.global_config.get("llm_model_name")
if model_name is None:
model_name = kwargs.pop("model_name", None)
if model_name is None:
raise ValueError("Gemini model name not provided in configuration.")
return await gemini_complete_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
keyword_extraction=keyword_extraction,
**kwargs,
)
@wrap_embedding_func_with_attrs(
embedding_dim=1536,
max_token_size=2048,
model_name="gemini-embedding-001",
supports_asymmetric=True,
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=(
retry_if_exception_type(google_api_exceptions.InternalServerError)
| retry_if_exception_type(google_api_exceptions.ServiceUnavailable)
| retry_if_exception_type(google_api_exceptions.ResourceExhausted)
| retry_if_exception_type(google_api_exceptions.GatewayTimeout)
| retry_if_exception_type(google_api_exceptions.BadGateway)
| retry_if_exception_type(google_api_exceptions.DeadlineExceeded)
| retry_if_exception_type(google_api_exceptions.Aborted)
| retry_if_exception_type(google_api_exceptions.Unknown)
),
)
async def gemini_embed(
texts: list[str],
model: str = "gemini-embedding-001",
base_url: str | None = None,
api_key: str | None = None,
embedding_dim: int | None = None,
max_token_size: int | None = None,
task_type: str | None = None,
timeout: int | None = None,
token_tracker: Any | None = None,
context: str = "document",
) -> np.ndarray:
"""Generate embeddings for a list of texts using Gemini's API.
This function uses Google's Gemini embedding model to generate text embeddings.
It supports dynamic dimension control and automatic normalization for dimensions
less than 3072.
Args:
texts: List of texts to embed.
model: The Gemini embedding model to use. Default is "gemini-embedding-001".
base_url: Optional custom API endpoint.
api_key: Optional Gemini API key. If None, uses environment variables.
embedding_dim: Optional embedding dimension for dynamic dimension reduction.
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
Do NOT manually pass this parameter when calling the function directly.
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator
or the EMBEDDING_DIM environment variable.
Supported range: 128-3072. Recommended values: 768, 1536, 3072.
max_token_size: Maximum tokens per text. This parameter is automatically
injected by the EmbeddingFunc wrapper when the underlying function
signature supports it (via inspect.signature check). Gemini API will
automatically truncate texts exceeding this limit (autoTruncate=True
by default), so no client-side truncation is needed.
task_type: Task type for embedding optimization. Default is "RETRIEVAL_DOCUMENT".
Supported types: SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING,
RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, CODE_RETRIEVAL_QUERY,
QUESTION_ANSWERING, FACT_VERIFICATION.
timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API).
token_tracker: Optional token usage tracker for monitoring API usage.
context: The embedding context - "query" for search queries, "document" for indexed content.
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper
when supports_asymmetric=True. Default is "document".
Returns:
A numpy array of embeddings, one per input text. For dimensions < 3072,
the embeddings are L2-normalized to ensure optimal semantic similarity performance.
Raises:
ValueError: If API key is not provided or configured.
RuntimeError: If the response from Gemini is invalid or empty.
Note:
- For dimension 3072: Embeddings are already normalized by the API
- For dimensions < 3072: Embeddings are L2-normalized after retrieval
- Normalization ensures accurate semantic similarity via cosine distance
- Gemini API automatically truncates texts exceeding max_token_size (autoTruncate=True)
"""
# Note: max_token_size is received but not used for client-side truncation.
# Gemini API handles truncation automatically with autoTruncate=True (default).
_ = max_token_size # Acknowledge parameter to avoid unused variable warning
key = _ensure_api_key(api_key)
# Convert timeout from seconds to milliseconds for Gemini API
timeout_ms = timeout * 1000 if timeout else None
client = _get_gemini_client(key, base_url, timeout_ms)
# Prepare embedding configuration
config_kwargs: dict[str, Any] = {}
# Add task_type to config
if task_type is None:
if context == "query":
task_type = "RETRIEVAL_QUERY"
elif context == "document":
task_type = "RETRIEVAL_DOCUMENT"
else:
task_type = "RETRIEVAL_DOCUMENT" # Default for backward compatibility
config_kwargs["task_type"] = task_type
# Add output_dimensionality if embedding_dim is provided
if embedding_dim is not None:
config_kwargs["output_dimensionality"] = embedding_dim
# Create config object if we have parameters
config_obj = types.EmbedContentConfig(**config_kwargs) if config_kwargs else None
request_kwargs: dict[str, Any] = {
"model": model,
"contents": texts,
}
if config_obj is not None:
request_kwargs["config"] = config_obj
# Use native async client for embedding
response = await client.aio.models.embed_content(**request_kwargs)
# Extract embeddings from response
if not hasattr(response, "embeddings") or not response.embeddings:
raise RuntimeError("Gemini response did not contain embeddings.")
# Convert embeddings to numpy array
embeddings = np.array(
[np.array(e.values, dtype=np.float32) for e in response.embeddings]
)
# Apply L2 normalization for dimensions < 3072
# The 3072 dimension embedding is already normalized by Gemini API
if embedding_dim and embedding_dim < 3072:
# Normalize each embedding vector to unit length
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
# Avoid division by zero
norms = np.where(norms == 0, 1, norms)
embeddings = embeddings / norms
logger.debug(
f"Applied L2 normalization to {len(embeddings)} embeddings of dimension {embedding_dim}"
)
# Track token usage if tracker is provided
# Note: Gemini embedding API may not provide usage metadata
if token_tracker and hasattr(response, "usage_metadata"):
usage = response.usage_metadata
token_counts = {
"prompt_tokens": getattr(usage, "prompt_token_count", 0),
"total_tokens": getattr(usage, "total_token_count", 0),
}
token_tracker.add_usage(token_counts)
logger.debug(
f"Generated {len(embeddings)} Gemini embeddings with dimension {embeddings.shape[1]}"
)
return embeddings
__all__ = [
"gemini_complete_if_cache",
"gemini_model_complete",
"gemini_embed",
]

View File

@@ -0,0 +1,206 @@
import copy
import os
from functools import lru_cache
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("transformers"):
pm.install("transformers")
if not pm.is_installed("torch"):
pm.install("torch")
if not pm.is_installed("numpy"):
pm.install("numpy")
from transformers import AutoTokenizer, AutoModelForCausalLM
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.exceptions import (
APIConnectionError,
RateLimitError,
APITimeoutError,
)
import torch
import numpy as np
from lightrag.utils import wrap_embedding_func_with_attrs
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@lru_cache(maxsize=1)
def initialize_hf_model(model_name):
hf_tokenizer = AutoTokenizer.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
hf_model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
if hf_tokenizer.pad_token is None:
hf_tokenizer.pad_token = hf_tokenizer.eos_token
return hf_model, hf_tokenizer
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def hf_model_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
enable_cot: bool = False,
**kwargs,
) -> str:
if enable_cot:
from lightrag.utils import logger
logger.debug(
"enable_cot=True is not supported for Hugging Face local models and will be ignored."
)
model_name = model
hf_model, hf_tokenizer = initialize_hf_model(model_name)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
kwargs.pop("hashing_kv", None)
input_prompt = ""
try:
input_prompt = hf_tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
except Exception:
try:
ori_message = copy.deepcopy(messages)
if messages[0]["role"] == "system":
messages[1]["content"] = (
"<system>"
+ messages[0]["content"]
+ "</system>\n"
+ messages[1]["content"]
)
messages = messages[1:]
input_prompt = hf_tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
except Exception:
len_message = len(ori_message)
for msgid in range(len_message):
input_prompt = (
input_prompt
+ "<"
+ ori_message[msgid]["role"]
+ ">"
+ ori_message[msgid]["content"]
+ "</"
+ ori_message[msgid]["role"]
+ ">\n"
)
input_ids = hf_tokenizer(
input_prompt, return_tensors="pt", padding=True, truncation=True
).to("cuda")
inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
output = hf_model.generate(
**input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
)
response_text = hf_tokenizer.decode(
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
)
return response_text
async def hf_model_complete(
prompt,
system_prompt=None,
history_messages=[],
keyword_extraction=False,
enable_cot: bool = False,
**kwargs,
) -> str:
kwargs.pop("keyword_extraction", None)
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
result = await hf_model_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
return result
@wrap_embedding_func_with_attrs(
embedding_dim=1024,
max_token_size=8192,
model_name="hf_embedding_model",
supports_asymmetric=True,
)
async def hf_embed(
texts: list[str],
tokenizer,
embed_model,
context: str = "document",
query_prefix: str | None = None,
document_prefix: str | None = None,
) -> np.ndarray:
"""Generate embeddings for a list of texts using a Hugging Face model.
Args:
texts (list[str]): List of input texts to embed.
tokenizer: Hugging Face tokenizer.
embed_model: Hugging Face model for generating embeddings.
context (str): Context indicating whether the texts are "query" or "document".
query_prefix (str | None): Optional prefix to add to query texts.
document_prefix (str | None): Optional prefix to add to document texts.
Returns:
np.ndarray: Array of embeddings.
"""
# Detect the appropriate device
if torch.cuda.is_available():
device = next(embed_model.parameters()).device # Use CUDA if available
elif torch.backends.mps.is_available():
device = torch.device("mps") # Use MPS for Apple Silicon
else:
device = torch.device("cpu") # Fallback to CPU
# Move the model to the detected device
embed_model = embed_model.to(device)
# Apply context-based prefixes if provided
if context == "query" and query_prefix:
texts = [query_prefix + text for text in texts]
elif context == "document" and document_prefix:
texts = [document_prefix + text for text in texts]
# Tokenize the input texts and move them to the same device
encoded_texts = tokenizer(
texts, return_tensors="pt", padding=True, truncation=True
).to(device)
# Perform inference
with torch.no_grad():
outputs = embed_model(
input_ids=encoded_texts["input_ids"],
attention_mask=encoded_texts["attention_mask"],
)
embeddings = outputs.last_hidden_state.mean(dim=1)
# Convert embeddings to NumPy
if embeddings.dtype == torch.bfloat16:
return embeddings.detach().to(torch.float32).cpu().numpy()
else:
return embeddings.detach().cpu().numpy()

View File

@@ -0,0 +1,183 @@
import os
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("aiohttp"):
pm.install("aiohttp")
if not pm.is_installed("tenacity"):
pm.install("tenacity")
import numpy as np
import base64
import aiohttp
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import wrap_embedding_func_with_attrs, logger
async def fetch_data(url, headers, data):
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=data) as response:
if response.status != 200:
error_text = await response.text()
# Check if the error response is HTML (common for 502, 503, etc.)
content_type = response.headers.get("content-type", "").lower()
is_html_error = (
error_text.strip().startswith("<!DOCTYPE html>")
or "text/html" in content_type
)
if is_html_error:
# Provide clean, user-friendly error messages for HTML error pages
if response.status == 502:
clean_error = "Bad Gateway (502) - Jina AI service temporarily unavailable. Please try again in a few minutes."
elif response.status == 503:
clean_error = "Service Unavailable (503) - Jina AI service is temporarily overloaded. Please try again later."
elif response.status == 504:
clean_error = "Gateway Timeout (504) - Jina AI service request timed out. Please try again."
else:
clean_error = f"HTTP {response.status} - Jina AI service error. Please try again later."
else:
# Use original error text if it's not HTML
clean_error = error_text
logger.error(f"Jina API error {response.status}: {clean_error}")
raise aiohttp.ClientResponseError(
request_info=response.request_info,
history=response.history,
status=response.status,
message=f"Jina API error: {clean_error}",
)
response_json = await response.json()
data_list = response_json.get("data", [])
return data_list
@wrap_embedding_func_with_attrs(
embedding_dim=2048,
max_token_size=8192,
model_name="jina-embeddings-v4",
supports_asymmetric=True,
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=(
retry_if_exception_type(aiohttp.ClientError)
| retry_if_exception_type(aiohttp.ClientResponseError)
),
)
async def jina_embed(
texts: list[str],
model: str = "jina-embeddings-v4",
embedding_dim: int = 2048,
late_chunking: bool = False,
base_url: str = None,
api_key: str = None,
context: str | None = None,
task: str | None = None,
) -> np.ndarray:
"""Generate embeddings for a list of texts using Jina AI's API.
Args:
texts: List of texts to embed.
model: The Jina embedding model to use (default: jina-embeddings-v4).
Supported models: jina-embeddings-v3, jina-embeddings-v4, etc.
embedding_dim: The embedding dimensions (default: 2048 for jina-embeddings-v4).
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
Do NOT manually pass this parameter when calling the function directly.
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
Manually passing a different value will trigger a warning and be ignored.
When provided (by EmbeddingFunc), it will be passed to the Jina API for dimension reduction.
late_chunking: Whether to use late chunking.
base_url: Optional base URL for the Jina API.
api_key: Optional Jina API key. If None, uses the JINA_API_KEY environment variable.
context: The embedding context - "query" for search queries, "document" for indexed content.
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper
when supports_asymmetric=True. When ``task`` is left at its default of None,
``context`` drives the task selection.
task: Embedding task mode. Default is None so that ``context`` (when present)
picks the right Jina task:
- "retrieval.query" for context="query"
- "retrieval.passage" for context="document"
- "text-matching" otherwise (true backward-compatible default)
Any explicit non-None task value overrides context-based selection.
Returns:
A numpy array of embeddings, one per input text.
Raises:
aiohttp.ClientError: If there is a connection error with the Jina API.
aiohttp.ClientResponseError: If the Jina API returns an error response.
"""
if api_key:
os.environ["JINA_API_KEY"] = api_key
if "JINA_API_KEY" not in os.environ:
raise ValueError("JINA_API_KEY environment variable is required")
url = base_url or "https://api.jina.ai/v1/embeddings"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
}
# Determine task based on context if not explicitly provided
if task is None:
if context == "query":
task = "retrieval.query"
elif context == "document":
task = "retrieval.passage"
else:
task = "text-matching" # Default for backward compatibility
data = {
"model": model,
"task": task,
"dimensions": embedding_dim,
"embedding_type": "base64",
"input": texts,
}
# Only add optional parameters if they have non-default values
if late_chunking:
data["late_chunking"] = late_chunking
logger.debug(
f"Jina embedding request: {len(texts)} texts, dimensions: {embedding_dim}"
)
try:
data_list = await fetch_data(url, headers, data)
if not data_list:
logger.error("Jina API returned empty data list")
raise ValueError("Jina API returned empty data list")
if len(data_list) != len(texts):
logger.error(
f"Jina API returned {len(data_list)} embeddings for {len(texts)} texts"
)
raise ValueError(
f"Jina API returned {len(data_list)} embeddings for {len(texts)} texts"
)
embeddings = np.array(
[
np.frombuffer(base64.b64decode(dp["embedding"]), dtype=np.float32)
for dp in data_list
]
)
logger.debug(f"Jina embeddings generated: shape {embeddings.shape}")
return embeddings
except Exception as e:
logger.error(f"Jina embedding error: {e}")
raise

View File

@@ -0,0 +1,208 @@
import pipmaster as pm
from llama_index.core.llms import (
ChatMessage,
MessageRole,
ChatResponse,
)
from typing import List, Optional
from lightrag.utils import logger
# Install required dependencies
if not pm.is_installed("llama-index"):
pm.install("llama-index")
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.settings import Settings as LlamaIndexSettings
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import (
wrap_embedding_func_with_attrs,
)
from lightrag.exceptions import (
APIConnectionError,
RateLimitError,
APITimeoutError,
)
import numpy as np
def configure_llama_index(settings: LlamaIndexSettings = None, **kwargs):
"""
Configure LlamaIndex settings.
Args:
settings: LlamaIndex Settings instance. If None, uses default settings.
**kwargs: Additional settings to override/configure
"""
if settings is None:
settings = LlamaIndexSettings()
# Update settings with any provided kwargs
for key, value in kwargs.items():
if hasattr(settings, key):
setattr(settings, key, value)
else:
logger.warning(f"Unknown LlamaIndex setting: {key}")
# Set as global settings
LlamaIndexSettings.set_global(settings)
return settings
def format_chat_messages(messages):
"""Format chat messages into LlamaIndex format."""
formatted_messages = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
formatted_messages.append(
ChatMessage(role=MessageRole.SYSTEM, content=content)
)
elif role == "assistant":
formatted_messages.append(
ChatMessage(role=MessageRole.ASSISTANT, content=content)
)
elif role == "user":
formatted_messages.append(
ChatMessage(role=MessageRole.USER, content=content)
)
else:
logger.warning(f"Unknown role {role}, treating as user message")
formatted_messages.append(
ChatMessage(role=MessageRole.USER, content=content)
)
return formatted_messages
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def llama_index_complete_if_cache(
model: str,
prompt: str,
system_prompt: Optional[str] = None,
history_messages: List[dict] = [],
enable_cot: bool = False,
chat_kwargs={},
) -> str:
"""Complete the prompt using LlamaIndex."""
if enable_cot:
logger.debug(
"enable_cot=True is not supported for LlamaIndex implementation and will be ignored."
)
try:
# Format messages for chat
formatted_messages = []
# Add system message if provided
if system_prompt:
formatted_messages.append(
ChatMessage(role=MessageRole.SYSTEM, content=system_prompt)
)
# Add history messages
for msg in history_messages:
formatted_messages.append(
ChatMessage(
role=MessageRole.USER
if msg["role"] == "user"
else MessageRole.ASSISTANT,
content=msg["content"],
)
)
# Add current prompt
formatted_messages.append(ChatMessage(role=MessageRole.USER, content=prompt))
response: ChatResponse = await model.achat(
messages=formatted_messages, **chat_kwargs
)
# In newer versions, the response is in message.content
content = response.message.content
return content
except Exception as e:
logger.error(f"Error in llama_index_complete_if_cache: {str(e)}")
raise
async def llama_index_complete(
prompt,
system_prompt=None,
history_messages=None,
enable_cot: bool = False,
keyword_extraction=False,
settings: LlamaIndexSettings = None,
**kwargs,
) -> str:
"""
Main completion function for LlamaIndex
Args:
prompt: Input prompt
system_prompt: Optional system prompt
history_messages: Optional chat history
keyword_extraction: Whether to extract keywords from response
settings: Optional LlamaIndex settings
**kwargs: Additional arguments
"""
if history_messages is None:
history_messages = []
kwargs.pop("keyword_extraction", None)
result = await llama_index_complete_if_cache(
kwargs.get("llm_instance"),
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
return result
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def llama_index_embed(
texts: list[str],
embed_model: BaseEmbedding = None,
settings: LlamaIndexSettings = None,
**kwargs,
) -> np.ndarray:
"""
Generate embeddings using LlamaIndex
Args:
texts: List of texts to embed
embed_model: LlamaIndex embedding model
settings: Optional LlamaIndex settings
**kwargs: Additional arguments
"""
if settings:
configure_llama_index(settings)
if embed_model is None:
raise ValueError("embed_model must be provided")
# Use _get_text_embeddings for batch processing
embeddings = embed_model._get_text_embeddings(texts)
return np.array(embeddings)

View File

@@ -0,0 +1,154 @@
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("lmdeploy"):
pm.install("lmdeploy[all]")
from lightrag.exceptions import (
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from functools import lru_cache
@lru_cache(maxsize=1)
def initialize_lmdeploy_pipeline(
model,
tp=1,
chat_template=None,
log_level="WARNING",
model_format="hf",
quant_policy=0,
):
from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
lmdeploy_pipe = pipeline(
model_path=model,
backend_config=TurbomindEngineConfig(
tp=tp, model_format=model_format, quant_policy=quant_policy
),
chat_template_config=(
ChatTemplateConfig(model_name=chat_template) if chat_template else None
),
log_level="WARNING",
)
return lmdeploy_pipe
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def lmdeploy_model_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
enable_cot: bool = False,
chat_template=None,
model_format="hf",
quant_policy=0,
**kwargs,
) -> str:
"""
Args:
model (str): The path to the model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download
from ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
chat_template (str): needed when model is a pytorch model on
huggingface.co, such as "internlm-chat-7b",
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
and when the model name of local path did not match the original model name in HF.
tp (int): tensor parallel
prompt (Union[str, List[str]]): input texts to be completed.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
Default to be False, which means greedy decoding will be applied.
"""
if enable_cot:
from lightrag.utils import logger
logger.debug(
"enable_cot=True is not supported for lmdeploy and will be ignored."
)
try:
import lmdeploy
from lmdeploy import version_info, GenerationConfig
except Exception:
raise ImportError("Please install lmdeploy before initialize lmdeploy backend.")
kwargs.pop("hashing_kv", None)
kwargs.pop("response_format", None)
max_new_tokens = kwargs.pop("max_tokens", 512)
tp = kwargs.pop("tp", 1)
skip_special_tokens = kwargs.pop("skip_special_tokens", True)
do_preprocess = kwargs.pop("do_preprocess", True)
do_sample = kwargs.pop("do_sample", False)
gen_params = kwargs
version = version_info
if do_sample is not None and version < (0, 6, 0):
raise RuntimeError(
"`do_sample` parameter is not supported by lmdeploy until "
f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
)
else:
do_sample = True
gen_params.update(do_sample=do_sample)
lmdeploy_pipe = initialize_lmdeploy_pipeline(
model=model,
tp=tp,
chat_template=chat_template,
model_format=model_format,
quant_policy=quant_policy,
log_level="WARNING",
)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens,
max_new_tokens=max_new_tokens,
**gen_params,
)
response = ""
async for res in lmdeploy_pipe.generate(
messages,
gen_config=gen_config,
do_preprocess=do_preprocess,
stream_response=False,
session_id=1,
):
response += res.response
return response

View File

@@ -0,0 +1,177 @@
import sys
if sys.version_info < (3, 9):
from typing import AsyncIterator
else:
from collections.abc import AsyncIterator
import pipmaster as pm # Pipmaster for dynamic library install
if not pm.is_installed("aiohttp"):
pm.install("aiohttp")
import aiohttp
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.exceptions import (
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from typing import Union, List
import numpy as np
from lightrag.utils import (
wrap_embedding_func_with_attrs,
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def lollms_model_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
enable_cot: bool = False,
base_url="http://localhost:9600",
**kwargs,
) -> Union[str, AsyncIterator[str]]:
"""Client implementation for lollms generation."""
if enable_cot:
from lightrag.utils import logger
logger.debug("enable_cot=True is not supported for lollms and will be ignored.")
stream = True if kwargs.get("stream") else False
api_key = kwargs.pop("api_key", None)
headers = (
{"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
if api_key
else {"Content-Type": "application/json"}
)
# Extract lollms specific parameters
request_data = {
"prompt": prompt,
"model_name": model,
"personality": kwargs.get("personality", -1),
"n_predict": kwargs.get("n_predict", None),
"stream": stream,
"temperature": kwargs.get("temperature", 1.0),
"top_k": kwargs.get("top_k", 50),
"top_p": kwargs.get("top_p", 0.95),
"repeat_penalty": kwargs.get("repeat_penalty", 0.8),
"repeat_last_n": kwargs.get("repeat_last_n", 40),
"seed": kwargs.get("seed", None),
"n_threads": kwargs.get("n_threads", 8),
}
# Prepare the full prompt including history
full_prompt = ""
if system_prompt:
full_prompt += f"{system_prompt}\n"
for msg in history_messages:
full_prompt += f"{msg['role']}: {msg['content']}\n"
full_prompt += prompt
request_data["prompt"] = full_prompt
timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout", None))
async with aiohttp.ClientSession(timeout=timeout, headers=headers) as session:
if stream:
async def inner():
async with session.post(
f"{base_url}/lollms_generate", json=request_data
) as response:
async for line in response.content:
yield line.decode().strip()
return inner()
else:
async with session.post(
f"{base_url}/lollms_generate", json=request_data
) as response:
return await response.text()
async def lollms_model_complete(
prompt,
system_prompt=None,
history_messages=[],
enable_cot: bool = False,
keyword_extraction=False,
**kwargs,
) -> Union[str, AsyncIterator[str]]:
"""Complete function for lollms model generation."""
# Extract and remove keyword_extraction from kwargs if present
keyword_extraction = kwargs.pop("keyword_extraction", None)
# Get model name from config
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
# If keyword extraction is needed, we might need to modify the prompt
# or add specific parameters for JSON output (if lollms supports it)
if keyword_extraction:
# Note: You might need to adjust this based on how lollms handles structured output
pass
return await lollms_model_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
@wrap_embedding_func_with_attrs(
embedding_dim=1024, max_token_size=8192, model_name="lollms_embedding_model"
)
async def lollms_embed(
texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
) -> np.ndarray:
"""
Generate embeddings for a list of texts using lollms server.
Args:
texts: List of strings to embed
embed_model: Model name (not used directly as lollms uses configured vectorizer)
base_url: URL of the lollms server
**kwargs: Additional arguments passed to the request
Returns:
np.ndarray: Array of embeddings
"""
api_key = kwargs.pop("api_key", None)
headers = (
{"Content-Type": "application/json", "Authorization": api_key}
if api_key
else {"Content-Type": "application/json"}
)
async with aiohttp.ClientSession(headers=headers) as session:
embeddings = []
for text in texts:
request_data = {"text": text}
async with session.post(
f"{base_url}/lollms_embed",
json=request_data,
) as response:
result = await response.json()
embeddings.append(result["vector"])
return np.array(embeddings)

View File

@@ -0,0 +1,68 @@
import sys
import os
if sys.version_info < (3, 9):
pass
else:
pass
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("openai"):
pm.install("openai")
from openai import (
AsyncOpenAI,
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import (
wrap_embedding_func_with_attrs,
)
import numpy as np
@wrap_embedding_func_with_attrs(
embedding_dim=2048, max_token_size=8192, model_name="nvidia_embedding_model"
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def nvidia_openai_embed(
texts: list[str],
model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1",
# refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding
base_url: str = "https://integrate.api.nvidia.com/v1",
api_key: str = None,
input_type: str = "passage", # query for retrieval, passage for embedding
trunc: str = "NONE", # NONE or START or END
encode: str = "float", # float or base64
) -> np.ndarray:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
response = await openai_async_client.embeddings.create(
model=model,
input=texts,
encoding_format=encode,
extra_body={"input_type": input_type, "truncate": trunc},
)
return np.array([dp.embedding for dp in response.data])

View File

@@ -0,0 +1,260 @@
from collections.abc import AsyncIterator
import os
import re
import pipmaster as pm
# install specific modules
if not pm.is_installed("ollama"):
pm.install("ollama")
import ollama
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.exceptions import (
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from lightrag.api import __api_version__
import numpy as np
from typing import Optional, Union
from lightrag.utils import (
wrap_embedding_func_with_attrs,
logger,
)
_OLLAMA_CLOUD_HOST = "https://ollama.com"
_CLOUD_MODEL_SUFFIX_PATTERN = re.compile(r"(?:-cloud|:cloud)$")
def _coerce_host_for_cloud_model(host: Optional[str], model: object) -> Optional[str]:
if host:
return host
try:
model_name_str = str(model) if model is not None else ""
except (TypeError, ValueError, AttributeError) as e:
logger.warning(f"Failed to convert model to string: {e}, using empty string")
model_name_str = ""
if _CLOUD_MODEL_SUFFIX_PATTERN.search(model_name_str):
logger.debug(
f"Detected cloud model '{model_name_str}', using Ollama Cloud host"
)
return _OLLAMA_CLOUD_HOST
return host
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def _ollama_model_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
enable_cot: bool = False,
**kwargs,
) -> Union[str, AsyncIterator[str]]:
if enable_cot:
logger.debug("enable_cot=True is not supported for ollama and will be ignored.")
stream = True if kwargs.get("stream") else False
kwargs.pop("max_tokens", None)
# kwargs.pop("response_format", None) # allow json
host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None)
if timeout == 0:
timeout = None
kwargs.pop("hashing_kv", None)
api_key = kwargs.pop("api_key", None)
# fallback to environment variable when not provided explicitly
if not api_key:
api_key = os.getenv("OLLAMA_API_KEY")
headers = {
"Content-Type": "application/json",
"User-Agent": f"LightRAG/{__api_version__}",
}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
host = _coerce_host_for_cloud_model(host, model)
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
try:
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
if stream:
"""cannot cache stream response and process reasoning"""
async def inner():
try:
async for chunk in response:
yield chunk["message"]["content"]
except Exception as e:
logger.error(f"Error in stream response: {str(e)}")
raise
finally:
try:
await ollama_client._client.aclose()
logger.debug("Successfully closed Ollama client for streaming")
except Exception as close_error:
logger.warning(f"Failed to close Ollama client: {close_error}")
return inner()
else:
model_response = response["message"]["content"]
"""
If the model also wraps its thoughts in a specific tag,
this information is not needed for the final
response and can simply be trimmed.
"""
return model_response
except Exception as e:
try:
await ollama_client._client.aclose()
logger.debug("Successfully closed Ollama client after exception")
except Exception as close_error:
logger.warning(
f"Failed to close Ollama client after exception: {close_error}"
)
raise e
finally:
if not stream:
try:
await ollama_client._client.aclose()
logger.debug(
"Successfully closed Ollama client for non-streaming response"
)
except Exception as close_error:
logger.warning(
f"Failed to close Ollama client in finally block: {close_error}"
)
async def ollama_model_complete(
prompt,
system_prompt=None,
history_messages=[],
enable_cot: bool = False,
keyword_extraction=False,
**kwargs,
) -> Union[str, AsyncIterator[str]]:
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["format"] = "json"
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await _ollama_model_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
@wrap_embedding_func_with_attrs(
embedding_dim=1024,
max_token_size=8192,
model_name="bge-m3:latest",
supports_asymmetric=True,
)
async def ollama_embed(
texts: list[str],
embed_model: str = "bge-m3:latest",
max_token_size: int | None = None,
context: str = "document",
query_prefix: str | None = None,
document_prefix: str | None = None,
**kwargs,
) -> np.ndarray:
"""Generate embeddings using Ollama's API.
Args:
texts: List of texts to embed.
embed_model: The Ollama embedding model to use. Default is "bge-m3:latest".
max_token_size: Maximum tokens per text. This parameter is automatically
injected by the EmbeddingFunc wrapper when the underlying function
signature supports it (via inspect.signature check). Ollama will
automatically truncate texts exceeding the model's context length
(num_ctx), so no client-side truncation is needed.
context: The embedding context - "query" for search queries, "document" for indexed content.
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper
when supports_asymmetric=True. Default is "document".
query_prefix: Optional prefix to prepend to texts when context="query" (e.g., "search_query: ").
document_prefix: Optional prefix to prepend to texts when context="document" (e.g., "search_document: ").
**kwargs: Additional arguments passed to the Ollama client.
Returns:
A numpy array of embeddings, one per input text.
Note:
- Ollama API automatically truncates texts exceeding the model's context length
- The max_token_size parameter is received but not used for client-side truncation
"""
# Apply context-based prefixes if provided
if context == "query" and query_prefix:
texts = [query_prefix + text for text in texts]
elif context == "document" and document_prefix:
texts = [document_prefix + text for text in texts]
# Note: max_token_size is received but not used for client-side truncation.
# Ollama API handles truncation automatically based on the model's num_ctx setting.
_ = max_token_size # Acknowledge parameter to avoid unused variable warning
api_key = kwargs.pop("api_key", None)
if not api_key:
api_key = os.getenv("OLLAMA_API_KEY")
headers = {
"Content-Type": "application/json",
"User-Agent": f"LightRAG/{__api_version__}",
}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None)
host = _coerce_host_for_cloud_model(host, embed_model)
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
try:
options = kwargs.pop("options", {})
data = await ollama_client.embed(
model=embed_model, input=texts, options=options
)
return np.array(data["embeddings"])
except Exception as e:
logger.error(f"Error in ollama_embed: {str(e)}")
try:
await ollama_client._client.aclose()
logger.debug("Successfully closed Ollama client after exception in embed")
except Exception as close_error:
logger.warning(
f"Failed to close Ollama client after exception in embed: {close_error}"
)
raise e
finally:
try:
await ollama_client._client.aclose()
logger.debug("Successfully closed Ollama client after embed")
except Exception as close_error:
logger.warning(f"Failed to close Ollama client after embed: {close_error}")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,197 @@
import os
import numpy as np
import pipmaster as pm # Pipmaster for dynamic library install
# Add Voyage AI import
if not pm.is_installed("voyageai"):
pm.install("voyageai")
import voyageai
from voyageai.error import (
APIConnectionError,
RateLimitError,
ServerError,
ServiceUnavailableError,
Timeout,
TryAgain,
)
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import wrap_embedding_func_with_attrs, logger
# Custom exceptions for VoyageAI errors
class VoyageAIError(Exception):
"""Generic VoyageAI API error"""
pass
@wrap_embedding_func_with_attrs(
embedding_dim=1024, max_token_size=32000, supports_asymmetric=True
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type(
(
APIConnectionError,
RateLimitError,
ServerError,
ServiceUnavailableError,
Timeout,
TryAgain,
)
),
)
async def voyageai_embed(
texts: list[str],
model: str = "voyage-3",
api_key: str | None = None,
embedding_dim: int | None = None,
input_type: str | None = None,
truncation: bool | None = True,
context: str | None = None,
) -> np.ndarray:
"""Generate embeddings for a list of texts using VoyageAI's API.
Args:
texts: List of texts to embed.
model: The VoyageAI embedding model to use. Options include:
- "voyage-3": General purpose (1024 dims, 32K context)
- "voyage-3-lite": Lightweight (512 dims, 32K context)
- "voyage-3-large": Highest accuracy (1024 dims, 32K context)
- "voyage-code-3": Code optimized (1024 dims, 32K context)
- "voyage-law-2": Legal documents (1024 dims, 16K context)
- "voyage-finance-2": Finance (1024 dims, 32K context)
api_key: Optional VoyageAI API key. If None, falls back to the
``VOYAGE_API_KEY`` environment variable (the name VoyageAI's own
SDK uses), then to ``VOYAGEAI_API_KEY`` for backward compatibility.
embedding_dim: Optional Matryoshka output dimension. Only honored by
models that support dimension reduction (e.g. voyage-3-large);
ignored otherwise. The decorator default is 1024 to match
``voyage-3``; if you select ``voyage-3-lite`` (512 dims) override
``EMBEDDING_DIM`` accordingly so the vector store size matches.
input_type: Optional input type hint for the model. Options:
- "query": For search queries
- "document": For documents to be indexed
- None: Let the model decide (default)
truncation: Whether the API should truncate texts that exceed the model's
token limit. Defaults to True (matches the VoyageAI SDK default).
context: Optional LightRAG embedding context. When ``input_type`` is not
set, "query" maps to ``input_type="query"`` and "document" maps to
``input_type="document"``.
Returns:
A numpy array of embeddings, one per input text.
Raises:
VoyageAIError: If the API call fails or returns invalid data.
"""
if not api_key:
api_key = os.environ.get("VOYAGE_API_KEY") or os.environ.get("VOYAGEAI_API_KEY")
if not api_key:
logger.error(
"VoyageAI API key not provided and neither VOYAGE_API_KEY nor "
"VOYAGEAI_API_KEY environment variable is set"
)
raise ValueError(
"VoyageAI API key is required: pass api_key, or set the "
"VOYAGE_API_KEY (preferred) or VOYAGEAI_API_KEY environment variable"
)
if input_type is None and context in {"query", "document"}:
input_type = context
try:
client = voyageai.AsyncClient(api_key=api_key)
total_chars = sum(len(t) for t in texts)
avg_chars = total_chars / len(texts) if texts else 0
logger.debug(
f"VoyageAI embedding request: {len(texts)} texts, "
f"total_chars={total_chars}, avg_chars={avg_chars:.0f}, model={model}, "
f"input_type={input_type}"
)
# Prepare API call parameters
embed_params = dict(
texts=texts,
model=model,
# Optional parameters -- if None, voyageai client uses defaults
output_dimension=embedding_dim,
truncation=truncation,
input_type=input_type,
)
# Make API call with timing
result = await client.embed(**embed_params)
if not result.embeddings:
err_msg = "VoyageAI API returned empty embeddings"
logger.error(err_msg)
raise VoyageAIError(err_msg)
if len(result.embeddings) != len(texts):
err_msg = f"VoyageAI API returned {len(result.embeddings)} embeddings for {len(texts)} texts"
logger.error(err_msg)
raise VoyageAIError(err_msg)
# Convert to numpy array with timing
embeddings = np.array(result.embeddings, dtype=np.float32)
logger.debug(f"VoyageAI embeddings generated: shape {embeddings.shape}")
return embeddings
except Exception as e:
logger.error(f"VoyageAI embedding error: {e}")
raise
# Optional: a helper function to get available embedding models
def get_available_embedding_models() -> dict[str, dict]:
"""
Returns a dictionary of available Voyage AI embedding models and their properties.
"""
return {
"voyage-3-large": {
"context_length": 32000,
"dimension": 1024,
"description": "Best general-purpose and multilingual",
},
"voyage-3": {
"context_length": 32000,
"dimension": 1024,
"description": "General-purpose and multilingual",
},
"voyage-3-lite": {
"context_length": 32000,
"dimension": 512,
"description": "Optimized for latency and cost",
},
"voyage-code-3": {
"context_length": 32000,
"dimension": 1024,
"description": "Optimized for code",
},
"voyage-finance-2": {
"context_length": 32000,
"dimension": 1024,
"description": "Optimized for finance",
},
"voyage-law-2": {
"context_length": 16000,
"dimension": 1024,
"description": "Optimized for legal",
},
"voyage-multimodal-3": {
"context_length": 32000,
"dimension": 1024,
"description": "Multimodal text and images",
},
}

View File

@@ -0,0 +1,248 @@
import sys
import re
import json
from ..utils import verbose_debug
if sys.version_info < (3, 9):
pass
else:
pass
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("zhipuai"):
pm.install("zhipuai")
from openai import (
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import (
wrap_embedding_func_with_attrs,
logger,
)
from lightrag.types import GPTKeywordExtractionFormat
import numpy as np
from typing import Union, List, Optional, Dict
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def zhipu_complete_if_cache(
prompt: Union[str, List[Dict[str, str]]],
model: str = "glm-4-flashx", # The most cost/performance balance model in glm-4 series
api_key: Optional[str] = None,
system_prompt: Optional[str] = None,
history_messages: List[Dict[str, str]] = [],
enable_cot: bool = False, # LightRAG output switch: include reasoning_content as <think>...</think>
thinking: Optional[
Dict[str, object]
] = None, # Zhipu request param: use {"type": "enabled"} to enable thinking
**kwargs,
) -> str:
"""Call Zhipu chat completions with optional official thinking support.
Parameter roles:
- `thinking`: forwarded to the Zhipu API as-is. To enable thinking output,
pass a config such as `{"type": "enabled"}`.
- `enable_cot`: LightRAG-only formatting switch. When True and the API
returns `reasoning_content`, it is preserved in the final string as
`<think>...</think>`.
"""
# dynamically load ZhipuAI
try:
from zhipuai import ZhipuAI
except ImportError:
raise ImportError("Please install zhipuai before initialize zhipuai backend.")
if api_key:
client = ZhipuAI(api_key=api_key)
else:
# please set ZHIPUAI_API_KEY in your environment
# os.environ["ZHIPUAI_API_KEY"]
client = ZhipuAI()
messages = []
if not system_prompt:
system_prompt = "You are a helpful assistant. Note that sensitive words in the content should be replaced with ***"
# Add system prompt if provided
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
# Add debug logging
logger.debug("===== Query Input to LLM =====")
logger.debug(f"Query: {prompt}")
verbose_debug(f"System prompt: {system_prompt}")
# Remove unsupported kwargs
kwargs = {
k: v for k, v in kwargs.items() if k not in ["hashing_kv", "keyword_extraction"]
}
# `thinking` is an official Zhipu request field. Example:
# {"type": "enabled"} enables reasoning output on supported models.
if thinking is not None:
kwargs["thinking"] = thinking
response = client.chat.completions.create(model=model, messages=messages, **kwargs)
message = response.choices[0].message
content = message.content or ""
reasoning_content = getattr(message, "reasoning_content", "") or ""
if enable_cot and reasoning_content.strip():
if content:
return f"<think>{reasoning_content}</think>{content}"
return f"<think>{reasoning_content}</think>"
return content
async def zhipu_complete(
prompt,
system_prompt=None,
history_messages=[],
keyword_extraction=False,
enable_cot: bool = False,
**kwargs,
):
# Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache
keyword_extraction = kwargs.pop("keyword_extraction", keyword_extraction)
if keyword_extraction:
# Add a system prompt to guide the model to return JSON format
extraction_prompt = """You are a helpful assistant that extracts keywords from text.
Please analyze the content and extract two types of keywords:
1. High-level keywords: Important concepts and main themes
2. Low-level keywords: Specific details and supporting elements
Return your response in this exact JSON format:
{
"high_level_keywords": ["keyword1", "keyword2"],
"low_level_keywords": ["keyword1", "keyword2", "keyword3"]
}
Only return the JSON, no other text."""
# Combine with existing system prompt if any
if system_prompt:
system_prompt = f"{system_prompt}\n\n{extraction_prompt}"
else:
system_prompt = extraction_prompt
try:
response = await zhipu_complete_if_cache(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
# Try to parse as JSON
try:
data = json.loads(response)
return GPTKeywordExtractionFormat(
high_level_keywords=data.get("high_level_keywords", []),
low_level_keywords=data.get("low_level_keywords", []),
)
except json.JSONDecodeError:
# If direct JSON parsing fails, try to extract JSON from text
match = re.search(r"\{[\s\S]*\}", response)
if match:
try:
data = json.loads(match.group())
return GPTKeywordExtractionFormat(
high_level_keywords=data.get("high_level_keywords", []),
low_level_keywords=data.get("low_level_keywords", []),
)
except json.JSONDecodeError:
pass
# If all parsing fails, log warning and return empty format
logger.warning(
f"Failed to parse keyword extraction response: {response}"
)
return GPTKeywordExtractionFormat(
high_level_keywords=[], low_level_keywords=[]
)
except Exception as e:
logger.error(f"Error during keyword extraction: {str(e)}")
return GPTKeywordExtractionFormat(
high_level_keywords=[], low_level_keywords=[]
)
else:
# For non-keyword-extraction, just return the raw response string
return await zhipu_complete_if_cache(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
@wrap_embedding_func_with_attrs(
embedding_dim=1024, max_token_size=8192, model_name="embedding-3"
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def zhipu_embedding(
texts: list[str],
model: str = "embedding-3",
api_key: str = None,
embedding_dim: int | None = None,
**kwargs,
) -> np.ndarray:
# dynamically load ZhipuAI
try:
from zhipuai import ZhipuAI
except ImportError:
raise ImportError("Please install zhipuai before initialize zhipuai backend.")
if api_key:
client = ZhipuAI(api_key=api_key)
else:
# please set ZHIPUAI_API_KEY in your environment
# os.environ["ZHIPUAI_API_KEY"]
client = ZhipuAI()
# Convert single text to list if needed
if isinstance(texts, str):
texts = [texts]
embeddings = []
for text in texts:
try:
request_kwargs = dict(kwargs)
if embedding_dim is not None:
request_kwargs["dimensions"] = embedding_dim
response = client.embeddings.create(
model=model, input=[text], **request_kwargs
)
embeddings.append(response.data[0].embedding)
except Exception as e:
raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")
return np.array(embeddings)