64 lines
1.8 KiB
Python
64 lines
1.8 KiB
Python
|
|
"""
|
||
|
|
LLM 工厂 - 创建不同提供商的 LLM 实例
|
||
|
|
"""
|
||
|
|
from typing import Optional
|
||
|
|
from langchain_openai import ChatOpenAI
|
||
|
|
from langchain_anthropic import ChatAnthropic
|
||
|
|
|
||
|
|
|
||
|
|
class LLMFactory:
|
||
|
|
"""LLM 工厂类"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
provider: str = "openai",
|
||
|
|
openai_api_key: Optional[str] = None,
|
||
|
|
anthropic_api_key: Optional[str] = None,
|
||
|
|
model: str = "gpt-3.5-turbo",
|
||
|
|
temperature: float = 0.7,
|
||
|
|
max_tokens: int = 2000
|
||
|
|
):
|
||
|
|
self.provider = provider
|
||
|
|
self.openai_api_key = openai_api_key
|
||
|
|
self.anthropic_api_key = anthropic_api_key
|
||
|
|
self.model = model
|
||
|
|
self.temperature = temperature
|
||
|
|
self.max_tokens = max_tokens
|
||
|
|
|
||
|
|
self._llm = None
|
||
|
|
|
||
|
|
def get_llm(self):
|
||
|
|
"""获取 LLM 实例"""
|
||
|
|
if self._llm is not None:
|
||
|
|
return self._llm
|
||
|
|
|
||
|
|
if self.provider == "openai":
|
||
|
|
self._llm = ChatOpenAI(
|
||
|
|
model=self.model,
|
||
|
|
temperature=self.temperature,
|
||
|
|
max_tokens=self.max_tokens,
|
||
|
|
api_key=self.openai_api_key
|
||
|
|
)
|
||
|
|
elif self.provider == "anthropic":
|
||
|
|
self._llm = ChatAnthropic(
|
||
|
|
model=self.model,
|
||
|
|
temperature=self.temperature,
|
||
|
|
max_tokens=self.max_tokens,
|
||
|
|
anthropic_api_key=self.anthropic_api_key
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
raise ValueError(f"Unsupported provider: {self.provider}")
|
||
|
|
|
||
|
|
return self._llm
|
||
|
|
|
||
|
|
def set_model(self, model: str):
|
||
|
|
"""设置模型"""
|
||
|
|
self.model = model
|
||
|
|
self._llm = None # 重置 LLM 实例
|
||
|
|
|
||
|
|
def set_temperature(self, temperature: float):
|
||
|
|
"""设置温度"""
|
||
|
|
self.temperature = temperature
|
||
|
|
if self._llm:
|
||
|
|
self._llm.temperature = temperature
|