""" LLM 服务 - 支持多种 LLM 提供商 OpenAI / Claude / Ollama / DeepSeek / 任意 OpenAI 兼容接口 """ from abc import ABC, abstractmethod from dataclasses import dataclass from typing import AsyncIterator, Literal from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from langchain_core.messages import BaseMessage, AIMessage from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_ollama import ChatOllama from app.config import settings from app.models.user import User import httpx import os ToolStrategy = Literal["native", "json_fallback"] def _resolve_effective_base_url(config: dict | None) -> str: provider = str((config or {}).get("provider") or settings.LLM_PROVIDER or "openai").strip().lower() base_url = str((config or {}).get("base_url") or "").strip() if base_url: return base_url if provider in {"openai", "custom", "deepseek"}: return settings.OPENAI_BASE_URL if provider == "ollama": return settings.OLLAMA_BASE_URL return "" @dataclass(frozen=True) class ProviderCapabilities: provider: str supports_native_tools: bool preferred_tool_strategy: ToolStrategy def default_provider_capabilities() -> ProviderCapabilities: return resolve_provider_capabilities({"provider": settings.LLM_PROVIDER}) def normalize_provider_name(config: dict | None) -> str: provider_raw = str((config or {}).get("provider") or "").strip().lower() provider = provider_raw or str(settings.LLM_PROVIDER or "openai").strip().lower() model = str((config or {}).get("model") or "").strip().lower() base_url = _resolve_effective_base_url(config).strip().lower() # base_url-first inference (provider may be omitted in user config) if base_url: if any(key in base_url for key in {"localhost:11434", "127.0.0.1:11434"}): return "ollama" if any(key in base_url for key in {"api.anthropic.com", "anthropic"}): return "claude" if "api.deepseek.com" in base_url: return "deepseek" # Many "openai-compatible" endpoints are configured as provider=openai. # We treat them as distinct providers so capability routing can stay conservative. if provider in {"openai", "custom"}: if any(key in model or key in base_url for key in {"minimax", "abab"}): return "minimax" if any(key in model or key in base_url for key in {"kimi", "moonshot"}): return "kimi" if any(key in model or key in base_url for key in {"qwen", "dashscope", "aliyuncs"}): return "qwen" return provider def resolve_provider_capabilities(config: dict | None) -> ProviderCapabilities: provider = normalize_provider_name(config) # Conservative default: only treat official OpenAI + DeepSeek + Claude as reliable native tool providers. # Many OpenAI-compatible endpoints reject tool / response_format / other chat params. native_tool_providers = {"openai", "deepseek", "claude"} base_url = _resolve_effective_base_url(config).strip().lower() is_official_openai = ( provider != "openai" or not base_url or "api.openai.com" in base_url or "openai.azure.com" in base_url ) if provider in native_tool_providers and is_official_openai: return ProviderCapabilities( provider=provider, supports_native_tools=True, preferred_tool_strategy="native", ) return ProviderCapabilities( provider=provider, supports_native_tools=False, preferred_tool_strategy="json_fallback", ) def create_llm_from_config(config: dict | None): """根据用户模型配置创建底层 LangChain LLM 实例""" if not config: return get_llm() provider = normalize_provider_name(config) model = config.get("model", "") api_key = config.get("api_key", "") base_url = config.get("base_url", "") if provider in {"openai", "deepseek", "custom", "minimax", "kimi", "qwen"}: llm = ChatOpenAI( api_key=api_key, model=model, base_url=base_url or None, timeout=httpx.Timeout(60.0, connect=10.0), ) elif provider == "claude": llm = ChatAnthropic( api_key=api_key, model=model, timeout=httpx.Timeout(60.0, connect=10.0), ) elif provider == "ollama": llm = ChatOllama( base_url=base_url or "http://localhost:11434", model=model, timeout=httpx.Timeout(120.0, connect=10.0), ) else: llm = ChatOpenAI( api_key=api_key, model=model, base_url=base_url or None, timeout=httpx.Timeout(60.0, connect=10.0), ) setattr(llm, "_jarvis_user_llm_config", config) setattr(llm, "_jarvis_provider_capabilities", resolve_provider_capabilities(config)) return llm class LLMService(ABC): @abstractmethod async def invoke(self, messages: list[BaseMessage]) -> AIMessage: raise NotImplementedError @abstractmethod async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]: raise NotImplementedError @abstractmethod def get_model_name(self) -> str: raise NotImplementedError class OpenAICompatibleService(LLMService): """ OpenAI 兼容接口 支持 OpenAI、DeepSeek、硅基流动、任意 OpenAI API 兼容服务 """ def __init__( self, api_key: str | None = None, model: str | None = None, base_url: str | None = None, ): self.api_key = api_key or settings.OPENAI_API_KEY self.model = model or settings.OPENAI_MODEL self.base_url = base_url or settings.OPENAI_BASE_URL self._llm = ChatOpenAI( api_key=self.api_key, model=self.model, base_url=self.base_url, timeout=httpx.Timeout(60.0, connect=10.0), ) async def invoke(self, messages: list[BaseMessage]) -> AIMessage: return await self._llm.ainvoke(messages) async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]: async for chunk in self._llm.astream(messages): if chunk.content: yield chunk.content def get_model_name(self) -> str: return self.model class ClaudeService(LLMService): def __init__( self, api_key: str | None = None, model: str | None = None, max_tokens: int = 8192, ): self.api_key = api_key or settings.ANTHROPIC_API_KEY self.model = model or settings.CLAUDE_MODEL self._llm = ChatAnthropic( api_key=self.api_key, model=self.model, max_tokens=max_tokens, timeout=httpx.Timeout(60.0, connect=10.0), ) async def invoke(self, messages: list[BaseMessage]) -> AIMessage: return await self._llm.ainvoke(messages) async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]: async for chunk in self._llm.astream(messages): if chunk.content: yield chunk.content def get_model_name(self) -> str: return self.model class OllamaService(LLMService): def __init__( self, base_url: str | None = None, model: str | None = None, ): self.base_url = base_url or settings.OLLAMA_BASE_URL self.model = model or settings.OLLAMA_MODEL self._llm = ChatOllama( base_url=self.base_url, model=self.model, timeout=httpx.Timeout(120.0, connect=10.0), ) async def invoke(self, messages: list[BaseMessage]) -> AIMessage: return await self._llm.ainvoke(messages) async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]: async for chunk in self._llm.astream(messages): if chunk.content: yield chunk.content def get_model_name(self) -> str: return self.model # 单例缓存 _llm_instance: LLMService | None = None def get_llm() -> LLMService: """根据配置获取 LLM 实例""" global _llm_instance if _llm_instance is None: provider = settings.LLM_PROVIDER if provider == "openai": _llm_instance = OpenAICompatibleService() elif provider == "deepseek": _llm_instance = OpenAICompatibleService( base_url="https://api.deepseek.com/v1", model="deepseek-chat", ) elif provider == "custom": _llm_instance = OpenAICompatibleService() elif provider == "claude": _llm_instance = ClaudeService() elif provider == "ollama": _llm_instance = OllamaService() else: raise ValueError(f"Unknown LLM provider: {provider}") setattr(_llm_instance, "_jarvis_provider_capabilities", default_provider_capabilities()) return _llm_instance