feat: 重构知识库系统,移除Hermes集成,增强RAG和同步功能
主要变更: - 移除Hermes智能体及相关回调服务 - 新增知识库RAG、同步、调度、规范化和索引任务服务 - 重构orchestrator服务,增强运行时聊天功能 - 更新前端聊天、政策制度、设置等页面样式和逻辑 - 更新expense_claims和document_intelligence服务 - 删除llm_wiki相关服务和测试文件 - 更新docker-compose配置和启动脚本
This commit is contained in:
265
.tmp/lightrag_inspect/lightrag_pkg/lightrag/llm/anthropic.py
Normal file
265
.tmp/lightrag_inspect/lightrag_pkg/lightrag/llm/anthropic.py
Normal file
@@ -0,0 +1,265 @@
|
||||
from ..utils import verbose_debug, VERBOSE_DEBUG
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, Union, AsyncIterator
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
from typing import AsyncIterator
|
||||
else:
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
# Install Anthropic SDK if not present
|
||||
if not pm.is_installed("anthropic"):
|
||||
pm.install("anthropic")
|
||||
|
||||
from anthropic import (
|
||||
AsyncAnthropic,
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
)
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
from lightrag.utils import (
|
||||
safe_unicode_decode,
|
||||
logger,
|
||||
)
|
||||
from lightrag.api import __api_version__
|
||||
|
||||
|
||||
# Custom exception for retry mechanism
|
||||
class InvalidResponseError(Exception):
|
||||
"""Custom exception class for triggering retry mechanism"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Core Anthropic completion function with retry
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError, InvalidResponseError)
|
||||
),
|
||||
)
|
||||
async def anthropic_complete_if_cache(
|
||||
model: str,
|
||||
prompt: str,
|
||||
system_prompt: str | None = None,
|
||||
history_messages: list[dict[str, Any]] | None = None,
|
||||
enable_cot: bool = False,
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
if enable_cot:
|
||||
logger.debug(
|
||||
"enable_cot=True is not supported for the Anthropic API and will be ignored."
|
||||
)
|
||||
if not api_key:
|
||||
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
||||
|
||||
default_headers = {
|
||||
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Set logger level to INFO when VERBOSE_DEBUG is off
|
||||
if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
|
||||
logging.getLogger("anthropic").setLevel(logging.INFO)
|
||||
|
||||
kwargs.pop("hashing_kv", None)
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
timeout = kwargs.pop("timeout", None)
|
||||
|
||||
anthropic_async_client = (
|
||||
AsyncAnthropic(
|
||||
default_headers=default_headers, api_key=api_key, timeout=timeout
|
||||
)
|
||||
if base_url is None
|
||||
else AsyncAnthropic(
|
||||
base_url=base_url,
|
||||
default_headers=default_headers,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
)
|
||||
)
|
||||
|
||||
messages: list[dict[str, Any]] = []
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
logger.debug("===== Sending Query to Anthropic LLM =====")
|
||||
logger.debug(f"Model: {model} Base URL: {base_url}")
|
||||
logger.debug(f"Additional kwargs: {kwargs}")
|
||||
verbose_debug(f"Query: {prompt}")
|
||||
verbose_debug(f"System prompt: {system_prompt}")
|
||||
|
||||
try:
|
||||
create_params = {"model": model, "messages": messages, "stream": True, **kwargs}
|
||||
if system_prompt:
|
||||
create_params["system"] = system_prompt
|
||||
response = await anthropic_async_client.messages.create(**create_params)
|
||||
|
||||
except APIConnectionError as e:
|
||||
logger.error(f"Anthropic API Connection Error: {e}")
|
||||
raise
|
||||
except RateLimitError as e:
|
||||
logger.error(f"Anthropic API Rate Limit Error: {e}")
|
||||
raise
|
||||
except APITimeoutError as e:
|
||||
logger.error(f"Anthropic API Timeout Error: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
body = getattr(e, "body", None)
|
||||
request_id = getattr(e, "request_id", None)
|
||||
req = getattr(e, "request", None)
|
||||
extra_parts = []
|
||||
if body:
|
||||
extra_parts.append(f"Response body: {body}")
|
||||
if request_id:
|
||||
extra_parts.append(f"Request ID: {request_id}")
|
||||
if req is not None:
|
||||
extra_parts.append(f"Request URL: {req.url}")
|
||||
extra = ("\n" + "\n".join(extra_parts)) if extra_parts else ""
|
||||
logger.error(
|
||||
f"Anthropic API Call Failed,\nModel: {model},\nParams: {kwargs}, Got: {e}{extra}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def stream_response():
|
||||
try:
|
||||
async for event in response:
|
||||
content = (
|
||||
event.delta.text
|
||||
if hasattr(event, "delta")
|
||||
and hasattr(event.delta, "text")
|
||||
and event.delta.text
|
||||
else None
|
||||
)
|
||||
if content is None:
|
||||
continue
|
||||
if r"\u" in content:
|
||||
content = safe_unicode_decode(content.encode("utf-8"))
|
||||
yield content
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream response: {str(e)}")
|
||||
raise
|
||||
|
||||
return stream_response()
|
||||
|
||||
|
||||
# Generic Anthropic completion function
|
||||
async def anthropic_complete(
|
||||
prompt: str,
|
||||
system_prompt: str | None = None,
|
||||
history_messages: list[dict[str, Any]] | None = None,
|
||||
enable_cot: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||
return await anthropic_complete_if_cache(
|
||||
model_name,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
enable_cot=enable_cot,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# Claude 3 Opus specific completion
|
||||
async def claude_3_opus_complete(
|
||||
prompt: str,
|
||||
system_prompt: str | None = None,
|
||||
history_messages: list[dict[str, Any]] | None = None,
|
||||
enable_cot: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
return await anthropic_complete_if_cache(
|
||||
"claude-3-opus-20240229",
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
enable_cot=enable_cot,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# Claude 3 Sonnet specific completion
|
||||
async def claude_3_sonnet_complete(
|
||||
prompt: str,
|
||||
system_prompt: str | None = None,
|
||||
history_messages: list[dict[str, Any]] | None = None,
|
||||
enable_cot: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
return await anthropic_complete_if_cache(
|
||||
"claude-3-sonnet-20240229",
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
enable_cot=enable_cot,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# Claude 3 Haiku specific completion
|
||||
async def claude_3_haiku_complete(
|
||||
prompt: str,
|
||||
system_prompt: str | None = None,
|
||||
history_messages: list[dict[str, Any]] | None = None,
|
||||
enable_cot: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
return await anthropic_complete_if_cache(
|
||||
"claude-3-haiku-20240307",
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
enable_cot=enable_cot,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# Backward-compatibility shim: the previous embedding implementation lived in
|
||||
# this module under the (misleading) name ``anthropic_embed`` even though it
|
||||
# called Voyage AI under the hood. The real implementation now lives in
|
||||
# ``lightrag.llm.voyageai.voyageai_embed``. Keep the old name importable for one
|
||||
# release cycle so downstream users get a clear deprecation warning instead of
|
||||
# an ImportError. Remove in a future major version.
|
||||
def anthropic_embed(*args, **kwargs):
|
||||
"""Deprecated alias for :func:`lightrag.llm.voyageai.voyageai_embed`.
|
||||
|
||||
This shim accepts the same arguments as the original ``anthropic_embed``
|
||||
function (which was always backed by VoyageAI) and forwards them to
|
||||
:func:`voyageai_embed`. It will be removed in a future release.
|
||||
"""
|
||||
|
||||
warnings.warn(
|
||||
"lightrag.llm.anthropic.anthropic_embed is deprecated and will be "
|
||||
"removed in a future release. Import "
|
||||
"lightrag.llm.voyageai.voyageai_embed instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
from lightrag.llm.voyageai import voyageai_embed
|
||||
|
||||
return voyageai_embed.func(*args, **kwargs)
|
||||
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Azure OpenAI compatibility layer.
|
||||
|
||||
This module provides backward compatibility by re-exporting Azure OpenAI functions
|
||||
from the main openai module where the actual implementation resides.
|
||||
|
||||
All core logic for both OpenAI and Azure OpenAI now lives in lightrag.llm.openai,
|
||||
with this module serving as a thin compatibility wrapper for existing code that
|
||||
imports from lightrag.llm.azure_openai.
|
||||
"""
|
||||
|
||||
from lightrag.llm.openai import (
|
||||
azure_openai_complete_if_cache,
|
||||
azure_openai_complete,
|
||||
azure_openai_embed,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"azure_openai_complete_if_cache",
|
||||
"azure_openai_complete",
|
||||
"azure_openai_embed",
|
||||
]
|
||||
485
.tmp/lightrag_inspect/lightrag_pkg/lightrag/llm/bedrock.py
Normal file
485
.tmp/lightrag_inspect/lightrag_pkg/lightrag/llm/bedrock.py
Normal file
@@ -0,0 +1,485 @@
|
||||
import copy
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
if not pm.is_installed("aioboto3"):
|
||||
pm.install("aioboto3")
|
||||
import aioboto3
|
||||
import numpy as np
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
import sys
|
||||
from lightrag.utils import wrap_embedding_func_with_attrs
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
from typing import AsyncIterator
|
||||
else:
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Union
|
||||
|
||||
# Import botocore exceptions for proper exception handling
|
||||
try:
|
||||
from botocore.exceptions import (
|
||||
ClientError,
|
||||
ConnectionError as BotocoreConnectionError,
|
||||
ReadTimeoutError,
|
||||
)
|
||||
except ImportError:
|
||||
# If botocore is not installed, define placeholders
|
||||
ClientError = Exception
|
||||
BotocoreConnectionError = Exception
|
||||
ReadTimeoutError = Exception
|
||||
|
||||
|
||||
class BedrockError(Exception):
|
||||
"""Generic error for issues related to Amazon Bedrock"""
|
||||
|
||||
|
||||
class BedrockRateLimitError(BedrockError):
|
||||
"""Error for rate limiting and throttling issues"""
|
||||
|
||||
|
||||
class BedrockConnectionError(BedrockError):
|
||||
"""Error for network and connection issues"""
|
||||
|
||||
|
||||
class BedrockTimeoutError(BedrockError):
|
||||
"""Error for timeout issues"""
|
||||
|
||||
|
||||
def _set_env_if_present(key: str, value):
|
||||
"""Set environment variable only if a non-empty value is provided."""
|
||||
if value is not None and value != "":
|
||||
os.environ[key] = value
|
||||
|
||||
|
||||
def _handle_bedrock_exception(e: Exception, operation: str = "Bedrock API") -> None:
|
||||
"""Convert AWS Bedrock exceptions to appropriate custom exceptions.
|
||||
|
||||
Args:
|
||||
e: The exception to handle
|
||||
operation: Description of the operation for error messages
|
||||
|
||||
Raises:
|
||||
BedrockRateLimitError: For rate limiting and throttling issues (retryable)
|
||||
BedrockConnectionError: For network and server issues (retryable)
|
||||
BedrockTimeoutError: For timeout issues (retryable)
|
||||
BedrockError: For validation and other non-retryable errors
|
||||
"""
|
||||
error_message = str(e)
|
||||
|
||||
# Handle botocore ClientError with specific error codes
|
||||
if isinstance(e, ClientError):
|
||||
error_code = e.response.get("Error", {}).get("Code", "")
|
||||
error_msg = e.response.get("Error", {}).get("Message", error_message)
|
||||
|
||||
# Rate limiting and throttling errors (retryable)
|
||||
if error_code in [
|
||||
"ThrottlingException",
|
||||
"ProvisionedThroughputExceededException",
|
||||
]:
|
||||
logging.error(f"{operation} rate limit error: {error_msg}")
|
||||
raise BedrockRateLimitError(f"Rate limit error: {error_msg}")
|
||||
|
||||
# Server errors (retryable)
|
||||
elif error_code in ["ServiceUnavailableException", "InternalServerException"]:
|
||||
logging.error(f"{operation} connection error: {error_msg}")
|
||||
raise BedrockConnectionError(f"Service error: {error_msg}")
|
||||
|
||||
# Check for 5xx HTTP status codes (retryable)
|
||||
elif e.response.get("ResponseMetadata", {}).get("HTTPStatusCode", 0) >= 500:
|
||||
logging.error(f"{operation} server error: {error_msg}")
|
||||
raise BedrockConnectionError(f"Server error: {error_msg}")
|
||||
|
||||
# Validation and other client errors (non-retryable)
|
||||
else:
|
||||
logging.error(f"{operation} client error: {error_msg}")
|
||||
raise BedrockError(f"Client error: {error_msg}")
|
||||
|
||||
# Connection errors (retryable)
|
||||
elif isinstance(e, BotocoreConnectionError):
|
||||
logging.error(f"{operation} connection error: {error_message}")
|
||||
raise BedrockConnectionError(f"Connection error: {error_message}")
|
||||
|
||||
# Timeout errors (retryable)
|
||||
elif isinstance(e, (ReadTimeoutError, TimeoutError)):
|
||||
logging.error(f"{operation} timeout error: {error_message}")
|
||||
raise BedrockTimeoutError(f"Timeout error: {error_message}")
|
||||
|
||||
# Custom Bedrock errors (already properly typed)
|
||||
elif isinstance(
|
||||
e,
|
||||
(
|
||||
BedrockRateLimitError,
|
||||
BedrockConnectionError,
|
||||
BedrockTimeoutError,
|
||||
BedrockError,
|
||||
),
|
||||
):
|
||||
raise
|
||||
|
||||
# Unknown errors (non-retryable)
|
||||
else:
|
||||
logging.error(f"{operation} unexpected error: {error_message}")
|
||||
raise BedrockError(f"Unexpected error: {error_message}")
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=(
|
||||
retry_if_exception_type(BedrockRateLimitError)
|
||||
| retry_if_exception_type(BedrockConnectionError)
|
||||
| retry_if_exception_type(BedrockTimeoutError)
|
||||
),
|
||||
)
|
||||
async def bedrock_complete_if_cache(
|
||||
model,
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
enable_cot: bool = False,
|
||||
aws_access_key_id=None,
|
||||
aws_secret_access_key=None,
|
||||
aws_session_token=None,
|
||||
**kwargs,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
if enable_cot:
|
||||
import logging
|
||||
|
||||
logging.debug(
|
||||
"enable_cot=True is not supported for Bedrock and will be ignored."
|
||||
)
|
||||
# Respect existing env; only set if a non-empty value is available
|
||||
access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id
|
||||
secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key
|
||||
session_token = os.environ.get("AWS_SESSION_TOKEN") or aws_session_token
|
||||
_set_env_if_present("AWS_ACCESS_KEY_ID", access_key)
|
||||
_set_env_if_present("AWS_SECRET_ACCESS_KEY", secret_key)
|
||||
_set_env_if_present("AWS_SESSION_TOKEN", session_token)
|
||||
# Region handling: prefer env, else kwarg (optional)
|
||||
region = os.environ.get("AWS_REGION") or kwargs.pop("aws_region", None)
|
||||
kwargs.pop("hashing_kv", None)
|
||||
# Capture stream flag (if provided) and remove from kwargs since it's not a Bedrock API parameter
|
||||
# We'll use this to determine whether to call converse_stream or converse
|
||||
stream = bool(kwargs.pop("stream", False))
|
||||
# Remove unsupported args for Bedrock Converse API
|
||||
for k in [
|
||||
"response_format",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"seed",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"n",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"max_completion_tokens",
|
||||
"response_format",
|
||||
]:
|
||||
kwargs.pop(k, None)
|
||||
# Fix message history format
|
||||
messages = []
|
||||
for history_message in history_messages:
|
||||
message = copy.copy(history_message)
|
||||
message["content"] = [{"text": message["content"]}]
|
||||
messages.append(message)
|
||||
|
||||
# Add user prompt
|
||||
messages.append({"role": "user", "content": [{"text": prompt}]})
|
||||
|
||||
# Initialize Converse API arguments
|
||||
args = {"modelId": model, "messages": messages}
|
||||
|
||||
# Define system prompt
|
||||
if system_prompt:
|
||||
args["system"] = [{"text": system_prompt}]
|
||||
|
||||
# Map and set up inference parameters
|
||||
inference_params_map = {
|
||||
"max_tokens": "maxTokens",
|
||||
"top_p": "topP",
|
||||
"stop_sequences": "stopSequences",
|
||||
}
|
||||
if inference_params := list(
|
||||
set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
|
||||
):
|
||||
args["inferenceConfig"] = {}
|
||||
for param in inference_params:
|
||||
args["inferenceConfig"][inference_params_map.get(param, param)] = (
|
||||
kwargs.pop(param)
|
||||
)
|
||||
|
||||
# Import logging for error handling
|
||||
import logging
|
||||
|
||||
# For streaming responses, we need a different approach to keep the connection open
|
||||
if stream:
|
||||
# Create a session that will be used throughout the streaming process
|
||||
session = aioboto3.Session()
|
||||
client = None
|
||||
|
||||
# Define the generator function that will manage the client lifecycle
|
||||
async def stream_generator():
|
||||
nonlocal client
|
||||
|
||||
# Create the client outside the generator to ensure it stays open
|
||||
client = await session.client(
|
||||
"bedrock-runtime", region_name=region
|
||||
).__aenter__()
|
||||
event_stream = None
|
||||
iteration_started = False
|
||||
|
||||
try:
|
||||
# Make the API call
|
||||
response = await client.converse_stream(**args, **kwargs)
|
||||
event_stream = response.get("stream")
|
||||
iteration_started = True
|
||||
|
||||
# Process the stream
|
||||
async for event in event_stream:
|
||||
# Validate event structure
|
||||
if not event or not isinstance(event, dict):
|
||||
continue
|
||||
|
||||
if "contentBlockDelta" in event:
|
||||
delta = event["contentBlockDelta"].get("delta", {})
|
||||
text = delta.get("text")
|
||||
if text:
|
||||
yield text
|
||||
# Handle other event types that might indicate stream end
|
||||
elif "messageStop" in event:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
# Try to clean up resources if possible
|
||||
if (
|
||||
iteration_started
|
||||
and event_stream
|
||||
and hasattr(event_stream, "aclose")
|
||||
and callable(getattr(event_stream, "aclose", None))
|
||||
):
|
||||
try:
|
||||
await event_stream.aclose()
|
||||
except Exception as close_error:
|
||||
logging.warning(
|
||||
f"Failed to close Bedrock event stream: {close_error}"
|
||||
)
|
||||
|
||||
# Convert to appropriate exception type
|
||||
_handle_bedrock_exception(e, "Bedrock streaming")
|
||||
|
||||
finally:
|
||||
# Clean up the event stream
|
||||
if (
|
||||
iteration_started
|
||||
and event_stream
|
||||
and hasattr(event_stream, "aclose")
|
||||
and callable(getattr(event_stream, "aclose", None))
|
||||
):
|
||||
try:
|
||||
await event_stream.aclose()
|
||||
except Exception as close_error:
|
||||
logging.warning(
|
||||
f"Failed to close Bedrock event stream in finally block: {close_error}"
|
||||
)
|
||||
|
||||
# Clean up the client
|
||||
if client:
|
||||
try:
|
||||
await client.__aexit__(None, None, None)
|
||||
except Exception as client_close_error:
|
||||
logging.warning(
|
||||
f"Failed to close Bedrock client: {client_close_error}"
|
||||
)
|
||||
|
||||
# Return the generator that manages its own lifecycle
|
||||
return stream_generator()
|
||||
|
||||
# For non-streaming responses, use the standard async context manager pattern
|
||||
session = aioboto3.Session()
|
||||
async with session.client(
|
||||
"bedrock-runtime", region_name=region
|
||||
) as bedrock_async_client:
|
||||
try:
|
||||
# Use converse for non-streaming responses
|
||||
response = await bedrock_async_client.converse(**args, **kwargs)
|
||||
|
||||
# Validate response structure
|
||||
if (
|
||||
not response
|
||||
or "output" not in response
|
||||
or "message" not in response["output"]
|
||||
or "content" not in response["output"]["message"]
|
||||
or not response["output"]["message"]["content"]
|
||||
):
|
||||
raise BedrockError("Invalid response structure from Bedrock API")
|
||||
|
||||
content = response["output"]["message"]["content"][0]["text"]
|
||||
|
||||
if not content or content.strip() == "":
|
||||
raise BedrockError("Received empty content from Bedrock API")
|
||||
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
# Convert to appropriate exception type
|
||||
_handle_bedrock_exception(e, "Bedrock converse")
|
||||
|
||||
|
||||
# Generic Bedrock completion function
|
||||
async def bedrock_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||
result = await bedrock_complete_if_cache(
|
||||
model_name,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(
|
||||
embedding_dim=1024, max_token_size=8192, model_name="amazon.titan-embed-text-v2:0"
|
||||
)
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=(
|
||||
retry_if_exception_type(BedrockRateLimitError)
|
||||
| retry_if_exception_type(BedrockConnectionError)
|
||||
| retry_if_exception_type(BedrockTimeoutError)
|
||||
),
|
||||
)
|
||||
async def bedrock_embed(
|
||||
texts: list[str],
|
||||
model: str = "amazon.titan-embed-text-v2:0",
|
||||
aws_access_key_id=None,
|
||||
aws_secret_access_key=None,
|
||||
aws_session_token=None,
|
||||
) -> np.ndarray:
|
||||
# Respect existing env; only set if a non-empty value is available
|
||||
access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id
|
||||
secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key
|
||||
session_token = os.environ.get("AWS_SESSION_TOKEN") or aws_session_token
|
||||
_set_env_if_present("AWS_ACCESS_KEY_ID", access_key)
|
||||
_set_env_if_present("AWS_SECRET_ACCESS_KEY", secret_key)
|
||||
_set_env_if_present("AWS_SESSION_TOKEN", session_token)
|
||||
|
||||
# Region handling: prefer env
|
||||
region = os.environ.get("AWS_REGION")
|
||||
|
||||
session = aioboto3.Session()
|
||||
async with session.client(
|
||||
"bedrock-runtime", region_name=region
|
||||
) as bedrock_async_client:
|
||||
try:
|
||||
if (model_provider := model.split(".")[0]) == "amazon":
|
||||
embed_texts = []
|
||||
for text in texts:
|
||||
try:
|
||||
if "v2" in model:
|
||||
body = json.dumps(
|
||||
{
|
||||
"inputText": text,
|
||||
# 'dimensions': embedding_dim,
|
||||
"embeddingTypes": ["float"],
|
||||
}
|
||||
)
|
||||
elif "v1" in model:
|
||||
body = json.dumps({"inputText": text})
|
||||
else:
|
||||
raise BedrockError(f"Model {model} is not supported!")
|
||||
|
||||
response = await bedrock_async_client.invoke_model(
|
||||
modelId=model,
|
||||
body=body,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
response_body = await response.get("body").json()
|
||||
|
||||
# Validate response structure
|
||||
if not response_body or "embedding" not in response_body:
|
||||
raise BedrockError(
|
||||
f"Invalid embedding response structure for text: {text[:50]}..."
|
||||
)
|
||||
|
||||
embedding = response_body["embedding"]
|
||||
if not embedding:
|
||||
raise BedrockError(
|
||||
f"Received empty embedding for text: {text[:50]}..."
|
||||
)
|
||||
|
||||
embed_texts.append(embedding)
|
||||
|
||||
except Exception as e:
|
||||
# Convert to appropriate exception type
|
||||
_handle_bedrock_exception(
|
||||
e, "Bedrock embedding (amazon, text chunk)"
|
||||
)
|
||||
|
||||
elif model_provider == "cohere":
|
||||
try:
|
||||
body = json.dumps(
|
||||
{
|
||||
"texts": texts,
|
||||
"input_type": "search_document",
|
||||
"truncate": "NONE",
|
||||
}
|
||||
)
|
||||
|
||||
response = await bedrock_async_client.invoke_model(
|
||||
model=model,
|
||||
body=body,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
response_body = json.loads(response.get("body").read())
|
||||
|
||||
# Validate response structure
|
||||
if not response_body or "embeddings" not in response_body:
|
||||
raise BedrockError(
|
||||
"Invalid embedding response structure from Cohere"
|
||||
)
|
||||
|
||||
embeddings = response_body["embeddings"]
|
||||
if not embeddings or len(embeddings) != len(texts):
|
||||
raise BedrockError(
|
||||
f"Invalid embeddings count: expected {len(texts)}, got {len(embeddings) if embeddings else 0}"
|
||||
)
|
||||
|
||||
embed_texts = embeddings
|
||||
|
||||
except Exception as e:
|
||||
# Convert to appropriate exception type
|
||||
_handle_bedrock_exception(e, "Bedrock embedding (cohere)")
|
||||
|
||||
else:
|
||||
raise BedrockError(
|
||||
f"Model provider '{model_provider}' is not supported!"
|
||||
)
|
||||
|
||||
# Final validation
|
||||
if not embed_texts:
|
||||
raise BedrockError("No embeddings generated")
|
||||
|
||||
return np.array(embed_texts)
|
||||
|
||||
except Exception as e:
|
||||
# Convert to appropriate exception type
|
||||
_handle_bedrock_exception(e, "Bedrock embedding")
|
||||
@@ -0,0 +1,740 @@
|
||||
"""
|
||||
Module that implements containers for specific LLM bindings.
|
||||
|
||||
This module provides container implementations for various Large Language Model
|
||||
bindings and integrations.
|
||||
"""
|
||||
|
||||
from argparse import ArgumentParser, Namespace
|
||||
import argparse
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, ClassVar, List, get_args, get_origin
|
||||
|
||||
from lightrag.utils import get_env_value
|
||||
from lightrag.constants import DEFAULT_TEMPERATURE
|
||||
|
||||
|
||||
def _resolve_optional_type(field_type: Any) -> Any:
|
||||
"""Return the concrete type for Optional/Union annotations."""
|
||||
origin = get_origin(field_type)
|
||||
if origin in (list, dict, tuple):
|
||||
return field_type
|
||||
|
||||
args = get_args(field_type)
|
||||
if args:
|
||||
non_none_args = [arg for arg in args if arg is not type(None)]
|
||||
if len(non_none_args) == 1:
|
||||
return non_none_args[0]
|
||||
return field_type
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# BindingOptions Base Class
|
||||
# =============================================================================
|
||||
#
|
||||
# The BindingOptions class serves as the foundation for all LLM provider bindings
|
||||
# in LightRAG. It provides a standardized framework for:
|
||||
#
|
||||
# 1. Configuration Management:
|
||||
# - Defines how each LLM provider's configuration parameters are structured
|
||||
# - Handles default values and type information for each parameter
|
||||
# - Maps configuration options to command-line arguments and environment variables
|
||||
#
|
||||
# 2. Environment Integration:
|
||||
# - Automatically generates environment variable names from binding parameters
|
||||
# - Provides methods to create sample .env files for easy configuration
|
||||
# - Supports configuration via environment variables with fallback to defaults
|
||||
#
|
||||
# 3. Command-Line Interface:
|
||||
# - Dynamically generates command-line arguments for all registered bindings
|
||||
# - Maintains consistent naming conventions across different LLM providers
|
||||
# - Provides help text and type validation for each configuration option
|
||||
#
|
||||
# 4. Extensibility:
|
||||
# - Uses class introspection to automatically discover all binding subclasses
|
||||
# - Requires minimal boilerplate code when adding new LLM provider bindings
|
||||
# - Maintains separation of concerns between different provider configurations
|
||||
#
|
||||
# This design pattern ensures that adding support for a new LLM provider requires
|
||||
# only defining the provider-specific parameters and help text, while the base
|
||||
# class handles all the common functionality for argument parsing, environment
|
||||
# variable handling, and configuration management.
|
||||
#
|
||||
# Instances of a derived class of BindingOptions can be used to store multiple
|
||||
# runtime configurations of options for a single LLM provider. using the
|
||||
# asdict() method to convert the options to a dictionary.
|
||||
#
|
||||
# =============================================================================
|
||||
@dataclass
|
||||
class BindingOptions:
|
||||
"""Base class for binding options."""
|
||||
|
||||
# mandatory name of binding
|
||||
_binding_name: ClassVar[str]
|
||||
|
||||
# optional help message for each option
|
||||
_help: ClassVar[dict[str, str]]
|
||||
|
||||
@staticmethod
|
||||
def _all_class_vars(klass: type, include_inherited=True) -> dict[str, Any]:
|
||||
"""Print class variables, optionally including inherited ones"""
|
||||
if include_inherited:
|
||||
# Get all class variables from MRO
|
||||
vars_dict = {}
|
||||
for base in reversed(klass.__mro__[:-1]): # Exclude 'object'
|
||||
vars_dict.update(
|
||||
{
|
||||
k: v
|
||||
for k, v in base.__dict__.items()
|
||||
if (
|
||||
not k.startswith("_")
|
||||
and not callable(v)
|
||||
and not isinstance(v, classmethod)
|
||||
)
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Only direct class variables
|
||||
vars_dict = {
|
||||
k: v
|
||||
for k, v in klass.__dict__.items()
|
||||
if (
|
||||
not k.startswith("_")
|
||||
and not callable(v)
|
||||
and not isinstance(v, classmethod)
|
||||
)
|
||||
}
|
||||
|
||||
return vars_dict
|
||||
|
||||
@classmethod
|
||||
def add_args(cls, parser: ArgumentParser):
|
||||
group = parser.add_argument_group(f"{cls._binding_name} binding options")
|
||||
for arg_item in cls.args_env_name_type_value():
|
||||
# Handle JSON parsing for list types
|
||||
if arg_item["type"] is List[str]:
|
||||
|
||||
def json_list_parser(value):
|
||||
try:
|
||||
parsed = json.loads(value)
|
||||
if not isinstance(parsed, list):
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Expected JSON array, got {type(parsed).__name__}"
|
||||
)
|
||||
return parsed
|
||||
except json.JSONDecodeError as e:
|
||||
raise argparse.ArgumentTypeError(f"Invalid JSON: {e}")
|
||||
|
||||
# Get environment variable with JSON parsing
|
||||
env_value = get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS)
|
||||
if env_value is not argparse.SUPPRESS:
|
||||
try:
|
||||
env_value = json_list_parser(env_value)
|
||||
except argparse.ArgumentTypeError:
|
||||
env_value = argparse.SUPPRESS
|
||||
|
||||
group.add_argument(
|
||||
f"--{arg_item['argname']}",
|
||||
type=json_list_parser,
|
||||
default=env_value,
|
||||
help=arg_item["help"],
|
||||
)
|
||||
# Handle JSON parsing for dict types
|
||||
elif arg_item["type"] is dict:
|
||||
|
||||
def json_dict_parser(value):
|
||||
try:
|
||||
parsed = json.loads(value)
|
||||
if not isinstance(parsed, dict):
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Expected JSON object, got {type(parsed).__name__}"
|
||||
)
|
||||
return parsed
|
||||
except json.JSONDecodeError as e:
|
||||
raise argparse.ArgumentTypeError(f"Invalid JSON: {e}")
|
||||
|
||||
# Get environment variable with JSON parsing
|
||||
env_value = get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS)
|
||||
if env_value is not argparse.SUPPRESS:
|
||||
try:
|
||||
env_value = json_dict_parser(env_value)
|
||||
except argparse.ArgumentTypeError:
|
||||
env_value = argparse.SUPPRESS
|
||||
|
||||
group.add_argument(
|
||||
f"--{arg_item['argname']}",
|
||||
type=json_dict_parser,
|
||||
default=env_value,
|
||||
help=arg_item["help"],
|
||||
)
|
||||
# Handle boolean types specially to avoid argparse bool() constructor issues
|
||||
elif arg_item["type"] is bool:
|
||||
|
||||
def bool_parser(value):
|
||||
"""Custom boolean parser that handles string representations correctly"""
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.lower() in ("true", "1", "yes", "t", "on")
|
||||
return bool(value)
|
||||
|
||||
# Get environment variable with proper type conversion
|
||||
env_value = get_env_value(
|
||||
f"{arg_item['env_name']}", argparse.SUPPRESS, bool
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
f"--{arg_item['argname']}",
|
||||
type=bool_parser,
|
||||
default=env_value,
|
||||
help=arg_item["help"],
|
||||
)
|
||||
else:
|
||||
resolved_type = arg_item["type"]
|
||||
if resolved_type is not None:
|
||||
resolved_type = _resolve_optional_type(resolved_type)
|
||||
|
||||
group.add_argument(
|
||||
f"--{arg_item['argname']}",
|
||||
type=resolved_type,
|
||||
default=get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS),
|
||||
help=arg_item["help"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def args_env_name_type_value(cls):
|
||||
import dataclasses
|
||||
|
||||
args_prefix = f"{cls._binding_name}".replace("_", "-")
|
||||
env_var_prefix = f"{cls._binding_name}_".upper()
|
||||
help = cls._help
|
||||
|
||||
# Check if this is a dataclass and use dataclass fields
|
||||
if dataclasses.is_dataclass(cls):
|
||||
for field in dataclasses.fields(cls):
|
||||
# Skip private fields
|
||||
if field.name.startswith("_"):
|
||||
continue
|
||||
|
||||
# Get default value
|
||||
if field.default is not dataclasses.MISSING:
|
||||
default_value = field.default
|
||||
elif field.default_factory is not dataclasses.MISSING:
|
||||
default_value = field.default_factory()
|
||||
else:
|
||||
default_value = None
|
||||
|
||||
argdef = {
|
||||
"argname": f"{args_prefix}-{field.name}",
|
||||
"env_name": f"{env_var_prefix}{field.name.upper()}",
|
||||
"type": _resolve_optional_type(field.type),
|
||||
"default": default_value,
|
||||
"help": f"{cls._binding_name} -- " + help.get(field.name, ""),
|
||||
}
|
||||
|
||||
yield argdef
|
||||
else:
|
||||
# Fallback to old method for non-dataclass classes
|
||||
class_vars = {
|
||||
key: value
|
||||
for key, value in cls._all_class_vars(cls).items()
|
||||
if not callable(value) and not key.startswith("_")
|
||||
}
|
||||
|
||||
# Get type hints to properly detect List[str] types
|
||||
type_hints = {}
|
||||
for base in cls.__mro__:
|
||||
if hasattr(base, "__annotations__"):
|
||||
type_hints.update(base.__annotations__)
|
||||
|
||||
for class_var in class_vars:
|
||||
# Use type hint if available, otherwise fall back to type of value
|
||||
var_type = type_hints.get(class_var, type(class_vars[class_var]))
|
||||
|
||||
argdef = {
|
||||
"argname": f"{args_prefix}-{class_var}",
|
||||
"env_name": f"{env_var_prefix}{class_var.upper()}",
|
||||
"type": var_type,
|
||||
"default": class_vars[class_var],
|
||||
"help": f"{cls._binding_name} -- " + help.get(class_var, ""),
|
||||
}
|
||||
|
||||
yield argdef
|
||||
|
||||
@classmethod
|
||||
def generate_dot_env_sample(cls):
|
||||
"""
|
||||
Generate a sample .env file for all LightRAG binding options.
|
||||
|
||||
This method creates a .env file that includes all the binding options
|
||||
defined by the subclasses of BindingOptions. It uses the args_env_name_type_value()
|
||||
method to get the list of all options and their default values.
|
||||
|
||||
Returns:
|
||||
str: A string containing the contents of the sample .env file.
|
||||
"""
|
||||
from io import StringIO
|
||||
|
||||
sample_top = (
|
||||
"#" * 80
|
||||
+ "\n"
|
||||
+ (
|
||||
"# Autogenerated .env entries list for LightRAG binding options\n"
|
||||
"#\n"
|
||||
"# To generate run:\n"
|
||||
"# $ python -m lightrag.llm.binding_options\n"
|
||||
)
|
||||
+ "#" * 80
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
sample_bottom = (
|
||||
("#\n# End of .env entries for LightRAG binding options\n")
|
||||
+ "#" * 80
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
sample_stream = StringIO()
|
||||
sample_stream.write(sample_top)
|
||||
for klass in cls.__subclasses__():
|
||||
for arg_item in klass.args_env_name_type_value():
|
||||
if arg_item["help"]:
|
||||
sample_stream.write(f"# {arg_item['help']}\n")
|
||||
|
||||
# Handle JSON formatting for list and dict types
|
||||
if arg_item["type"] is List[str] or arg_item["type"] is dict:
|
||||
default_value = json.dumps(arg_item["default"])
|
||||
else:
|
||||
default_value = arg_item["default"]
|
||||
|
||||
sample_stream.write(f"# {arg_item['env_name']}={default_value}\n\n")
|
||||
|
||||
sample_stream.write(sample_bottom)
|
||||
return sample_stream.getvalue()
|
||||
|
||||
@classmethod
|
||||
def options_dict(cls, args: Namespace) -> dict[str, Any]:
|
||||
"""
|
||||
Extract options dictionary for a specific binding from parsed arguments.
|
||||
|
||||
This method filters the parsed command-line arguments to return only those
|
||||
that belong to the specific binding class. It removes the binding prefix
|
||||
from argument names to create a clean options dictionary.
|
||||
|
||||
Args:
|
||||
args (Namespace): Parsed command-line arguments containing all binding options
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Dictionary mapping option names (without prefix) to their values
|
||||
|
||||
Example:
|
||||
If args contains {'ollama_num_ctx': 512, 'other_option': 'value'}
|
||||
and this is called on OllamaOptions, it returns {'num_ctx': 512}
|
||||
"""
|
||||
prefix = cls._binding_name + "_"
|
||||
skipchars = len(prefix)
|
||||
options = {
|
||||
key[skipchars:]: value
|
||||
for key, value in vars(args).items()
|
||||
if key.startswith(prefix)
|
||||
}
|
||||
|
||||
return options
|
||||
|
||||
def asdict(self) -> dict[str, Any]:
|
||||
"""
|
||||
Convert an instance of binding options to a dictionary.
|
||||
|
||||
This method uses dataclasses.asdict() to convert the dataclass instance
|
||||
into a dictionary representation, including all its fields and values.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Dictionary representation of the binding options instance
|
||||
"""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Binding Options for Ollama
|
||||
# =============================================================================
|
||||
#
|
||||
# Ollama binding options provide configuration for the Ollama local LLM server.
|
||||
# These options control model behavior, sampling parameters, hardware utilization,
|
||||
# and performance settings. The parameters are based on Ollama's API specification
|
||||
# and provide fine-grained control over model inference and generation.
|
||||
#
|
||||
# The _OllamaOptionsMixin defines the complete set of available options, while
|
||||
# OllamaEmbeddingOptions and OllamaLLMOptions provide specialized configurations
|
||||
# for embedding and language model tasks respectively.
|
||||
# =============================================================================
|
||||
@dataclass
|
||||
class _OllamaOptionsMixin:
|
||||
"""Options for Ollama bindings."""
|
||||
|
||||
# Core context and generation parameters
|
||||
num_ctx: int = 32768 # Context window size (number of tokens)
|
||||
num_predict: int = 128 # Maximum number of tokens to predict
|
||||
num_keep: int = 0 # Number of tokens to keep from the initial prompt
|
||||
seed: int = -1 # Random seed for generation (-1 for random)
|
||||
|
||||
# Sampling parameters
|
||||
temperature: float = DEFAULT_TEMPERATURE # Controls randomness (0.0-2.0)
|
||||
top_k: int = 40 # Top-k sampling parameter
|
||||
top_p: float = 0.9 # Top-p (nucleus) sampling parameter
|
||||
tfs_z: float = 1.0 # Tail free sampling parameter
|
||||
typical_p: float = 1.0 # Typical probability mass
|
||||
min_p: float = 0.0 # Minimum probability threshold
|
||||
|
||||
# Repetition control
|
||||
repeat_last_n: int = 64 # Number of tokens to consider for repetition penalty
|
||||
repeat_penalty: float = 1.1 # Penalty for repetition
|
||||
presence_penalty: float = 0.0 # Penalty for token presence
|
||||
frequency_penalty: float = 0.0 # Penalty for token frequency
|
||||
|
||||
# Mirostat sampling
|
||||
mirostat: int = (
|
||||
# Mirostat sampling algorithm (0=disabled, 1=Mirostat 1.0, 2=Mirostat 2.0)
|
||||
0
|
||||
)
|
||||
mirostat_tau: float = 5.0 # Mirostat target entropy
|
||||
mirostat_eta: float = 0.1 # Mirostat learning rate
|
||||
|
||||
# Hardware and performance parameters
|
||||
numa: bool = False # Enable NUMA optimization
|
||||
num_batch: int = 512 # Batch size for processing
|
||||
num_gpu: int = -1 # Number of GPUs to use (-1 for auto)
|
||||
main_gpu: int = 0 # Main GPU index
|
||||
low_vram: bool = False # Optimize for low VRAM
|
||||
num_thread: int = 0 # Number of CPU threads (0 for auto)
|
||||
|
||||
# Memory and model parameters
|
||||
f16_kv: bool = True # Use half-precision for key/value cache
|
||||
logits_all: bool = False # Return logits for all tokens
|
||||
vocab_only: bool = False # Only load vocabulary
|
||||
use_mmap: bool = True # Use memory mapping for model files
|
||||
use_mlock: bool = False # Lock model in memory
|
||||
embedding_only: bool = False # Only use for embeddings
|
||||
|
||||
# Output control
|
||||
penalize_newline: bool = True # Penalize newline tokens
|
||||
stop: List[str] = field(default_factory=list) # Stop sequences
|
||||
|
||||
# optional help strings
|
||||
_help: ClassVar[dict[str, str]] = {
|
||||
"num_ctx": "Context window size (number of tokens)",
|
||||
"num_predict": "Maximum number of tokens to predict",
|
||||
"num_keep": "Number of tokens to keep from the initial prompt",
|
||||
"seed": "Random seed for generation (-1 for random)",
|
||||
"temperature": "Controls randomness (0.0-2.0, higher = more creative)",
|
||||
"top_k": "Top-k sampling parameter (0 = disabled)",
|
||||
"top_p": "Top-p (nucleus) sampling parameter (0.0-1.0)",
|
||||
"tfs_z": "Tail free sampling parameter (1.0 = disabled)",
|
||||
"typical_p": "Typical probability mass (1.0 = disabled)",
|
||||
"min_p": "Minimum probability threshold (0.0 = disabled)",
|
||||
"repeat_last_n": "Number of tokens to consider for repetition penalty",
|
||||
"repeat_penalty": "Penalty for repetition (1.0 = no penalty)",
|
||||
"presence_penalty": "Penalty for token presence (-2.0 to 2.0)",
|
||||
"frequency_penalty": "Penalty for token frequency (-2.0 to 2.0)",
|
||||
"mirostat": "Mirostat sampling algorithm (0=disabled, 1=Mirostat 1.0, 2=Mirostat 2.0)",
|
||||
"mirostat_tau": "Mirostat target entropy",
|
||||
"mirostat_eta": "Mirostat learning rate",
|
||||
"numa": "Enable NUMA optimization",
|
||||
"num_batch": "Batch size for processing",
|
||||
"num_gpu": "Number of GPUs to use (-1 for auto)",
|
||||
"main_gpu": "Main GPU index",
|
||||
"low_vram": "Optimize for low VRAM",
|
||||
"num_thread": "Number of CPU threads (0 for auto)",
|
||||
"f16_kv": "Use half-precision for key/value cache",
|
||||
"logits_all": "Return logits for all tokens",
|
||||
"vocab_only": "Only load vocabulary",
|
||||
"use_mmap": "Use memory mapping for model files",
|
||||
"use_mlock": "Lock model in memory",
|
||||
"embedding_only": "Only use for embeddings",
|
||||
"penalize_newline": "Penalize newline tokens",
|
||||
"stop": 'Stop sequences (JSON array of strings, e.g., \'["</s>", "\\n\\n"]\')',
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class OllamaEmbeddingOptions(_OllamaOptionsMixin, BindingOptions):
|
||||
"""Options for Ollama embeddings with specialized configuration for embedding tasks."""
|
||||
|
||||
# mandatory name of binding
|
||||
_binding_name: ClassVar[str] = "ollama_embedding"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OllamaLLMOptions(_OllamaOptionsMixin, BindingOptions):
|
||||
"""Options for Ollama LLM with specialized configuration for LLM tasks."""
|
||||
|
||||
# mandatory name of binding
|
||||
_binding_name: ClassVar[str] = "ollama_llm"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Binding Options for Gemini
|
||||
# =============================================================================
|
||||
@dataclass
|
||||
class GeminiLLMOptions(BindingOptions):
|
||||
"""Options for Google Gemini models."""
|
||||
|
||||
_binding_name: ClassVar[str] = "gemini_llm"
|
||||
|
||||
temperature: float = DEFAULT_TEMPERATURE
|
||||
top_p: float = 0.95
|
||||
top_k: int = 40
|
||||
max_output_tokens: int | None = None
|
||||
candidate_count: int = 1
|
||||
presence_penalty: float = 0.0
|
||||
frequency_penalty: float = 0.0
|
||||
stop_sequences: List[str] = field(default_factory=list)
|
||||
seed: int | None = None
|
||||
thinking_config: dict | None = None
|
||||
safety_settings: dict | None = None
|
||||
|
||||
_help: ClassVar[dict[str, str]] = {
|
||||
"temperature": "Controls randomness (0.0-2.0, higher = more creative)",
|
||||
"top_p": "Nucleus sampling parameter (0.0-1.0)",
|
||||
"top_k": "Limits sampling to the top K tokens (1 disables the limit)",
|
||||
"max_output_tokens": "Maximum tokens generated in the response",
|
||||
"candidate_count": "Number of candidates returned per request",
|
||||
"presence_penalty": "Penalty for token presence (-2.0 to 2.0)",
|
||||
"frequency_penalty": "Penalty for token frequency (-2.0 to 2.0)",
|
||||
"stop_sequences": "Stop sequences (JSON array of strings, e.g., '[\"END\"]')",
|
||||
"seed": "Random seed for reproducible generation (leave empty for random)",
|
||||
"thinking_config": "Thinking configuration (JSON dict, e.g., '{\"thinking_budget\": 1024}' or '{\"include_thoughts\": true}')",
|
||||
"safety_settings": "JSON object with Gemini safety settings overrides",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeminiEmbeddingOptions(BindingOptions):
|
||||
"""Options for Google Gemini embedding models."""
|
||||
|
||||
_binding_name: ClassVar[str] = "gemini_embedding"
|
||||
|
||||
task_type: str | None = None
|
||||
|
||||
_help: ClassVar[dict[str, str]] = {
|
||||
"task_type": "Task type for embedding optimization. If not specified, automatically determined from context (RETRIEVAL_QUERY for queries, RETRIEVAL_DOCUMENT for documents). Supported types: RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING, CODE_RETRIEVAL_QUERY, QUESTION_ANSWERING, FACT_VERIFICATION",
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Binding Options for OpenAI
|
||||
# =============================================================================
|
||||
#
|
||||
# OpenAI binding options provide configuration for OpenAI's API and Azure OpenAI.
|
||||
# These options control model behavior, sampling parameters, and generation settings.
|
||||
# The parameters are based on OpenAI's API specification and provide fine-grained
|
||||
# control over model inference and generation.
|
||||
#
|
||||
# =============================================================================
|
||||
@dataclass
|
||||
class OpenAILLMOptions(BindingOptions):
|
||||
"""Options for OpenAI LLM with configuration for OpenAI and Azure OpenAI API calls."""
|
||||
|
||||
# mandatory name of binding
|
||||
_binding_name: ClassVar[str] = "openai_llm"
|
||||
|
||||
# Sampling and generation parameters
|
||||
frequency_penalty: float = 0.0 # Penalty for token frequency (-2.0 to 2.0)
|
||||
max_completion_tokens: int = None # Maximum number of tokens to generate
|
||||
presence_penalty: float = 0.0 # Penalty for token presence (-2.0 to 2.0)
|
||||
reasoning_effort: str = "medium" # Reasoning effort level (low, medium, high)
|
||||
safety_identifier: str = "" # Safety identifier for content filtering
|
||||
service_tier: str = "" # Service tier for API usage
|
||||
stop: List[str] = field(default_factory=list) # Stop sequences
|
||||
temperature: float = DEFAULT_TEMPERATURE # Controls randomness (0.0 to 2.0)
|
||||
top_p: float = 1.0 # Nucleus sampling parameter (0.0 to 1.0)
|
||||
max_tokens: int = None # Maximum number of tokens to generate(deprecated, use max_completion_tokens instead)
|
||||
extra_body: dict = None # Extra body parameters for OpenRouter of vLLM
|
||||
|
||||
# Help descriptions
|
||||
_help: ClassVar[dict[str, str]] = {
|
||||
"frequency_penalty": "Penalty for token frequency (-2.0 to 2.0, positive values discourage repetition)",
|
||||
"max_completion_tokens": "Maximum number of tokens to generate (optional, leave empty for model default)",
|
||||
"presence_penalty": "Penalty for token presence (-2.0 to 2.0, positive values encourage new topics)",
|
||||
"reasoning_effort": "Reasoning effort level for o1 models (low, medium, high)",
|
||||
"safety_identifier": "Safety identifier for content filtering (optional)",
|
||||
"service_tier": "Service tier for API usage (optional)",
|
||||
"stop": 'Stop sequences (JSON array of strings, e.g., \'["</s>", "\\n\\n"]\')',
|
||||
"temperature": "Controls randomness (0.0-2.0, higher = more creative)",
|
||||
"top_p": "Nucleus sampling parameter (0.0-1.0, lower = more focused)",
|
||||
"max_tokens": "Maximum number of tokens to generate (deprecated, use max_completion_tokens instead)",
|
||||
"extra_body": 'Extra body parameters for OpenRouter of vLLM (JSON dict, e.g., \'"reasoning": {"reasoning": {"enabled": false}}\')',
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main Section - For Testing and Sample Generation
|
||||
# =============================================================================
|
||||
#
|
||||
# When run as a script, this module:
|
||||
# 1. Generates and prints a sample .env file with all binding options
|
||||
# 2. If "test" argument is provided, demonstrates argument parsing with Ollama binding
|
||||
#
|
||||
# Usage:
|
||||
# python -m lightrag.llm.binding_options # Generate .env sample
|
||||
# python -m lightrag.llm.binding_options test # Test argument parsing
|
||||
#
|
||||
# =============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
import dotenv
|
||||
# from io import StringIO
|
||||
|
||||
dotenv.load_dotenv(dotenv_path=".env", override=False)
|
||||
|
||||
# env_strstream = StringIO(
|
||||
# ("OLLAMA_LLM_TEMPERATURE=0.1\nOLLAMA_EMBEDDING_TEMPERATURE=0.2\n")
|
||||
# )
|
||||
# # Load environment variables from .env file
|
||||
# dotenv.load_dotenv(stream=env_strstream)
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "test":
|
||||
# Add arguments for OllamaEmbeddingOptions, OllamaLLMOptions, and OpenAILLMOptions
|
||||
parser = ArgumentParser(description="Test binding options")
|
||||
OllamaEmbeddingOptions.add_args(parser)
|
||||
OllamaLLMOptions.add_args(parser)
|
||||
OpenAILLMOptions.add_args(parser)
|
||||
|
||||
# Parse arguments test
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"--ollama-embedding-num_ctx",
|
||||
"1024",
|
||||
"--ollama-llm-num_ctx",
|
||||
"2048",
|
||||
"--openai-llm-temperature",
|
||||
"0.7",
|
||||
"--openai-llm-max_completion_tokens",
|
||||
"1000",
|
||||
"--openai-llm-stop",
|
||||
'["</s>", "\\n\\n"]',
|
||||
"--openai-llm-reasoning",
|
||||
'{"effort": "high", "max_tokens": 2000, "exclude": false, "enabled": true}',
|
||||
]
|
||||
)
|
||||
print("Final args for LLM and Embedding:")
|
||||
print(f"{args}\n")
|
||||
|
||||
print("Ollama LLM options:")
|
||||
print(OllamaLLMOptions.options_dict(args))
|
||||
|
||||
print("\nOllama Embedding options:")
|
||||
print(OllamaEmbeddingOptions.options_dict(args))
|
||||
|
||||
print("\nOpenAI LLM options:")
|
||||
print(OpenAILLMOptions.options_dict(args))
|
||||
|
||||
# Test creating OpenAI options instance
|
||||
openai_options = OpenAILLMOptions(
|
||||
temperature=0.8,
|
||||
max_completion_tokens=1500,
|
||||
frequency_penalty=0.1,
|
||||
presence_penalty=0.2,
|
||||
stop=["<|end|>", "\n\n"],
|
||||
)
|
||||
print("\nOpenAI LLM options instance:")
|
||||
print(openai_options.asdict())
|
||||
|
||||
# Test creating OpenAI options instance with reasoning parameter
|
||||
openai_options_with_reasoning = OpenAILLMOptions(
|
||||
temperature=0.9,
|
||||
max_completion_tokens=2000,
|
||||
reasoning={
|
||||
"effort": "medium",
|
||||
"max_tokens": 1500,
|
||||
"exclude": True,
|
||||
"enabled": True,
|
||||
},
|
||||
)
|
||||
print("\nOpenAI LLM options instance with reasoning:")
|
||||
print(openai_options_with_reasoning.asdict())
|
||||
|
||||
# Test dict parsing functionality
|
||||
print("\n" + "=" * 50)
|
||||
print("TESTING DICT PARSING FUNCTIONALITY")
|
||||
print("=" * 50)
|
||||
|
||||
# Test valid JSON dict parsing
|
||||
test_parser = ArgumentParser(description="Test dict parsing")
|
||||
OpenAILLMOptions.add_args(test_parser)
|
||||
|
||||
try:
|
||||
test_args = test_parser.parse_args(
|
||||
["--openai-llm-reasoning", '{"effort": "low", "max_tokens": 1000}']
|
||||
)
|
||||
print("✓ Valid JSON dict parsing successful:")
|
||||
print(
|
||||
f" Parsed reasoning: {OpenAILLMOptions.options_dict(test_args)['reasoning']}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"✗ Valid JSON dict parsing failed: {e}")
|
||||
|
||||
# Test invalid JSON dict parsing
|
||||
try:
|
||||
test_args = test_parser.parse_args(
|
||||
[
|
||||
"--openai-llm-reasoning",
|
||||
'{"effort": "low", "max_tokens": 1000', # Missing closing brace
|
||||
]
|
||||
)
|
||||
print("✗ Invalid JSON should have failed but didn't")
|
||||
except SystemExit:
|
||||
print("✓ Invalid JSON dict parsing correctly rejected")
|
||||
except Exception as e:
|
||||
print(f"✓ Invalid JSON dict parsing correctly rejected: {e}")
|
||||
|
||||
# Test non-dict JSON parsing
|
||||
try:
|
||||
test_args = test_parser.parse_args(
|
||||
[
|
||||
"--openai-llm-reasoning",
|
||||
'["not", "a", "dict"]', # Array instead of dict
|
||||
]
|
||||
)
|
||||
print("✗ Non-dict JSON should have failed but didn't")
|
||||
except SystemExit:
|
||||
print("✓ Non-dict JSON parsing correctly rejected")
|
||||
except Exception as e:
|
||||
print(f"✓ Non-dict JSON parsing correctly rejected: {e}")
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("TESTING ENVIRONMENT VARIABLE SUPPORT")
|
||||
print("=" * 50)
|
||||
|
||||
# Test environment variable support for dict
|
||||
import os
|
||||
|
||||
os.environ["OPENAI_LLM_REASONING"] = (
|
||||
'{"effort": "high", "max_tokens": 3000, "exclude": false}'
|
||||
)
|
||||
|
||||
env_parser = ArgumentParser(description="Test env var dict parsing")
|
||||
OpenAILLMOptions.add_args(env_parser)
|
||||
|
||||
try:
|
||||
env_args = env_parser.parse_args(
|
||||
[]
|
||||
) # No command line args, should use env var
|
||||
reasoning_from_env = OpenAILLMOptions.options_dict(env_args).get(
|
||||
"reasoning"
|
||||
)
|
||||
if reasoning_from_env:
|
||||
print("✓ Environment variable dict parsing successful:")
|
||||
print(f" Parsed reasoning from env: {reasoning_from_env}")
|
||||
else:
|
||||
print("✗ Environment variable dict parsing failed: No reasoning found")
|
||||
except Exception as e:
|
||||
print(f"✗ Environment variable dict parsing failed: {e}")
|
||||
finally:
|
||||
# Clean up environment variable
|
||||
if "OPENAI_LLM_REASONING" in os.environ:
|
||||
del os.environ["OPENAI_LLM_REASONING"]
|
||||
|
||||
else:
|
||||
print(BindingOptions.generate_dot_env_sample())
|
||||
@@ -0,0 +1,69 @@
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
# install specific modules
|
||||
if not pm.is_installed("lmdeploy"):
|
||||
pm.install("lmdeploy")
|
||||
|
||||
from openai import (
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
)
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
|
||||
import numpy as np
|
||||
import aiohttp
|
||||
import base64
|
||||
import struct
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def siliconcloud_embedding(
|
||||
texts: list[str],
|
||||
model: str = "netease-youdao/bce-embedding-base_v1",
|
||||
base_url: str = "https://api.siliconflow.cn/v1/embeddings",
|
||||
max_token_size: int = 8192,
|
||||
api_key: str = None,
|
||||
) -> np.ndarray:
|
||||
if api_key and not api_key.startswith("Bearer "):
|
||||
api_key = "Bearer " + api_key
|
||||
|
||||
headers = {"Authorization": api_key, "Content-Type": "application/json"}
|
||||
|
||||
truncate_texts = [text[0:max_token_size] for text in texts]
|
||||
|
||||
payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
|
||||
|
||||
base64_strings = []
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(base_url, headers=headers, json=payload) as response:
|
||||
content = await response.json()
|
||||
if "code" in content:
|
||||
raise ValueError(content)
|
||||
base64_strings = [item["embedding"] for item in content["data"]]
|
||||
|
||||
embeddings = []
|
||||
for string in base64_strings:
|
||||
decode_bytes = base64.b64decode(string)
|
||||
n = len(decode_bytes) // 4
|
||||
float_array = struct.unpack("<" + "f" * n, decode_bytes)
|
||||
embeddings.append(float_array)
|
||||
return np.array(embeddings)
|
||||
623
.tmp/lightrag_inspect/lightrag_pkg/lightrag/llm/gemini.py
Normal file
623
.tmp/lightrag_inspect/lightrag_pkg/lightrag/llm/gemini.py
Normal file
@@ -0,0 +1,623 @@
|
||||
"""
|
||||
Gemini LLM binding for LightRAG.
|
||||
|
||||
This module provides asynchronous helpers that adapt Google's Gemini models
|
||||
to the same interface used by the rest of the LightRAG LLM bindings. The
|
||||
implementation mirrors the OpenAI helpers while relying on the official
|
||||
``google-genai`` client under the hood.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections.abc import AsyncIterator
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
from lightrag.utils import (
|
||||
logger,
|
||||
remove_think_tags,
|
||||
safe_unicode_decode,
|
||||
wrap_embedding_func_with_attrs,
|
||||
)
|
||||
|
||||
import pipmaster as pm
|
||||
|
||||
# Install the Google Gemini client and its dependencies on demand
|
||||
if not pm.is_installed("google-genai"):
|
||||
pm.install("google-genai")
|
||||
if not pm.is_installed("google-api-core"):
|
||||
pm.install("google-api-core")
|
||||
|
||||
from google import genai # type: ignore
|
||||
from google.genai import types # type: ignore
|
||||
from google.api_core import exceptions as google_api_exceptions # type: ignore
|
||||
|
||||
|
||||
class InvalidResponseError(Exception):
|
||||
"""Custom exception class for triggering retry mechanism when Gemini returns empty responses"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def _get_gemini_client(
|
||||
api_key: str, base_url: str | None, timeout: int | None = None
|
||||
) -> genai.Client:
|
||||
"""
|
||||
Create (or fetch cached) Gemini client.
|
||||
|
||||
Args:
|
||||
api_key: Google Gemini API key (not used in Vertex AI mode).
|
||||
base_url: Optional custom API endpoint.
|
||||
timeout: Optional request timeout in milliseconds.
|
||||
|
||||
Returns:
|
||||
genai.Client: Configured Gemini client instance.
|
||||
"""
|
||||
client_kwargs: dict[str, Any] = {}
|
||||
|
||||
# Add Vertex AI support
|
||||
use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
|
||||
if use_vertexai:
|
||||
# Vertex AI mode: use project/location, NOT api_key
|
||||
client_kwargs["vertexai"] = True
|
||||
project = os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
if project:
|
||||
location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1")
|
||||
client_kwargs["project"] = project
|
||||
if location:
|
||||
client_kwargs["location"] = location
|
||||
else:
|
||||
raise ValueError(
|
||||
"GOOGLE_CLOUD_PROJECT must be set when using Vertex AI mode"
|
||||
)
|
||||
else:
|
||||
# Standard Gemini API mode: use api_key
|
||||
client_kwargs["api_key"] = api_key
|
||||
|
||||
if base_url and base_url != "DEFAULT_GEMINI_ENDPOINT" or timeout is not None:
|
||||
try:
|
||||
http_options_kwargs = {}
|
||||
if base_url and base_url != "DEFAULT_GEMINI_ENDPOINT":
|
||||
http_options_kwargs["base_url"] = base_url
|
||||
if timeout is not None:
|
||||
http_options_kwargs["timeout"] = timeout
|
||||
|
||||
client_kwargs["http_options"] = types.HttpOptions(**http_options_kwargs)
|
||||
except Exception as e:
|
||||
logger.error("Failed to apply custom Gemini http_options: %s", e)
|
||||
raise e
|
||||
|
||||
return genai.Client(**client_kwargs)
|
||||
|
||||
|
||||
def _ensure_api_key(api_key: str | None) -> str:
|
||||
# In Vertex AI mode, API key is not required
|
||||
use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
|
||||
if use_vertexai:
|
||||
# Return empty string for Vertex AI mode (not used)
|
||||
return ""
|
||||
|
||||
key = api_key or os.getenv("LLM_BINDING_API_KEY") or os.getenv("GEMINI_API_KEY")
|
||||
if not key:
|
||||
raise ValueError(
|
||||
"Gemini API key not provided. "
|
||||
"Set LLM_BINDING_API_KEY or GEMINI_API_KEY in the environment."
|
||||
)
|
||||
return key
|
||||
|
||||
|
||||
def _build_generation_config(
|
||||
base_config: dict[str, Any] | None,
|
||||
system_prompt: str | None,
|
||||
keyword_extraction: bool,
|
||||
) -> types.GenerateContentConfig | None:
|
||||
config_data = dict(base_config or {})
|
||||
|
||||
if system_prompt:
|
||||
if config_data.get("system_instruction"):
|
||||
config_data["system_instruction"] = (
|
||||
f"{config_data['system_instruction']}\n{system_prompt}"
|
||||
)
|
||||
else:
|
||||
config_data["system_instruction"] = system_prompt
|
||||
|
||||
if keyword_extraction and not config_data.get("response_mime_type"):
|
||||
config_data["response_mime_type"] = "application/json"
|
||||
|
||||
# Remove entries that are explicitly set to None to avoid type errors
|
||||
sanitized = {
|
||||
key: value
|
||||
for key, value in config_data.items()
|
||||
if value is not None and value != ""
|
||||
}
|
||||
|
||||
if not sanitized:
|
||||
return None
|
||||
|
||||
return types.GenerateContentConfig(**sanitized)
|
||||
|
||||
|
||||
def _format_history_messages(history_messages: list[dict[str, Any]] | None) -> str:
|
||||
if not history_messages:
|
||||
return ""
|
||||
|
||||
history_lines: list[str] = []
|
||||
for message in history_messages:
|
||||
role = message.get("role", "user")
|
||||
content = message.get("content", "")
|
||||
history_lines.append(f"[{role}] {content}")
|
||||
|
||||
return "\n".join(history_lines)
|
||||
|
||||
|
||||
def _extract_response_text(
|
||||
response: Any, extract_thoughts: bool = False
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Extract text content from Gemini response, separating regular content from thoughts.
|
||||
|
||||
Args:
|
||||
response: Gemini API response object
|
||||
extract_thoughts: Whether to extract thought content separately
|
||||
|
||||
Returns:
|
||||
Tuple of (regular_text, thought_text)
|
||||
"""
|
||||
candidates = getattr(response, "candidates", None)
|
||||
if not candidates:
|
||||
return ("", "")
|
||||
|
||||
regular_parts: list[str] = []
|
||||
thought_parts: list[str] = []
|
||||
|
||||
for candidate in candidates:
|
||||
if not getattr(candidate, "content", None):
|
||||
continue
|
||||
# Use 'or []' to handle None values from parts attribute
|
||||
for part in getattr(candidate.content, "parts", None) or []:
|
||||
text = getattr(part, "text", None)
|
||||
if not text:
|
||||
continue
|
||||
|
||||
# Check if this part is thought content using the 'thought' attribute
|
||||
is_thought = getattr(part, "thought", False)
|
||||
|
||||
if is_thought and extract_thoughts:
|
||||
thought_parts.append(text)
|
||||
elif not is_thought:
|
||||
regular_parts.append(text)
|
||||
|
||||
return ("\n".join(regular_parts), "\n".join(thought_parts))
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=(
|
||||
retry_if_exception_type(google_api_exceptions.InternalServerError)
|
||||
| retry_if_exception_type(google_api_exceptions.ServiceUnavailable)
|
||||
| retry_if_exception_type(google_api_exceptions.ResourceExhausted)
|
||||
| retry_if_exception_type(google_api_exceptions.GatewayTimeout)
|
||||
| retry_if_exception_type(google_api_exceptions.BadGateway)
|
||||
| retry_if_exception_type(google_api_exceptions.DeadlineExceeded)
|
||||
| retry_if_exception_type(google_api_exceptions.Aborted)
|
||||
| retry_if_exception_type(google_api_exceptions.Unknown)
|
||||
| retry_if_exception_type(InvalidResponseError)
|
||||
),
|
||||
)
|
||||
async def gemini_complete_if_cache(
|
||||
model: str,
|
||||
prompt: str,
|
||||
system_prompt: str | None = None,
|
||||
history_messages: list[dict[str, Any]] | None = None,
|
||||
enable_cot: bool = False,
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
token_tracker: Any | None = None,
|
||||
stream: bool | None = None,
|
||||
keyword_extraction: bool = False,
|
||||
generation_config: dict[str, Any] | None = None,
|
||||
timeout: int | None = None,
|
||||
**_: Any,
|
||||
) -> str | AsyncIterator[str]:
|
||||
"""
|
||||
Complete a prompt using Gemini's API with Chain of Thought (COT) support.
|
||||
|
||||
This function supports automatic integration of reasoning content from Gemini models
|
||||
that provide Chain of Thought capabilities via the thinking_config API feature.
|
||||
|
||||
COT Integration:
|
||||
- When enable_cot=True: Thought content is wrapped in <think>...</think> tags
|
||||
- When enable_cot=False: Thought content is filtered out, only regular content returned
|
||||
- Thought content is identified by the 'thought' attribute on response parts
|
||||
- Requires thinking_config to be enabled in generation_config for API to return thoughts
|
||||
|
||||
Args:
|
||||
model: The Gemini model to use.
|
||||
prompt: The prompt to complete.
|
||||
system_prompt: Optional system prompt to include.
|
||||
history_messages: Optional list of previous messages in the conversation.
|
||||
api_key: Optional Gemini API key. If None, uses environment variable.
|
||||
base_url: Optional custom API endpoint.
|
||||
generation_config: Optional generation configuration dict.
|
||||
keyword_extraction: Whether to use JSON response format.
|
||||
token_tracker: Optional token usage tracker for monitoring API usage.
|
||||
stream: Whether to stream the response.
|
||||
hashing_kv: Storage interface (for interface parity with other bindings).
|
||||
enable_cot: Whether to include Chain of Thought content in the response.
|
||||
timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API).
|
||||
**_: Additional keyword arguments (ignored).
|
||||
|
||||
Returns:
|
||||
The completed text (with COT content if enable_cot=True) or an async iterator
|
||||
of text chunks if streaming. COT content is wrapped in <think>...</think> tags.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the response from Gemini is empty.
|
||||
ValueError: If API key is not provided or configured.
|
||||
"""
|
||||
key = _ensure_api_key(api_key)
|
||||
# Convert timeout from seconds to milliseconds for Gemini API
|
||||
timeout_ms = timeout * 1000 if timeout else None
|
||||
client = _get_gemini_client(key, base_url, timeout_ms)
|
||||
|
||||
history_block = _format_history_messages(history_messages)
|
||||
prompt_sections = []
|
||||
if history_block:
|
||||
prompt_sections.append(history_block)
|
||||
prompt_sections.append(f"[user] {prompt}")
|
||||
combined_prompt = "\n".join(prompt_sections)
|
||||
|
||||
config_obj = _build_generation_config(
|
||||
generation_config,
|
||||
system_prompt=system_prompt,
|
||||
keyword_extraction=keyword_extraction,
|
||||
)
|
||||
|
||||
request_kwargs: dict[str, Any] = {
|
||||
"model": model,
|
||||
"contents": [combined_prompt],
|
||||
}
|
||||
if config_obj is not None:
|
||||
request_kwargs["config"] = config_obj
|
||||
|
||||
if stream:
|
||||
|
||||
async def _async_stream() -> AsyncIterator[str]:
|
||||
# COT state tracking for streaming
|
||||
cot_active = False
|
||||
cot_started = False
|
||||
initial_content_seen = False
|
||||
usage_metadata = None
|
||||
|
||||
try:
|
||||
# Use native async streaming from genai SDK
|
||||
# Note: generate_content_stream returns Awaitable[AsyncIterator], need to await first
|
||||
stream = await client.aio.models.generate_content_stream(
|
||||
**request_kwargs
|
||||
)
|
||||
async for chunk in stream:
|
||||
usage = getattr(chunk, "usage_metadata", None)
|
||||
if usage is not None:
|
||||
usage_metadata = usage
|
||||
|
||||
# Extract both regular and thought content
|
||||
regular_text, thought_text = _extract_response_text(
|
||||
chunk, extract_thoughts=True
|
||||
)
|
||||
|
||||
if enable_cot:
|
||||
# Process regular content
|
||||
if regular_text:
|
||||
if not initial_content_seen:
|
||||
initial_content_seen = True
|
||||
|
||||
# Close COT section if it was active
|
||||
if cot_active:
|
||||
yield "</think>"
|
||||
cot_active = False
|
||||
|
||||
# Process and yield regular content
|
||||
if "\\u" in regular_text:
|
||||
regular_text = safe_unicode_decode(
|
||||
regular_text.encode("utf-8")
|
||||
)
|
||||
yield regular_text
|
||||
|
||||
# Process thought content
|
||||
if thought_text:
|
||||
if not initial_content_seen and not cot_started:
|
||||
# Start COT section
|
||||
yield "<think>"
|
||||
cot_active = True
|
||||
cot_started = True
|
||||
|
||||
# Yield thought content if COT is active
|
||||
if cot_active:
|
||||
if "\\u" in thought_text:
|
||||
thought_text = safe_unicode_decode(
|
||||
thought_text.encode("utf-8")
|
||||
)
|
||||
yield thought_text
|
||||
else:
|
||||
# COT disabled - only yield regular content
|
||||
if regular_text:
|
||||
if "\\u" in regular_text:
|
||||
regular_text = safe_unicode_decode(
|
||||
regular_text.encode("utf-8")
|
||||
)
|
||||
yield regular_text
|
||||
|
||||
# Ensure COT is properly closed if still active
|
||||
if cot_active:
|
||||
yield "</think>"
|
||||
cot_active = False
|
||||
|
||||
except Exception as exc:
|
||||
# Try to close COT tag before re-raising
|
||||
if cot_active:
|
||||
try:
|
||||
yield "</think>"
|
||||
except Exception:
|
||||
pass
|
||||
raise exc
|
||||
finally:
|
||||
# Track token usage after streaming completes
|
||||
if token_tracker and usage_metadata:
|
||||
token_tracker.add_usage(
|
||||
{
|
||||
"prompt_tokens": getattr(
|
||||
usage_metadata, "prompt_token_count", 0
|
||||
),
|
||||
"completion_tokens": getattr(
|
||||
usage_metadata, "candidates_token_count", 0
|
||||
),
|
||||
"total_tokens": getattr(
|
||||
usage_metadata, "total_token_count", 0
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
return _async_stream()
|
||||
|
||||
# Non-streaming: use native async client
|
||||
response = await client.aio.models.generate_content(**request_kwargs)
|
||||
|
||||
# Extract both regular text and thought text
|
||||
regular_text, thought_text = _extract_response_text(response, extract_thoughts=True)
|
||||
|
||||
# Apply COT filtering logic based on enable_cot parameter
|
||||
if enable_cot:
|
||||
# Include thought content wrapped in <think> tags
|
||||
if thought_text and thought_text.strip():
|
||||
if not regular_text or regular_text.strip() == "":
|
||||
# Only thought content available
|
||||
final_text = f"<think>{thought_text}</think>"
|
||||
else:
|
||||
# Both content types present: prepend thought to regular content
|
||||
final_text = f"<think>{thought_text}</think>{regular_text}"
|
||||
else:
|
||||
# No thought content, use regular content only
|
||||
final_text = regular_text or ""
|
||||
else:
|
||||
# Filter out thought content, return only regular content
|
||||
final_text = regular_text or ""
|
||||
|
||||
if not final_text:
|
||||
raise InvalidResponseError("Gemini response did not contain any text content.")
|
||||
|
||||
if "\\u" in final_text:
|
||||
final_text = safe_unicode_decode(final_text.encode("utf-8"))
|
||||
|
||||
final_text = remove_think_tags(final_text)
|
||||
|
||||
usage = getattr(response, "usage_metadata", None)
|
||||
if token_tracker and usage:
|
||||
token_tracker.add_usage(
|
||||
{
|
||||
"prompt_tokens": getattr(usage, "prompt_token_count", 0),
|
||||
"completion_tokens": getattr(usage, "candidates_token_count", 0),
|
||||
"total_tokens": getattr(usage, "total_token_count", 0),
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug("Gemini response length: %s", len(final_text))
|
||||
return final_text
|
||||
|
||||
|
||||
async def gemini_model_complete(
|
||||
prompt: str,
|
||||
system_prompt: str | None = None,
|
||||
history_messages: list[dict[str, Any]] | None = None,
|
||||
keyword_extraction: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> str | AsyncIterator[str]:
|
||||
hashing_kv = kwargs.get("hashing_kv")
|
||||
model_name = None
|
||||
if hashing_kv is not None:
|
||||
model_name = hashing_kv.global_config.get("llm_model_name")
|
||||
if model_name is None:
|
||||
model_name = kwargs.pop("model_name", None)
|
||||
if model_name is None:
|
||||
raise ValueError("Gemini model name not provided in configuration.")
|
||||
|
||||
return await gemini_complete_if_cache(
|
||||
model_name,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
keyword_extraction=keyword_extraction,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(
|
||||
embedding_dim=1536,
|
||||
max_token_size=2048,
|
||||
model_name="gemini-embedding-001",
|
||||
supports_asymmetric=True,
|
||||
)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=(
|
||||
retry_if_exception_type(google_api_exceptions.InternalServerError)
|
||||
| retry_if_exception_type(google_api_exceptions.ServiceUnavailable)
|
||||
| retry_if_exception_type(google_api_exceptions.ResourceExhausted)
|
||||
| retry_if_exception_type(google_api_exceptions.GatewayTimeout)
|
||||
| retry_if_exception_type(google_api_exceptions.BadGateway)
|
||||
| retry_if_exception_type(google_api_exceptions.DeadlineExceeded)
|
||||
| retry_if_exception_type(google_api_exceptions.Aborted)
|
||||
| retry_if_exception_type(google_api_exceptions.Unknown)
|
||||
),
|
||||
)
|
||||
async def gemini_embed(
|
||||
texts: list[str],
|
||||
model: str = "gemini-embedding-001",
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
embedding_dim: int | None = None,
|
||||
max_token_size: int | None = None,
|
||||
task_type: str | None = None,
|
||||
timeout: int | None = None,
|
||||
token_tracker: Any | None = None,
|
||||
context: str = "document",
|
||||
) -> np.ndarray:
|
||||
"""Generate embeddings for a list of texts using Gemini's API.
|
||||
|
||||
This function uses Google's Gemini embedding model to generate text embeddings.
|
||||
It supports dynamic dimension control and automatic normalization for dimensions
|
||||
less than 3072.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed.
|
||||
model: The Gemini embedding model to use. Default is "gemini-embedding-001".
|
||||
base_url: Optional custom API endpoint.
|
||||
api_key: Optional Gemini API key. If None, uses environment variables.
|
||||
embedding_dim: Optional embedding dimension for dynamic dimension reduction.
|
||||
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
|
||||
Do NOT manually pass this parameter when calling the function directly.
|
||||
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator
|
||||
or the EMBEDDING_DIM environment variable.
|
||||
Supported range: 128-3072. Recommended values: 768, 1536, 3072.
|
||||
max_token_size: Maximum tokens per text. This parameter is automatically
|
||||
injected by the EmbeddingFunc wrapper when the underlying function
|
||||
signature supports it (via inspect.signature check). Gemini API will
|
||||
automatically truncate texts exceeding this limit (autoTruncate=True
|
||||
by default), so no client-side truncation is needed.
|
||||
task_type: Task type for embedding optimization. Default is "RETRIEVAL_DOCUMENT".
|
||||
Supported types: SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING,
|
||||
RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, CODE_RETRIEVAL_QUERY,
|
||||
QUESTION_ANSWERING, FACT_VERIFICATION.
|
||||
timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API).
|
||||
token_tracker: Optional token usage tracker for monitoring API usage.
|
||||
context: The embedding context - "query" for search queries, "document" for indexed content.
|
||||
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper
|
||||
when supports_asymmetric=True. Default is "document".
|
||||
|
||||
Returns:
|
||||
A numpy array of embeddings, one per input text. For dimensions < 3072,
|
||||
the embeddings are L2-normalized to ensure optimal semantic similarity performance.
|
||||
|
||||
Raises:
|
||||
ValueError: If API key is not provided or configured.
|
||||
RuntimeError: If the response from Gemini is invalid or empty.
|
||||
|
||||
Note:
|
||||
- For dimension 3072: Embeddings are already normalized by the API
|
||||
- For dimensions < 3072: Embeddings are L2-normalized after retrieval
|
||||
- Normalization ensures accurate semantic similarity via cosine distance
|
||||
- Gemini API automatically truncates texts exceeding max_token_size (autoTruncate=True)
|
||||
"""
|
||||
# Note: max_token_size is received but not used for client-side truncation.
|
||||
# Gemini API handles truncation automatically with autoTruncate=True (default).
|
||||
_ = max_token_size # Acknowledge parameter to avoid unused variable warning
|
||||
|
||||
key = _ensure_api_key(api_key)
|
||||
# Convert timeout from seconds to milliseconds for Gemini API
|
||||
timeout_ms = timeout * 1000 if timeout else None
|
||||
client = _get_gemini_client(key, base_url, timeout_ms)
|
||||
|
||||
# Prepare embedding configuration
|
||||
config_kwargs: dict[str, Any] = {}
|
||||
|
||||
# Add task_type to config
|
||||
if task_type is None:
|
||||
if context == "query":
|
||||
task_type = "RETRIEVAL_QUERY"
|
||||
elif context == "document":
|
||||
task_type = "RETRIEVAL_DOCUMENT"
|
||||
else:
|
||||
task_type = "RETRIEVAL_DOCUMENT" # Default for backward compatibility
|
||||
config_kwargs["task_type"] = task_type
|
||||
|
||||
# Add output_dimensionality if embedding_dim is provided
|
||||
if embedding_dim is not None:
|
||||
config_kwargs["output_dimensionality"] = embedding_dim
|
||||
|
||||
# Create config object if we have parameters
|
||||
config_obj = types.EmbedContentConfig(**config_kwargs) if config_kwargs else None
|
||||
|
||||
request_kwargs: dict[str, Any] = {
|
||||
"model": model,
|
||||
"contents": texts,
|
||||
}
|
||||
if config_obj is not None:
|
||||
request_kwargs["config"] = config_obj
|
||||
|
||||
# Use native async client for embedding
|
||||
response = await client.aio.models.embed_content(**request_kwargs)
|
||||
|
||||
# Extract embeddings from response
|
||||
if not hasattr(response, "embeddings") or not response.embeddings:
|
||||
raise RuntimeError("Gemini response did not contain embeddings.")
|
||||
|
||||
# Convert embeddings to numpy array
|
||||
embeddings = np.array(
|
||||
[np.array(e.values, dtype=np.float32) for e in response.embeddings]
|
||||
)
|
||||
|
||||
# Apply L2 normalization for dimensions < 3072
|
||||
# The 3072 dimension embedding is already normalized by Gemini API
|
||||
if embedding_dim and embedding_dim < 3072:
|
||||
# Normalize each embedding vector to unit length
|
||||
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||
# Avoid division by zero
|
||||
norms = np.where(norms == 0, 1, norms)
|
||||
embeddings = embeddings / norms
|
||||
logger.debug(
|
||||
f"Applied L2 normalization to {len(embeddings)} embeddings of dimension {embedding_dim}"
|
||||
)
|
||||
|
||||
# Track token usage if tracker is provided
|
||||
# Note: Gemini embedding API may not provide usage metadata
|
||||
if token_tracker and hasattr(response, "usage_metadata"):
|
||||
usage = response.usage_metadata
|
||||
token_counts = {
|
||||
"prompt_tokens": getattr(usage, "prompt_token_count", 0),
|
||||
"total_tokens": getattr(usage, "total_token_count", 0),
|
||||
}
|
||||
token_tracker.add_usage(token_counts)
|
||||
|
||||
logger.debug(
|
||||
f"Generated {len(embeddings)} Gemini embeddings with dimension {embeddings.shape[1]}"
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
__all__ = [
|
||||
"gemini_complete_if_cache",
|
||||
"gemini_model_complete",
|
||||
"gemini_embed",
|
||||
]
|
||||
206
.tmp/lightrag_inspect/lightrag_pkg/lightrag/llm/hf.py
Normal file
206
.tmp/lightrag_inspect/lightrag_pkg/lightrag/llm/hf.py
Normal file
@@ -0,0 +1,206 @@
|
||||
import copy
|
||||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
# install specific modules
|
||||
if not pm.is_installed("transformers"):
|
||||
pm.install("transformers")
|
||||
if not pm.is_installed("torch"):
|
||||
pm.install("torch")
|
||||
if not pm.is_installed("numpy"):
|
||||
pm.install("numpy")
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
from lightrag.exceptions import (
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
)
|
||||
import torch
|
||||
import numpy as np
|
||||
from lightrag.utils import wrap_embedding_func_with_attrs
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def initialize_hf_model(model_name):
|
||||
hf_tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name, device_map="auto", trust_remote_code=True
|
||||
)
|
||||
hf_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, device_map="auto", trust_remote_code=True
|
||||
)
|
||||
if hf_tokenizer.pad_token is None:
|
||||
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
||||
|
||||
return hf_model, hf_tokenizer
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def hf_model_if_cache(
|
||||
model,
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
enable_cot: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if enable_cot:
|
||||
from lightrag.utils import logger
|
||||
|
||||
logger.debug(
|
||||
"enable_cot=True is not supported for Hugging Face local models and will be ignored."
|
||||
)
|
||||
model_name = model
|
||||
hf_model, hf_tokenizer = initialize_hf_model(model_name)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
kwargs.pop("hashing_kv", None)
|
||||
input_prompt = ""
|
||||
try:
|
||||
input_prompt = hf_tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
except Exception:
|
||||
try:
|
||||
ori_message = copy.deepcopy(messages)
|
||||
if messages[0]["role"] == "system":
|
||||
messages[1]["content"] = (
|
||||
"<system>"
|
||||
+ messages[0]["content"]
|
||||
+ "</system>\n"
|
||||
+ messages[1]["content"]
|
||||
)
|
||||
messages = messages[1:]
|
||||
input_prompt = hf_tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
except Exception:
|
||||
len_message = len(ori_message)
|
||||
for msgid in range(len_message):
|
||||
input_prompt = (
|
||||
input_prompt
|
||||
+ "<"
|
||||
+ ori_message[msgid]["role"]
|
||||
+ ">"
|
||||
+ ori_message[msgid]["content"]
|
||||
+ "</"
|
||||
+ ori_message[msgid]["role"]
|
||||
+ ">\n"
|
||||
)
|
||||
|
||||
input_ids = hf_tokenizer(
|
||||
input_prompt, return_tensors="pt", padding=True, truncation=True
|
||||
).to("cuda")
|
||||
inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
|
||||
output = hf_model.generate(
|
||||
**input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
|
||||
)
|
||||
response_text = hf_tokenizer.decode(
|
||||
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
|
||||
)
|
||||
|
||||
return response_text
|
||||
|
||||
|
||||
async def hf_model_complete(
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
keyword_extraction=False,
|
||||
enable_cot: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||
result = await hf_model_if_cache(
|
||||
model_name,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
enable_cot=enable_cot,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(
|
||||
embedding_dim=1024,
|
||||
max_token_size=8192,
|
||||
model_name="hf_embedding_model",
|
||||
supports_asymmetric=True,
|
||||
)
|
||||
async def hf_embed(
|
||||
texts: list[str],
|
||||
tokenizer,
|
||||
embed_model,
|
||||
context: str = "document",
|
||||
query_prefix: str | None = None,
|
||||
document_prefix: str | None = None,
|
||||
) -> np.ndarray:
|
||||
"""Generate embeddings for a list of texts using a Hugging Face model.
|
||||
|
||||
Args:
|
||||
texts (list[str]): List of input texts to embed.
|
||||
tokenizer: Hugging Face tokenizer.
|
||||
embed_model: Hugging Face model for generating embeddings.
|
||||
context (str): Context indicating whether the texts are "query" or "document".
|
||||
query_prefix (str | None): Optional prefix to add to query texts.
|
||||
document_prefix (str | None): Optional prefix to add to document texts.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Array of embeddings.
|
||||
"""
|
||||
# Detect the appropriate device
|
||||
if torch.cuda.is_available():
|
||||
device = next(embed_model.parameters()).device # Use CUDA if available
|
||||
elif torch.backends.mps.is_available():
|
||||
device = torch.device("mps") # Use MPS for Apple Silicon
|
||||
else:
|
||||
device = torch.device("cpu") # Fallback to CPU
|
||||
|
||||
# Move the model to the detected device
|
||||
embed_model = embed_model.to(device)
|
||||
|
||||
# Apply context-based prefixes if provided
|
||||
if context == "query" and query_prefix:
|
||||
texts = [query_prefix + text for text in texts]
|
||||
elif context == "document" and document_prefix:
|
||||
texts = [document_prefix + text for text in texts]
|
||||
|
||||
# Tokenize the input texts and move them to the same device
|
||||
encoded_texts = tokenizer(
|
||||
texts, return_tensors="pt", padding=True, truncation=True
|
||||
).to(device)
|
||||
|
||||
# Perform inference
|
||||
with torch.no_grad():
|
||||
outputs = embed_model(
|
||||
input_ids=encoded_texts["input_ids"],
|
||||
attention_mask=encoded_texts["attention_mask"],
|
||||
)
|
||||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||
|
||||
# Convert embeddings to NumPy
|
||||
if embeddings.dtype == torch.bfloat16:
|
||||
return embeddings.detach().to(torch.float32).cpu().numpy()
|
||||
else:
|
||||
return embeddings.detach().cpu().numpy()
|
||||
183
.tmp/lightrag_inspect/lightrag_pkg/lightrag/llm/jina.py
Normal file
183
.tmp/lightrag_inspect/lightrag_pkg/lightrag/llm/jina.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import os
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
# install specific modules
|
||||
if not pm.is_installed("aiohttp"):
|
||||
pm.install("aiohttp")
|
||||
if not pm.is_installed("tenacity"):
|
||||
pm.install("tenacity")
|
||||
|
||||
import numpy as np
|
||||
import base64
|
||||
import aiohttp
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
from lightrag.utils import wrap_embedding_func_with_attrs, logger
|
||||
|
||||
|
||||
async def fetch_data(url, headers, data):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
|
||||
# Check if the error response is HTML (common for 502, 503, etc.)
|
||||
content_type = response.headers.get("content-type", "").lower()
|
||||
is_html_error = (
|
||||
error_text.strip().startswith("<!DOCTYPE html>")
|
||||
or "text/html" in content_type
|
||||
)
|
||||
|
||||
if is_html_error:
|
||||
# Provide clean, user-friendly error messages for HTML error pages
|
||||
if response.status == 502:
|
||||
clean_error = "Bad Gateway (502) - Jina AI service temporarily unavailable. Please try again in a few minutes."
|
||||
elif response.status == 503:
|
||||
clean_error = "Service Unavailable (503) - Jina AI service is temporarily overloaded. Please try again later."
|
||||
elif response.status == 504:
|
||||
clean_error = "Gateway Timeout (504) - Jina AI service request timed out. Please try again."
|
||||
else:
|
||||
clean_error = f"HTTP {response.status} - Jina AI service error. Please try again later."
|
||||
else:
|
||||
# Use original error text if it's not HTML
|
||||
clean_error = error_text
|
||||
|
||||
logger.error(f"Jina API error {response.status}: {clean_error}")
|
||||
raise aiohttp.ClientResponseError(
|
||||
request_info=response.request_info,
|
||||
history=response.history,
|
||||
status=response.status,
|
||||
message=f"Jina API error: {clean_error}",
|
||||
)
|
||||
response_json = await response.json()
|
||||
data_list = response_json.get("data", [])
|
||||
return data_list
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(
|
||||
embedding_dim=2048,
|
||||
max_token_size=8192,
|
||||
model_name="jina-embeddings-v4",
|
||||
supports_asymmetric=True,
|
||||
)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=(
|
||||
retry_if_exception_type(aiohttp.ClientError)
|
||||
| retry_if_exception_type(aiohttp.ClientResponseError)
|
||||
),
|
||||
)
|
||||
async def jina_embed(
|
||||
texts: list[str],
|
||||
model: str = "jina-embeddings-v4",
|
||||
embedding_dim: int = 2048,
|
||||
late_chunking: bool = False,
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
context: str | None = None,
|
||||
task: str | None = None,
|
||||
) -> np.ndarray:
|
||||
"""Generate embeddings for a list of texts using Jina AI's API.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed.
|
||||
model: The Jina embedding model to use (default: jina-embeddings-v4).
|
||||
Supported models: jina-embeddings-v3, jina-embeddings-v4, etc.
|
||||
embedding_dim: The embedding dimensions (default: 2048 for jina-embeddings-v4).
|
||||
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
|
||||
Do NOT manually pass this parameter when calling the function directly.
|
||||
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
|
||||
Manually passing a different value will trigger a warning and be ignored.
|
||||
When provided (by EmbeddingFunc), it will be passed to the Jina API for dimension reduction.
|
||||
late_chunking: Whether to use late chunking.
|
||||
base_url: Optional base URL for the Jina API.
|
||||
api_key: Optional Jina API key. If None, uses the JINA_API_KEY environment variable.
|
||||
context: The embedding context - "query" for search queries, "document" for indexed content.
|
||||
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper
|
||||
when supports_asymmetric=True. When ``task`` is left at its default of None,
|
||||
``context`` drives the task selection.
|
||||
task: Embedding task mode. Default is None so that ``context`` (when present)
|
||||
picks the right Jina task:
|
||||
- "retrieval.query" for context="query"
|
||||
- "retrieval.passage" for context="document"
|
||||
- "text-matching" otherwise (true backward-compatible default)
|
||||
Any explicit non-None task value overrides context-based selection.
|
||||
|
||||
|
||||
Returns:
|
||||
A numpy array of embeddings, one per input text.
|
||||
|
||||
Raises:
|
||||
aiohttp.ClientError: If there is a connection error with the Jina API.
|
||||
aiohttp.ClientResponseError: If the Jina API returns an error response.
|
||||
"""
|
||||
if api_key:
|
||||
os.environ["JINA_API_KEY"] = api_key
|
||||
|
||||
if "JINA_API_KEY" not in os.environ:
|
||||
raise ValueError("JINA_API_KEY environment variable is required")
|
||||
|
||||
url = base_url or "https://api.jina.ai/v1/embeddings"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
|
||||
}
|
||||
|
||||
# Determine task based on context if not explicitly provided
|
||||
if task is None:
|
||||
if context == "query":
|
||||
task = "retrieval.query"
|
||||
elif context == "document":
|
||||
task = "retrieval.passage"
|
||||
else:
|
||||
task = "text-matching" # Default for backward compatibility
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"task": task,
|
||||
"dimensions": embedding_dim,
|
||||
"embedding_type": "base64",
|
||||
"input": texts,
|
||||
}
|
||||
|
||||
# Only add optional parameters if they have non-default values
|
||||
if late_chunking:
|
||||
data["late_chunking"] = late_chunking
|
||||
|
||||
logger.debug(
|
||||
f"Jina embedding request: {len(texts)} texts, dimensions: {embedding_dim}"
|
||||
)
|
||||
|
||||
try:
|
||||
data_list = await fetch_data(url, headers, data)
|
||||
|
||||
if not data_list:
|
||||
logger.error("Jina API returned empty data list")
|
||||
raise ValueError("Jina API returned empty data list")
|
||||
|
||||
if len(data_list) != len(texts):
|
||||
logger.error(
|
||||
f"Jina API returned {len(data_list)} embeddings for {len(texts)} texts"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Jina API returned {len(data_list)} embeddings for {len(texts)} texts"
|
||||
)
|
||||
|
||||
embeddings = np.array(
|
||||
[
|
||||
np.frombuffer(base64.b64decode(dp["embedding"]), dtype=np.float32)
|
||||
for dp in data_list
|
||||
]
|
||||
)
|
||||
logger.debug(f"Jina embeddings generated: shape {embeddings.shape}")
|
||||
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Jina embedding error: {e}")
|
||||
raise
|
||||
@@ -0,0 +1,208 @@
|
||||
import pipmaster as pm
|
||||
from llama_index.core.llms import (
|
||||
ChatMessage,
|
||||
MessageRole,
|
||||
ChatResponse,
|
||||
)
|
||||
from typing import List, Optional
|
||||
from lightrag.utils import logger
|
||||
|
||||
# Install required dependencies
|
||||
if not pm.is_installed("llama-index"):
|
||||
pm.install("llama-index")
|
||||
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.settings import Settings as LlamaIndexSettings
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
from lightrag.utils import (
|
||||
wrap_embedding_func_with_attrs,
|
||||
)
|
||||
from lightrag.exceptions import (
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
)
|
||||
import numpy as np
|
||||
|
||||
|
||||
def configure_llama_index(settings: LlamaIndexSettings = None, **kwargs):
|
||||
"""
|
||||
Configure LlamaIndex settings.
|
||||
|
||||
Args:
|
||||
settings: LlamaIndex Settings instance. If None, uses default settings.
|
||||
**kwargs: Additional settings to override/configure
|
||||
"""
|
||||
if settings is None:
|
||||
settings = LlamaIndexSettings()
|
||||
|
||||
# Update settings with any provided kwargs
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(settings, key):
|
||||
setattr(settings, key, value)
|
||||
else:
|
||||
logger.warning(f"Unknown LlamaIndex setting: {key}")
|
||||
|
||||
# Set as global settings
|
||||
LlamaIndexSettings.set_global(settings)
|
||||
return settings
|
||||
|
||||
|
||||
def format_chat_messages(messages):
|
||||
"""Format chat messages into LlamaIndex format."""
|
||||
formatted_messages = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
formatted_messages.append(
|
||||
ChatMessage(role=MessageRole.SYSTEM, content=content)
|
||||
)
|
||||
elif role == "assistant":
|
||||
formatted_messages.append(
|
||||
ChatMessage(role=MessageRole.ASSISTANT, content=content)
|
||||
)
|
||||
elif role == "user":
|
||||
formatted_messages.append(
|
||||
ChatMessage(role=MessageRole.USER, content=content)
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Unknown role {role}, treating as user message")
|
||||
formatted_messages.append(
|
||||
ChatMessage(role=MessageRole.USER, content=content)
|
||||
)
|
||||
|
||||
return formatted_messages
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def llama_index_complete_if_cache(
|
||||
model: str,
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
history_messages: List[dict] = [],
|
||||
enable_cot: bool = False,
|
||||
chat_kwargs={},
|
||||
) -> str:
|
||||
"""Complete the prompt using LlamaIndex."""
|
||||
if enable_cot:
|
||||
logger.debug(
|
||||
"enable_cot=True is not supported for LlamaIndex implementation and will be ignored."
|
||||
)
|
||||
try:
|
||||
# Format messages for chat
|
||||
formatted_messages = []
|
||||
|
||||
# Add system message if provided
|
||||
if system_prompt:
|
||||
formatted_messages.append(
|
||||
ChatMessage(role=MessageRole.SYSTEM, content=system_prompt)
|
||||
)
|
||||
|
||||
# Add history messages
|
||||
for msg in history_messages:
|
||||
formatted_messages.append(
|
||||
ChatMessage(
|
||||
role=MessageRole.USER
|
||||
if msg["role"] == "user"
|
||||
else MessageRole.ASSISTANT,
|
||||
content=msg["content"],
|
||||
)
|
||||
)
|
||||
|
||||
# Add current prompt
|
||||
formatted_messages.append(ChatMessage(role=MessageRole.USER, content=prompt))
|
||||
|
||||
response: ChatResponse = await model.achat(
|
||||
messages=formatted_messages, **chat_kwargs
|
||||
)
|
||||
|
||||
# In newer versions, the response is in message.content
|
||||
content = response.message.content
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in llama_index_complete_if_cache: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def llama_index_complete(
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=None,
|
||||
enable_cot: bool = False,
|
||||
keyword_extraction=False,
|
||||
settings: LlamaIndexSettings = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Main completion function for LlamaIndex
|
||||
|
||||
Args:
|
||||
prompt: Input prompt
|
||||
system_prompt: Optional system prompt
|
||||
history_messages: Optional chat history
|
||||
keyword_extraction: Whether to extract keywords from response
|
||||
settings: Optional LlamaIndex settings
|
||||
**kwargs: Additional arguments
|
||||
"""
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
result = await llama_index_complete_if_cache(
|
||||
kwargs.get("llm_instance"),
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
enable_cot=enable_cot,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def llama_index_embed(
|
||||
texts: list[str],
|
||||
embed_model: BaseEmbedding = None,
|
||||
settings: LlamaIndexSettings = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings using LlamaIndex
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
embed_model: LlamaIndex embedding model
|
||||
settings: Optional LlamaIndex settings
|
||||
**kwargs: Additional arguments
|
||||
"""
|
||||
if settings:
|
||||
configure_llama_index(settings)
|
||||
|
||||
if embed_model is None:
|
||||
raise ValueError("embed_model must be provided")
|
||||
|
||||
# Use _get_text_embeddings for batch processing
|
||||
embeddings = embed_model._get_text_embeddings(texts)
|
||||
return np.array(embeddings)
|
||||
154
.tmp/lightrag_inspect/lightrag_pkg/lightrag/llm/lmdeploy.py
Normal file
154
.tmp/lightrag_inspect/lightrag_pkg/lightrag/llm/lmdeploy.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
# install specific modules
|
||||
if not pm.is_installed("lmdeploy"):
|
||||
pm.install("lmdeploy[all]")
|
||||
|
||||
from lightrag.exceptions import (
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
)
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def initialize_lmdeploy_pipeline(
|
||||
model,
|
||||
tp=1,
|
||||
chat_template=None,
|
||||
log_level="WARNING",
|
||||
model_format="hf",
|
||||
quant_policy=0,
|
||||
):
|
||||
from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
|
||||
|
||||
lmdeploy_pipe = pipeline(
|
||||
model_path=model,
|
||||
backend_config=TurbomindEngineConfig(
|
||||
tp=tp, model_format=model_format, quant_policy=quant_policy
|
||||
),
|
||||
chat_template_config=(
|
||||
ChatTemplateConfig(model_name=chat_template) if chat_template else None
|
||||
),
|
||||
log_level="WARNING",
|
||||
)
|
||||
return lmdeploy_pipe
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def lmdeploy_model_if_cache(
|
||||
model,
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
enable_cot: bool = False,
|
||||
chat_template=None,
|
||||
model_format="hf",
|
||||
quant_policy=0,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Args:
|
||||
model (str): The path to the model.
|
||||
It could be one of the following options:
|
||||
- i) A local directory path of a turbomind model which is
|
||||
converted by `lmdeploy convert` command or download
|
||||
from ii) and iii).
|
||||
- ii) The model_id of a lmdeploy-quantized model hosted
|
||||
inside a model repo on huggingface.co, such as
|
||||
"InternLM/internlm-chat-20b-4bit",
|
||||
"lmdeploy/llama2-chat-70b-4bit", etc.
|
||||
- iii) The model_id of a model hosted inside a model repo
|
||||
on huggingface.co, such as "internlm/internlm-chat-7b",
|
||||
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
|
||||
and so on.
|
||||
chat_template (str): needed when model is a pytorch model on
|
||||
huggingface.co, such as "internlm-chat-7b",
|
||||
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
|
||||
and when the model name of local path did not match the original model name in HF.
|
||||
tp (int): tensor parallel
|
||||
prompt (Union[str, List[str]]): input texts to be completed.
|
||||
do_preprocess (bool): whether pre-process the messages. Default to
|
||||
True, which means chat_template will be applied.
|
||||
skip_special_tokens (bool): Whether or not to remove special tokens
|
||||
in the decoding. Default to be True.
|
||||
do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
|
||||
Default to be False, which means greedy decoding will be applied.
|
||||
"""
|
||||
if enable_cot:
|
||||
from lightrag.utils import logger
|
||||
|
||||
logger.debug(
|
||||
"enable_cot=True is not supported for lmdeploy and will be ignored."
|
||||
)
|
||||
try:
|
||||
import lmdeploy
|
||||
from lmdeploy import version_info, GenerationConfig
|
||||
except Exception:
|
||||
raise ImportError("Please install lmdeploy before initialize lmdeploy backend.")
|
||||
kwargs.pop("hashing_kv", None)
|
||||
kwargs.pop("response_format", None)
|
||||
max_new_tokens = kwargs.pop("max_tokens", 512)
|
||||
tp = kwargs.pop("tp", 1)
|
||||
skip_special_tokens = kwargs.pop("skip_special_tokens", True)
|
||||
do_preprocess = kwargs.pop("do_preprocess", True)
|
||||
do_sample = kwargs.pop("do_sample", False)
|
||||
gen_params = kwargs
|
||||
|
||||
version = version_info
|
||||
if do_sample is not None and version < (0, 6, 0):
|
||||
raise RuntimeError(
|
||||
"`do_sample` parameter is not supported by lmdeploy until "
|
||||
f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
|
||||
)
|
||||
else:
|
||||
do_sample = True
|
||||
gen_params.update(do_sample=do_sample)
|
||||
|
||||
lmdeploy_pipe = initialize_lmdeploy_pipeline(
|
||||
model=model,
|
||||
tp=tp,
|
||||
chat_template=chat_template,
|
||||
model_format=model_format,
|
||||
quant_policy=quant_policy,
|
||||
log_level="WARNING",
|
||||
)
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
gen_config = GenerationConfig(
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
max_new_tokens=max_new_tokens,
|
||||
**gen_params,
|
||||
)
|
||||
|
||||
response = ""
|
||||
async for res in lmdeploy_pipe.generate(
|
||||
messages,
|
||||
gen_config=gen_config,
|
||||
do_preprocess=do_preprocess,
|
||||
stream_response=False,
|
||||
session_id=1,
|
||||
):
|
||||
response += res.response
|
||||
return response
|
||||
177
.tmp/lightrag_inspect/lightrag_pkg/lightrag/llm/lollms.py
Normal file
177
.tmp/lightrag_inspect/lightrag_pkg/lightrag/llm/lollms.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
from typing import AsyncIterator
|
||||
else:
|
||||
from collections.abc import AsyncIterator
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
if not pm.is_installed("aiohttp"):
|
||||
pm.install("aiohttp")
|
||||
|
||||
import aiohttp
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
from lightrag.exceptions import (
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
)
|
||||
|
||||
from typing import Union, List
|
||||
import numpy as np
|
||||
|
||||
from lightrag.utils import (
|
||||
wrap_embedding_func_with_attrs,
|
||||
)
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def lollms_model_if_cache(
|
||||
model,
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
enable_cot: bool = False,
|
||||
base_url="http://localhost:9600",
|
||||
**kwargs,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
"""Client implementation for lollms generation."""
|
||||
if enable_cot:
|
||||
from lightrag.utils import logger
|
||||
|
||||
logger.debug("enable_cot=True is not supported for lollms and will be ignored.")
|
||||
|
||||
stream = True if kwargs.get("stream") else False
|
||||
api_key = kwargs.pop("api_key", None)
|
||||
headers = (
|
||||
{"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
if api_key
|
||||
else {"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
# Extract lollms specific parameters
|
||||
request_data = {
|
||||
"prompt": prompt,
|
||||
"model_name": model,
|
||||
"personality": kwargs.get("personality", -1),
|
||||
"n_predict": kwargs.get("n_predict", None),
|
||||
"stream": stream,
|
||||
"temperature": kwargs.get("temperature", 1.0),
|
||||
"top_k": kwargs.get("top_k", 50),
|
||||
"top_p": kwargs.get("top_p", 0.95),
|
||||
"repeat_penalty": kwargs.get("repeat_penalty", 0.8),
|
||||
"repeat_last_n": kwargs.get("repeat_last_n", 40),
|
||||
"seed": kwargs.get("seed", None),
|
||||
"n_threads": kwargs.get("n_threads", 8),
|
||||
}
|
||||
|
||||
# Prepare the full prompt including history
|
||||
full_prompt = ""
|
||||
if system_prompt:
|
||||
full_prompt += f"{system_prompt}\n"
|
||||
for msg in history_messages:
|
||||
full_prompt += f"{msg['role']}: {msg['content']}\n"
|
||||
full_prompt += prompt
|
||||
|
||||
request_data["prompt"] = full_prompt
|
||||
timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout", None))
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout, headers=headers) as session:
|
||||
if stream:
|
||||
|
||||
async def inner():
|
||||
async with session.post(
|
||||
f"{base_url}/lollms_generate", json=request_data
|
||||
) as response:
|
||||
async for line in response.content:
|
||||
yield line.decode().strip()
|
||||
|
||||
return inner()
|
||||
else:
|
||||
async with session.post(
|
||||
f"{base_url}/lollms_generate", json=request_data
|
||||
) as response:
|
||||
return await response.text()
|
||||
|
||||
|
||||
async def lollms_model_complete(
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
enable_cot: bool = False,
|
||||
keyword_extraction=False,
|
||||
**kwargs,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
"""Complete function for lollms model generation."""
|
||||
|
||||
# Extract and remove keyword_extraction from kwargs if present
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
|
||||
# Get model name from config
|
||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||
|
||||
# If keyword extraction is needed, we might need to modify the prompt
|
||||
# or add specific parameters for JSON output (if lollms supports it)
|
||||
if keyword_extraction:
|
||||
# Note: You might need to adjust this based on how lollms handles structured output
|
||||
pass
|
||||
|
||||
return await lollms_model_if_cache(
|
||||
model_name,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
enable_cot=enable_cot,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(
|
||||
embedding_dim=1024, max_token_size=8192, model_name="lollms_embedding_model"
|
||||
)
|
||||
async def lollms_embed(
|
||||
texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for a list of texts using lollms server.
|
||||
|
||||
Args:
|
||||
texts: List of strings to embed
|
||||
embed_model: Model name (not used directly as lollms uses configured vectorizer)
|
||||
base_url: URL of the lollms server
|
||||
**kwargs: Additional arguments passed to the request
|
||||
|
||||
Returns:
|
||||
np.ndarray: Array of embeddings
|
||||
"""
|
||||
api_key = kwargs.pop("api_key", None)
|
||||
headers = (
|
||||
{"Content-Type": "application/json", "Authorization": api_key}
|
||||
if api_key
|
||||
else {"Content-Type": "application/json"}
|
||||
)
|
||||
async with aiohttp.ClientSession(headers=headers) as session:
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
request_data = {"text": text}
|
||||
|
||||
async with session.post(
|
||||
f"{base_url}/lollms_embed",
|
||||
json=request_data,
|
||||
) as response:
|
||||
result = await response.json()
|
||||
embeddings.append(result["vector"])
|
||||
|
||||
return np.array(embeddings)
|
||||
@@ -0,0 +1,68 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
# install specific modules
|
||||
if not pm.is_installed("openai"):
|
||||
pm.install("openai")
|
||||
|
||||
from openai import (
|
||||
AsyncOpenAI,
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
)
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
from lightrag.utils import (
|
||||
wrap_embedding_func_with_attrs,
|
||||
)
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(
|
||||
embedding_dim=2048, max_token_size=8192, model_name="nvidia_embedding_model"
|
||||
)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def nvidia_openai_embed(
|
||||
texts: list[str],
|
||||
model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1",
|
||||
# refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding
|
||||
base_url: str = "https://integrate.api.nvidia.com/v1",
|
||||
api_key: str = None,
|
||||
input_type: str = "passage", # query for retrieval, passage for embedding
|
||||
trunc: str = "NONE", # NONE or START or END
|
||||
encode: str = "float", # float or base64
|
||||
) -> np.ndarray:
|
||||
if api_key:
|
||||
os.environ["OPENAI_API_KEY"] = api_key
|
||||
|
||||
openai_async_client = (
|
||||
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
||||
)
|
||||
response = await openai_async_client.embeddings.create(
|
||||
model=model,
|
||||
input=texts,
|
||||
encoding_format=encode,
|
||||
extra_body={"input_type": input_type, "truncate": trunc},
|
||||
)
|
||||
return np.array([dp.embedding for dp in response.data])
|
||||
260
.tmp/lightrag_inspect/lightrag_pkg/lightrag/llm/ollama.py
Normal file
260
.tmp/lightrag_inspect/lightrag_pkg/lightrag/llm/ollama.py
Normal file
@@ -0,0 +1,260 @@
|
||||
from collections.abc import AsyncIterator
|
||||
import os
|
||||
import re
|
||||
|
||||
import pipmaster as pm
|
||||
|
||||
# install specific modules
|
||||
if not pm.is_installed("ollama"):
|
||||
pm.install("ollama")
|
||||
|
||||
import ollama
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
from lightrag.exceptions import (
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
)
|
||||
from lightrag.api import __api_version__
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional, Union
|
||||
from lightrag.utils import (
|
||||
wrap_embedding_func_with_attrs,
|
||||
logger,
|
||||
)
|
||||
|
||||
|
||||
_OLLAMA_CLOUD_HOST = "https://ollama.com"
|
||||
_CLOUD_MODEL_SUFFIX_PATTERN = re.compile(r"(?:-cloud|:cloud)$")
|
||||
|
||||
|
||||
def _coerce_host_for_cloud_model(host: Optional[str], model: object) -> Optional[str]:
|
||||
if host:
|
||||
return host
|
||||
try:
|
||||
model_name_str = str(model) if model is not None else ""
|
||||
except (TypeError, ValueError, AttributeError) as e:
|
||||
logger.warning(f"Failed to convert model to string: {e}, using empty string")
|
||||
model_name_str = ""
|
||||
if _CLOUD_MODEL_SUFFIX_PATTERN.search(model_name_str):
|
||||
logger.debug(
|
||||
f"Detected cloud model '{model_name_str}', using Ollama Cloud host"
|
||||
)
|
||||
return _OLLAMA_CLOUD_HOST
|
||||
return host
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def _ollama_model_if_cache(
|
||||
model,
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
enable_cot: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
if enable_cot:
|
||||
logger.debug("enable_cot=True is not supported for ollama and will be ignored.")
|
||||
stream = True if kwargs.get("stream") else False
|
||||
|
||||
kwargs.pop("max_tokens", None)
|
||||
# kwargs.pop("response_format", None) # allow json
|
||||
host = kwargs.pop("host", None)
|
||||
timeout = kwargs.pop("timeout", None)
|
||||
if timeout == 0:
|
||||
timeout = None
|
||||
kwargs.pop("hashing_kv", None)
|
||||
api_key = kwargs.pop("api_key", None)
|
||||
# fallback to environment variable when not provided explicitly
|
||||
if not api_key:
|
||||
api_key = os.getenv("OLLAMA_API_KEY")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": f"LightRAG/{__api_version__}",
|
||||
}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
host = _coerce_host_for_cloud_model(host, model)
|
||||
|
||||
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
|
||||
|
||||
try:
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
|
||||
if stream:
|
||||
"""cannot cache stream response and process reasoning"""
|
||||
|
||||
async def inner():
|
||||
try:
|
||||
async for chunk in response:
|
||||
yield chunk["message"]["content"]
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream response: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
await ollama_client._client.aclose()
|
||||
logger.debug("Successfully closed Ollama client for streaming")
|
||||
except Exception as close_error:
|
||||
logger.warning(f"Failed to close Ollama client: {close_error}")
|
||||
|
||||
return inner()
|
||||
else:
|
||||
model_response = response["message"]["content"]
|
||||
|
||||
"""
|
||||
If the model also wraps its thoughts in a specific tag,
|
||||
this information is not needed for the final
|
||||
response and can simply be trimmed.
|
||||
"""
|
||||
|
||||
return model_response
|
||||
except Exception as e:
|
||||
try:
|
||||
await ollama_client._client.aclose()
|
||||
logger.debug("Successfully closed Ollama client after exception")
|
||||
except Exception as close_error:
|
||||
logger.warning(
|
||||
f"Failed to close Ollama client after exception: {close_error}"
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
if not stream:
|
||||
try:
|
||||
await ollama_client._client.aclose()
|
||||
logger.debug(
|
||||
"Successfully closed Ollama client for non-streaming response"
|
||||
)
|
||||
except Exception as close_error:
|
||||
logger.warning(
|
||||
f"Failed to close Ollama client in finally block: {close_error}"
|
||||
)
|
||||
|
||||
|
||||
async def ollama_model_complete(
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
enable_cot: bool = False,
|
||||
keyword_extraction=False,
|
||||
**kwargs,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
if keyword_extraction:
|
||||
kwargs["format"] = "json"
|
||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||
return await _ollama_model_if_cache(
|
||||
model_name,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
enable_cot=enable_cot,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(
|
||||
embedding_dim=1024,
|
||||
max_token_size=8192,
|
||||
model_name="bge-m3:latest",
|
||||
supports_asymmetric=True,
|
||||
)
|
||||
async def ollama_embed(
|
||||
texts: list[str],
|
||||
embed_model: str = "bge-m3:latest",
|
||||
max_token_size: int | None = None,
|
||||
context: str = "document",
|
||||
query_prefix: str | None = None,
|
||||
document_prefix: str | None = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""Generate embeddings using Ollama's API.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed.
|
||||
embed_model: The Ollama embedding model to use. Default is "bge-m3:latest".
|
||||
max_token_size: Maximum tokens per text. This parameter is automatically
|
||||
injected by the EmbeddingFunc wrapper when the underlying function
|
||||
signature supports it (via inspect.signature check). Ollama will
|
||||
automatically truncate texts exceeding the model's context length
|
||||
(num_ctx), so no client-side truncation is needed.
|
||||
context: The embedding context - "query" for search queries, "document" for indexed content.
|
||||
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper
|
||||
when supports_asymmetric=True. Default is "document".
|
||||
query_prefix: Optional prefix to prepend to texts when context="query" (e.g., "search_query: ").
|
||||
document_prefix: Optional prefix to prepend to texts when context="document" (e.g., "search_document: ").
|
||||
**kwargs: Additional arguments passed to the Ollama client.
|
||||
|
||||
Returns:
|
||||
A numpy array of embeddings, one per input text.
|
||||
|
||||
Note:
|
||||
- Ollama API automatically truncates texts exceeding the model's context length
|
||||
- The max_token_size parameter is received but not used for client-side truncation
|
||||
"""
|
||||
# Apply context-based prefixes if provided
|
||||
if context == "query" and query_prefix:
|
||||
texts = [query_prefix + text for text in texts]
|
||||
elif context == "document" and document_prefix:
|
||||
texts = [document_prefix + text for text in texts]
|
||||
|
||||
# Note: max_token_size is received but not used for client-side truncation.
|
||||
# Ollama API handles truncation automatically based on the model's num_ctx setting.
|
||||
_ = max_token_size # Acknowledge parameter to avoid unused variable warning
|
||||
api_key = kwargs.pop("api_key", None)
|
||||
if not api_key:
|
||||
api_key = os.getenv("OLLAMA_API_KEY")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": f"LightRAG/{__api_version__}",
|
||||
}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
host = kwargs.pop("host", None)
|
||||
timeout = kwargs.pop("timeout", None)
|
||||
|
||||
host = _coerce_host_for_cloud_model(host, embed_model)
|
||||
|
||||
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
|
||||
try:
|
||||
options = kwargs.pop("options", {})
|
||||
data = await ollama_client.embed(
|
||||
model=embed_model, input=texts, options=options
|
||||
)
|
||||
return np.array(data["embeddings"])
|
||||
except Exception as e:
|
||||
logger.error(f"Error in ollama_embed: {str(e)}")
|
||||
try:
|
||||
await ollama_client._client.aclose()
|
||||
logger.debug("Successfully closed Ollama client after exception in embed")
|
||||
except Exception as close_error:
|
||||
logger.warning(
|
||||
f"Failed to close Ollama client after exception in embed: {close_error}"
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
try:
|
||||
await ollama_client._client.aclose()
|
||||
logger.debug("Successfully closed Ollama client after embed")
|
||||
except Exception as close_error:
|
||||
logger.warning(f"Failed to close Ollama client after embed: {close_error}")
|
||||
1073
.tmp/lightrag_inspect/lightrag_pkg/lightrag/llm/openai.py
Normal file
1073
.tmp/lightrag_inspect/lightrag_pkg/lightrag/llm/openai.py
Normal file
File diff suppressed because it is too large
Load Diff
197
.tmp/lightrag_inspect/lightrag_pkg/lightrag/llm/voyageai.py
Normal file
197
.tmp/lightrag_inspect/lightrag_pkg/lightrag/llm/voyageai.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
# Add Voyage AI import
|
||||
if not pm.is_installed("voyageai"):
|
||||
pm.install("voyageai")
|
||||
|
||||
import voyageai
|
||||
from voyageai.error import (
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
ServerError,
|
||||
ServiceUnavailableError,
|
||||
Timeout,
|
||||
TryAgain,
|
||||
)
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
from lightrag.utils import wrap_embedding_func_with_attrs, logger
|
||||
|
||||
|
||||
# Custom exceptions for VoyageAI errors
|
||||
class VoyageAIError(Exception):
|
||||
"""Generic VoyageAI API error"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(
|
||||
embedding_dim=1024, max_token_size=32000, supports_asymmetric=True
|
||||
)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type(
|
||||
(
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
ServerError,
|
||||
ServiceUnavailableError,
|
||||
Timeout,
|
||||
TryAgain,
|
||||
)
|
||||
),
|
||||
)
|
||||
async def voyageai_embed(
|
||||
texts: list[str],
|
||||
model: str = "voyage-3",
|
||||
api_key: str | None = None,
|
||||
embedding_dim: int | None = None,
|
||||
input_type: str | None = None,
|
||||
truncation: bool | None = True,
|
||||
context: str | None = None,
|
||||
) -> np.ndarray:
|
||||
"""Generate embeddings for a list of texts using VoyageAI's API.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed.
|
||||
model: The VoyageAI embedding model to use. Options include:
|
||||
- "voyage-3": General purpose (1024 dims, 32K context)
|
||||
- "voyage-3-lite": Lightweight (512 dims, 32K context)
|
||||
- "voyage-3-large": Highest accuracy (1024 dims, 32K context)
|
||||
- "voyage-code-3": Code optimized (1024 dims, 32K context)
|
||||
- "voyage-law-2": Legal documents (1024 dims, 16K context)
|
||||
- "voyage-finance-2": Finance (1024 dims, 32K context)
|
||||
api_key: Optional VoyageAI API key. If None, falls back to the
|
||||
``VOYAGE_API_KEY`` environment variable (the name VoyageAI's own
|
||||
SDK uses), then to ``VOYAGEAI_API_KEY`` for backward compatibility.
|
||||
embedding_dim: Optional Matryoshka output dimension. Only honored by
|
||||
models that support dimension reduction (e.g. voyage-3-large);
|
||||
ignored otherwise. The decorator default is 1024 to match
|
||||
``voyage-3``; if you select ``voyage-3-lite`` (512 dims) override
|
||||
``EMBEDDING_DIM`` accordingly so the vector store size matches.
|
||||
input_type: Optional input type hint for the model. Options:
|
||||
- "query": For search queries
|
||||
- "document": For documents to be indexed
|
||||
- None: Let the model decide (default)
|
||||
truncation: Whether the API should truncate texts that exceed the model's
|
||||
token limit. Defaults to True (matches the VoyageAI SDK default).
|
||||
context: Optional LightRAG embedding context. When ``input_type`` is not
|
||||
set, "query" maps to ``input_type="query"`` and "document" maps to
|
||||
``input_type="document"``.
|
||||
|
||||
Returns:
|
||||
A numpy array of embeddings, one per input text.
|
||||
|
||||
Raises:
|
||||
VoyageAIError: If the API call fails or returns invalid data.
|
||||
|
||||
"""
|
||||
if not api_key:
|
||||
api_key = os.environ.get("VOYAGE_API_KEY") or os.environ.get("VOYAGEAI_API_KEY")
|
||||
if not api_key:
|
||||
logger.error(
|
||||
"VoyageAI API key not provided and neither VOYAGE_API_KEY nor "
|
||||
"VOYAGEAI_API_KEY environment variable is set"
|
||||
)
|
||||
raise ValueError(
|
||||
"VoyageAI API key is required: pass api_key, or set the "
|
||||
"VOYAGE_API_KEY (preferred) or VOYAGEAI_API_KEY environment variable"
|
||||
)
|
||||
|
||||
if input_type is None and context in {"query", "document"}:
|
||||
input_type = context
|
||||
|
||||
try:
|
||||
client = voyageai.AsyncClient(api_key=api_key)
|
||||
|
||||
total_chars = sum(len(t) for t in texts)
|
||||
avg_chars = total_chars / len(texts) if texts else 0
|
||||
logger.debug(
|
||||
f"VoyageAI embedding request: {len(texts)} texts, "
|
||||
f"total_chars={total_chars}, avg_chars={avg_chars:.0f}, model={model}, "
|
||||
f"input_type={input_type}"
|
||||
)
|
||||
|
||||
# Prepare API call parameters
|
||||
embed_params = dict(
|
||||
texts=texts,
|
||||
model=model,
|
||||
# Optional parameters -- if None, voyageai client uses defaults
|
||||
output_dimension=embedding_dim,
|
||||
truncation=truncation,
|
||||
input_type=input_type,
|
||||
)
|
||||
# Make API call with timing
|
||||
result = await client.embed(**embed_params)
|
||||
|
||||
if not result.embeddings:
|
||||
err_msg = "VoyageAI API returned empty embeddings"
|
||||
logger.error(err_msg)
|
||||
raise VoyageAIError(err_msg)
|
||||
|
||||
if len(result.embeddings) != len(texts):
|
||||
err_msg = f"VoyageAI API returned {len(result.embeddings)} embeddings for {len(texts)} texts"
|
||||
logger.error(err_msg)
|
||||
raise VoyageAIError(err_msg)
|
||||
|
||||
# Convert to numpy array with timing
|
||||
embeddings = np.array(result.embeddings, dtype=np.float32)
|
||||
logger.debug(f"VoyageAI embeddings generated: shape {embeddings.shape}")
|
||||
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"VoyageAI embedding error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# Optional: a helper function to get available embedding models
|
||||
def get_available_embedding_models() -> dict[str, dict]:
|
||||
"""
|
||||
Returns a dictionary of available Voyage AI embedding models and their properties.
|
||||
"""
|
||||
return {
|
||||
"voyage-3-large": {
|
||||
"context_length": 32000,
|
||||
"dimension": 1024,
|
||||
"description": "Best general-purpose and multilingual",
|
||||
},
|
||||
"voyage-3": {
|
||||
"context_length": 32000,
|
||||
"dimension": 1024,
|
||||
"description": "General-purpose and multilingual",
|
||||
},
|
||||
"voyage-3-lite": {
|
||||
"context_length": 32000,
|
||||
"dimension": 512,
|
||||
"description": "Optimized for latency and cost",
|
||||
},
|
||||
"voyage-code-3": {
|
||||
"context_length": 32000,
|
||||
"dimension": 1024,
|
||||
"description": "Optimized for code",
|
||||
},
|
||||
"voyage-finance-2": {
|
||||
"context_length": 32000,
|
||||
"dimension": 1024,
|
||||
"description": "Optimized for finance",
|
||||
},
|
||||
"voyage-law-2": {
|
||||
"context_length": 16000,
|
||||
"dimension": 1024,
|
||||
"description": "Optimized for legal",
|
||||
},
|
||||
"voyage-multimodal-3": {
|
||||
"context_length": 32000,
|
||||
"dimension": 1024,
|
||||
"description": "Multimodal text and images",
|
||||
},
|
||||
}
|
||||
248
.tmp/lightrag_inspect/lightrag_pkg/lightrag/llm/zhipu.py
Normal file
248
.tmp/lightrag_inspect/lightrag_pkg/lightrag/llm/zhipu.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import sys
|
||||
import re
|
||||
import json
|
||||
from ..utils import verbose_debug
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
# install specific modules
|
||||
if not pm.is_installed("zhipuai"):
|
||||
pm.install("zhipuai")
|
||||
|
||||
from openai import (
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
)
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
from lightrag.utils import (
|
||||
wrap_embedding_func_with_attrs,
|
||||
logger,
|
||||
)
|
||||
|
||||
from lightrag.types import GPTKeywordExtractionFormat
|
||||
|
||||
import numpy as np
|
||||
from typing import Union, List, Optional, Dict
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def zhipu_complete_if_cache(
|
||||
prompt: Union[str, List[Dict[str, str]]],
|
||||
model: str = "glm-4-flashx", # The most cost/performance balance model in glm-4 series
|
||||
api_key: Optional[str] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
history_messages: List[Dict[str, str]] = [],
|
||||
enable_cot: bool = False, # LightRAG output switch: include reasoning_content as <think>...</think>
|
||||
thinking: Optional[
|
||||
Dict[str, object]
|
||||
] = None, # Zhipu request param: use {"type": "enabled"} to enable thinking
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Call Zhipu chat completions with optional official thinking support.
|
||||
|
||||
Parameter roles:
|
||||
- `thinking`: forwarded to the Zhipu API as-is. To enable thinking output,
|
||||
pass a config such as `{"type": "enabled"}`.
|
||||
- `enable_cot`: LightRAG-only formatting switch. When True and the API
|
||||
returns `reasoning_content`, it is preserved in the final string as
|
||||
`<think>...</think>`.
|
||||
"""
|
||||
# dynamically load ZhipuAI
|
||||
try:
|
||||
from zhipuai import ZhipuAI
|
||||
except ImportError:
|
||||
raise ImportError("Please install zhipuai before initialize zhipuai backend.")
|
||||
|
||||
if api_key:
|
||||
client = ZhipuAI(api_key=api_key)
|
||||
else:
|
||||
# please set ZHIPUAI_API_KEY in your environment
|
||||
# os.environ["ZHIPUAI_API_KEY"]
|
||||
client = ZhipuAI()
|
||||
|
||||
messages = []
|
||||
|
||||
if not system_prompt:
|
||||
system_prompt = "You are a helpful assistant. Note that sensitive words in the content should be replaced with ***"
|
||||
|
||||
# Add system prompt if provided
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Add debug logging
|
||||
logger.debug("===== Query Input to LLM =====")
|
||||
logger.debug(f"Query: {prompt}")
|
||||
verbose_debug(f"System prompt: {system_prompt}")
|
||||
|
||||
# Remove unsupported kwargs
|
||||
kwargs = {
|
||||
k: v for k, v in kwargs.items() if k not in ["hashing_kv", "keyword_extraction"]
|
||||
}
|
||||
# `thinking` is an official Zhipu request field. Example:
|
||||
# {"type": "enabled"} enables reasoning output on supported models.
|
||||
if thinking is not None:
|
||||
kwargs["thinking"] = thinking
|
||||
|
||||
response = client.chat.completions.create(model=model, messages=messages, **kwargs)
|
||||
message = response.choices[0].message
|
||||
content = message.content or ""
|
||||
reasoning_content = getattr(message, "reasoning_content", "") or ""
|
||||
|
||||
if enable_cot and reasoning_content.strip():
|
||||
if content:
|
||||
return f"<think>{reasoning_content}</think>{content}"
|
||||
return f"<think>{reasoning_content}</think>"
|
||||
|
||||
return content
|
||||
|
||||
|
||||
async def zhipu_complete(
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
keyword_extraction=False,
|
||||
enable_cot: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", keyword_extraction)
|
||||
|
||||
if keyword_extraction:
|
||||
# Add a system prompt to guide the model to return JSON format
|
||||
extraction_prompt = """You are a helpful assistant that extracts keywords from text.
|
||||
Please analyze the content and extract two types of keywords:
|
||||
1. High-level keywords: Important concepts and main themes
|
||||
2. Low-level keywords: Specific details and supporting elements
|
||||
|
||||
Return your response in this exact JSON format:
|
||||
{
|
||||
"high_level_keywords": ["keyword1", "keyword2"],
|
||||
"low_level_keywords": ["keyword1", "keyword2", "keyword3"]
|
||||
}
|
||||
|
||||
Only return the JSON, no other text."""
|
||||
|
||||
# Combine with existing system prompt if any
|
||||
if system_prompt:
|
||||
system_prompt = f"{system_prompt}\n\n{extraction_prompt}"
|
||||
else:
|
||||
system_prompt = extraction_prompt
|
||||
|
||||
try:
|
||||
response = await zhipu_complete_if_cache(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
enable_cot=enable_cot,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Try to parse as JSON
|
||||
try:
|
||||
data = json.loads(response)
|
||||
return GPTKeywordExtractionFormat(
|
||||
high_level_keywords=data.get("high_level_keywords", []),
|
||||
low_level_keywords=data.get("low_level_keywords", []),
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# If direct JSON parsing fails, try to extract JSON from text
|
||||
match = re.search(r"\{[\s\S]*\}", response)
|
||||
if match:
|
||||
try:
|
||||
data = json.loads(match.group())
|
||||
return GPTKeywordExtractionFormat(
|
||||
high_level_keywords=data.get("high_level_keywords", []),
|
||||
low_level_keywords=data.get("low_level_keywords", []),
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# If all parsing fails, log warning and return empty format
|
||||
logger.warning(
|
||||
f"Failed to parse keyword extraction response: {response}"
|
||||
)
|
||||
return GPTKeywordExtractionFormat(
|
||||
high_level_keywords=[], low_level_keywords=[]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during keyword extraction: {str(e)}")
|
||||
return GPTKeywordExtractionFormat(
|
||||
high_level_keywords=[], low_level_keywords=[]
|
||||
)
|
||||
else:
|
||||
# For non-keyword-extraction, just return the raw response string
|
||||
return await zhipu_complete_if_cache(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
enable_cot=enable_cot,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(
|
||||
embedding_dim=1024, max_token_size=8192, model_name="embedding-3"
|
||||
)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def zhipu_embedding(
|
||||
texts: list[str],
|
||||
model: str = "embedding-3",
|
||||
api_key: str = None,
|
||||
embedding_dim: int | None = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
# dynamically load ZhipuAI
|
||||
try:
|
||||
from zhipuai import ZhipuAI
|
||||
except ImportError:
|
||||
raise ImportError("Please install zhipuai before initialize zhipuai backend.")
|
||||
if api_key:
|
||||
client = ZhipuAI(api_key=api_key)
|
||||
else:
|
||||
# please set ZHIPUAI_API_KEY in your environment
|
||||
# os.environ["ZHIPUAI_API_KEY"]
|
||||
client = ZhipuAI()
|
||||
|
||||
# Convert single text to list if needed
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
try:
|
||||
request_kwargs = dict(kwargs)
|
||||
if embedding_dim is not None:
|
||||
request_kwargs["dimensions"] = embedding_dim
|
||||
response = client.embeddings.create(
|
||||
model=model, input=[text], **request_kwargs
|
||||
)
|
||||
embeddings.append(response.data[0].embedding)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")
|
||||
|
||||
return np.array(embeddings)
|
||||
Reference in New Issue
Block a user