Files
JARVIS/backend/app/services/llm_service.py

276 lines
8.9 KiB
Python

"""
LLM 服务 - 支持多种 LLM 提供商
OpenAI / Claude / Ollama / DeepSeek / 任意 OpenAI 兼容接口
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import AsyncIterator, Literal
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from langchain_core.messages import BaseMessage, AIMessage
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_ollama import ChatOllama
from app.config import settings
from app.models.user import User
import httpx
import os
ToolStrategy = Literal["native", "json_fallback"]
def _resolve_effective_base_url(config: dict | None) -> str:
provider = str((config or {}).get("provider") or settings.LLM_PROVIDER or "openai").strip().lower()
base_url = str((config or {}).get("base_url") or "").strip()
if base_url:
return base_url
if provider in {"openai", "custom", "deepseek"}:
return settings.OPENAI_BASE_URL
if provider == "ollama":
return settings.OLLAMA_BASE_URL
return ""
@dataclass(frozen=True)
class ProviderCapabilities:
provider: str
supports_native_tools: bool
preferred_tool_strategy: ToolStrategy
def default_provider_capabilities() -> ProviderCapabilities:
return resolve_provider_capabilities({"provider": settings.LLM_PROVIDER})
def normalize_provider_name(config: dict | None) -> str:
provider_raw = str((config or {}).get("provider") or "").strip().lower()
provider = provider_raw or str(settings.LLM_PROVIDER or "openai").strip().lower()
model = str((config or {}).get("model") or "").strip().lower()
base_url = _resolve_effective_base_url(config).strip().lower()
# base_url-first inference (provider may be omitted in user config)
if base_url:
if any(key in base_url for key in {"localhost:11434", "127.0.0.1:11434"}):
return "ollama"
if any(key in base_url for key in {"api.anthropic.com", "anthropic"}):
return "claude"
if "api.deepseek.com" in base_url:
return "deepseek"
# Many "openai-compatible" endpoints are configured as provider=openai.
# We treat them as distinct providers so capability routing can stay conservative.
if provider in {"openai", "custom"}:
if any(key in model or key in base_url for key in {"minimax", "abab"}):
return "minimax"
if any(key in model or key in base_url for key in {"kimi", "moonshot"}):
return "kimi"
if any(key in model or key in base_url for key in {"qwen", "dashscope", "aliyuncs"}):
return "qwen"
return provider
def resolve_provider_capabilities(config: dict | None) -> ProviderCapabilities:
provider = normalize_provider_name(config)
# Conservative default: only treat official OpenAI + DeepSeek + Claude as reliable native tool providers.
# Many OpenAI-compatible endpoints reject tool / response_format / other chat params.
native_tool_providers = {"openai", "deepseek", "claude"}
base_url = _resolve_effective_base_url(config).strip().lower()
is_official_openai = (
provider != "openai"
or not base_url
or "api.openai.com" in base_url
or "openai.azure.com" in base_url
)
if provider in native_tool_providers and is_official_openai:
return ProviderCapabilities(
provider=provider,
supports_native_tools=True,
preferred_tool_strategy="native",
)
return ProviderCapabilities(
provider=provider,
supports_native_tools=False,
preferred_tool_strategy="json_fallback",
)
def create_llm_from_config(config: dict | None):
"""根据用户模型配置创建底层 LangChain LLM 实例"""
if not config:
return get_llm()
provider = normalize_provider_name(config)
model = config.get("model", "")
api_key = config.get("api_key", "")
base_url = config.get("base_url", "")
if provider in {"openai", "deepseek", "custom", "minimax", "kimi", "qwen"}:
llm = ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or None,
timeout=httpx.Timeout(60.0, connect=10.0),
)
elif provider == "claude":
llm = ChatAnthropic(
api_key=api_key,
model=model,
timeout=httpx.Timeout(60.0, connect=10.0),
)
elif provider == "ollama":
llm = ChatOllama(
base_url=base_url or "http://localhost:11434",
model=model,
timeout=httpx.Timeout(120.0, connect=10.0),
)
else:
llm = ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or None,
timeout=httpx.Timeout(60.0, connect=10.0),
)
setattr(llm, "_jarvis_user_llm_config", config)
setattr(llm, "_jarvis_provider_capabilities", resolve_provider_capabilities(config))
return llm
class LLMService(ABC):
@abstractmethod
async def invoke(self, messages: list[BaseMessage]) -> AIMessage:
raise NotImplementedError
@abstractmethod
async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
raise NotImplementedError
@abstractmethod
def get_model_name(self) -> str:
raise NotImplementedError
class OpenAICompatibleService(LLMService):
"""
OpenAI 兼容接口
支持 OpenAI、DeepSeek、硅基流动、任意 OpenAI API 兼容服务
"""
def __init__(
self,
api_key: str | None = None,
model: str | None = None,
base_url: str | None = None,
):
self.api_key = api_key or settings.OPENAI_API_KEY
self.model = model or settings.OPENAI_MODEL
self.base_url = base_url or settings.OPENAI_BASE_URL
self._llm = ChatOpenAI(
api_key=self.api_key,
model=self.model,
base_url=self.base_url,
timeout=httpx.Timeout(60.0, connect=10.0),
)
async def invoke(self, messages: list[BaseMessage]) -> AIMessage:
return await self._llm.ainvoke(messages)
async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
async for chunk in self._llm.astream(messages):
if chunk.content:
yield chunk.content
def get_model_name(self) -> str:
return self.model
class ClaudeService(LLMService):
def __init__(
self,
api_key: str | None = None,
model: str | None = None,
max_tokens: int = 8192,
):
self.api_key = api_key or settings.ANTHROPIC_API_KEY
self.model = model or settings.CLAUDE_MODEL
self._llm = ChatAnthropic(
api_key=self.api_key,
model=self.model,
max_tokens=max_tokens,
timeout=httpx.Timeout(60.0, connect=10.0),
)
async def invoke(self, messages: list[BaseMessage]) -> AIMessage:
return await self._llm.ainvoke(messages)
async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
async for chunk in self._llm.astream(messages):
if chunk.content:
yield chunk.content
def get_model_name(self) -> str:
return self.model
class OllamaService(LLMService):
def __init__(
self,
base_url: str | None = None,
model: str | None = None,
):
self.base_url = base_url or settings.OLLAMA_BASE_URL
self.model = model or settings.OLLAMA_MODEL
self._llm = ChatOllama(
base_url=self.base_url,
model=self.model,
timeout=httpx.Timeout(120.0, connect=10.0),
)
async def invoke(self, messages: list[BaseMessage]) -> AIMessage:
return await self._llm.ainvoke(messages)
async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
async for chunk in self._llm.astream(messages):
if chunk.content:
yield chunk.content
def get_model_name(self) -> str:
return self.model
# 单例缓存
_llm_instance: LLMService | None = None
def get_llm() -> LLMService:
"""根据配置获取 LLM 实例"""
global _llm_instance
if _llm_instance is None:
provider = settings.LLM_PROVIDER
if provider == "openai":
_llm_instance = OpenAICompatibleService()
elif provider == "deepseek":
_llm_instance = OpenAICompatibleService(
base_url="https://api.deepseek.com/v1",
model="deepseek-chat",
)
elif provider == "custom":
_llm_instance = OpenAICompatibleService()
elif provider == "claude":
_llm_instance = ClaudeService()
elif provider == "ollama":
_llm_instance = OllamaService()
else:
raise ValueError(f"Unknown LLM provider: {provider}")
setattr(_llm_instance, "_jarvis_provider_capabilities", default_provider_capabilities())
return _llm_instance