2026-03-21 10:13:29 +08:00
|
|
|
"""
|
|
|
|
|
LLM 服务 - 支持多种 LLM 提供商
|
|
|
|
|
OpenAI / Claude / Ollama / DeepSeek / 任意 OpenAI 兼容接口
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
from typing import AsyncIterator
|
2026-03-22 13:47:34 +08:00
|
|
|
from sqlalchemy import select
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
2026-03-21 10:13:29 +08:00
|
|
|
from langchain_core.messages import BaseMessage, AIMessage
|
|
|
|
|
from langchain_openai import ChatOpenAI
|
|
|
|
|
from langchain_anthropic import ChatAnthropic
|
|
|
|
|
from langchain_ollama import ChatOllama
|
|
|
|
|
from app.config import settings
|
2026-03-22 13:47:34 +08:00
|
|
|
from app.models.user import User
|
2026-03-21 10:13:29 +08:00
|
|
|
import httpx
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
os.makedirs(settings.DATA_DIR, exist_ok=True)
|
|
|
|
|
os.makedirs(settings.CHROMA_PERSIST_DIR, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LLMService(ABC):
|
|
|
|
|
@abstractmethod
|
|
|
|
|
async def invoke(self, messages: list[BaseMessage]) -> AIMessage:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def get_model_name(self) -> str:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OpenAICompatibleService(LLMService):
|
|
|
|
|
"""
|
|
|
|
|
OpenAI 兼容接口
|
|
|
|
|
支持 OpenAI、DeepSeek、硅基流动、任意 OpenAI API 兼容服务
|
|
|
|
|
"""
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
api_key: str | None = None,
|
|
|
|
|
model: str | None = None,
|
|
|
|
|
base_url: str | None = None,
|
|
|
|
|
):
|
|
|
|
|
self.api_key = api_key or settings.OPENAI_API_KEY
|
|
|
|
|
self.model = model or settings.OPENAI_MODEL
|
|
|
|
|
self.base_url = base_url or settings.OPENAI_BASE_URL
|
|
|
|
|
self._llm = ChatOpenAI(
|
|
|
|
|
api_key=self.api_key,
|
|
|
|
|
model=self.model,
|
|
|
|
|
base_url=self.base_url,
|
|
|
|
|
timeout=httpx.Timeout(60.0, connect=10.0),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def invoke(self, messages: list[BaseMessage]) -> AIMessage:
|
|
|
|
|
return await self._llm.ainvoke(messages)
|
|
|
|
|
|
|
|
|
|
async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
|
|
|
|
|
async for chunk in self._llm.astream(messages):
|
|
|
|
|
if chunk.content:
|
|
|
|
|
yield chunk.content
|
|
|
|
|
|
|
|
|
|
def get_model_name(self) -> str:
|
|
|
|
|
return self.model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ClaudeService(LLMService):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
api_key: str | None = None,
|
|
|
|
|
model: str | None = None,
|
|
|
|
|
max_tokens: int = 8192,
|
|
|
|
|
):
|
|
|
|
|
self.api_key = api_key or settings.ANTHROPIC_API_KEY
|
|
|
|
|
self.model = model or settings.CLAUDE_MODEL
|
|
|
|
|
self._llm = ChatAnthropic(
|
|
|
|
|
api_key=self.api_key,
|
|
|
|
|
model=self.model,
|
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
|
timeout=httpx.Timeout(60.0, connect=10.0),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def invoke(self, messages: list[BaseMessage]) -> AIMessage:
|
|
|
|
|
return await self._llm.ainvoke(messages)
|
|
|
|
|
|
|
|
|
|
async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
|
|
|
|
|
async for chunk in self._llm.astream(messages):
|
|
|
|
|
if chunk.content:
|
|
|
|
|
yield chunk.content
|
|
|
|
|
|
|
|
|
|
def get_model_name(self) -> str:
|
|
|
|
|
return self.model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OllamaService(LLMService):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
base_url: str | None = None,
|
|
|
|
|
model: str | None = None,
|
|
|
|
|
):
|
|
|
|
|
self.base_url = base_url or settings.OLLAMA_BASE_URL
|
|
|
|
|
self.model = model or settings.OLLAMA_MODEL
|
|
|
|
|
self._llm = ChatOllama(
|
|
|
|
|
base_url=self.base_url,
|
|
|
|
|
model=self.model,
|
|
|
|
|
timeout=httpx.Timeout(120.0, connect=10.0),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def invoke(self, messages: list[BaseMessage]) -> AIMessage:
|
|
|
|
|
return await self._llm.ainvoke(messages)
|
|
|
|
|
|
|
|
|
|
async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
|
|
|
|
|
async for chunk in self._llm.astream(messages):
|
|
|
|
|
if chunk.content:
|
|
|
|
|
yield chunk.content
|
|
|
|
|
|
|
|
|
|
def get_model_name(self) -> str:
|
|
|
|
|
return self.model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 单例缓存
|
|
|
|
|
_llm_instance: LLMService | None = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_llm() -> LLMService:
|
|
|
|
|
"""根据配置获取 LLM 实例"""
|
|
|
|
|
global _llm_instance
|
|
|
|
|
if _llm_instance is None:
|
|
|
|
|
provider = settings.LLM_PROVIDER
|
|
|
|
|
if provider == "openai":
|
|
|
|
|
_llm_instance = OpenAICompatibleService()
|
|
|
|
|
elif provider == "deepseek":
|
|
|
|
|
_llm_instance = OpenAICompatibleService(
|
|
|
|
|
base_url="https://api.deepseek.com/v1",
|
|
|
|
|
model="deepseek-chat",
|
|
|
|
|
)
|
|
|
|
|
elif provider == "custom":
|
|
|
|
|
_llm_instance = OpenAICompatibleService()
|
|
|
|
|
elif provider == "claude":
|
|
|
|
|
_llm_instance = ClaudeService()
|
|
|
|
|
elif provider == "ollama":
|
|
|
|
|
_llm_instance = OllamaService()
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unknown LLM provider: {provider}")
|
|
|
|
|
return _llm_instance
|