feat: enhance agent orchestration, knowledge flow and UI refinements
This commit is contained in:
@@ -4,7 +4,8 @@ OpenAI / Claude / Ollama / DeepSeek / 任意 OpenAI 兼容接口
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncIterator, Literal
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from langchain_core.messages import BaseMessage, AIMessage
|
||||
@@ -16,8 +17,131 @@ from app.models.user import User
|
||||
import httpx
|
||||
import os
|
||||
|
||||
os.makedirs(settings.DATA_DIR, exist_ok=True)
|
||||
os.makedirs(settings.CHROMA_PERSIST_DIR, exist_ok=True)
|
||||
|
||||
|
||||
ToolStrategy = Literal["native", "json_fallback"]
|
||||
|
||||
|
||||
def _resolve_effective_base_url(config: dict | None) -> str:
|
||||
provider = str((config or {}).get("provider") or settings.LLM_PROVIDER or "openai").strip().lower()
|
||||
base_url = str((config or {}).get("base_url") or "").strip()
|
||||
if base_url:
|
||||
return base_url
|
||||
if provider in {"openai", "custom", "deepseek"}:
|
||||
return settings.OPENAI_BASE_URL
|
||||
if provider == "ollama":
|
||||
return settings.OLLAMA_BASE_URL
|
||||
return ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProviderCapabilities:
|
||||
provider: str
|
||||
supports_native_tools: bool
|
||||
preferred_tool_strategy: ToolStrategy
|
||||
|
||||
|
||||
def default_provider_capabilities() -> ProviderCapabilities:
|
||||
return resolve_provider_capabilities({"provider": settings.LLM_PROVIDER})
|
||||
|
||||
|
||||
def normalize_provider_name(config: dict | None) -> str:
|
||||
provider_raw = str((config or {}).get("provider") or "").strip().lower()
|
||||
provider = provider_raw or str(settings.LLM_PROVIDER or "openai").strip().lower()
|
||||
model = str((config or {}).get("model") or "").strip().lower()
|
||||
base_url = _resolve_effective_base_url(config).strip().lower()
|
||||
|
||||
# base_url-first inference (provider may be omitted in user config)
|
||||
if base_url:
|
||||
if any(key in base_url for key in {"localhost:11434", "127.0.0.1:11434"}):
|
||||
return "ollama"
|
||||
if any(key in base_url for key in {"api.anthropic.com", "anthropic"}):
|
||||
return "claude"
|
||||
if "api.deepseek.com" in base_url:
|
||||
return "deepseek"
|
||||
|
||||
# Many "openai-compatible" endpoints are configured as provider=openai.
|
||||
# We treat them as distinct providers so capability routing can stay conservative.
|
||||
if provider in {"openai", "custom"}:
|
||||
if any(key in model or key in base_url for key in {"minimax", "abab"}):
|
||||
return "minimax"
|
||||
if any(key in model or key in base_url for key in {"kimi", "moonshot"}):
|
||||
return "kimi"
|
||||
if any(key in model or key in base_url for key in {"qwen", "dashscope", "aliyuncs"}):
|
||||
return "qwen"
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
def resolve_provider_capabilities(config: dict | None) -> ProviderCapabilities:
|
||||
provider = normalize_provider_name(config)
|
||||
|
||||
# Conservative default: only treat official OpenAI + DeepSeek + Claude as reliable native tool providers.
|
||||
# Many OpenAI-compatible endpoints reject tool / response_format / other chat params.
|
||||
native_tool_providers = {"openai", "deepseek", "claude"}
|
||||
|
||||
base_url = _resolve_effective_base_url(config).strip().lower()
|
||||
is_official_openai = (
|
||||
provider != "openai"
|
||||
or not base_url
|
||||
or "api.openai.com" in base_url
|
||||
or "openai.azure.com" in base_url
|
||||
)
|
||||
|
||||
if provider in native_tool_providers and is_official_openai:
|
||||
return ProviderCapabilities(
|
||||
provider=provider,
|
||||
supports_native_tools=True,
|
||||
preferred_tool_strategy="native",
|
||||
)
|
||||
|
||||
return ProviderCapabilities(
|
||||
provider=provider,
|
||||
supports_native_tools=False,
|
||||
preferred_tool_strategy="json_fallback",
|
||||
)
|
||||
|
||||
|
||||
def create_llm_from_config(config: dict | None):
|
||||
"""根据用户模型配置创建底层 LangChain LLM 实例"""
|
||||
if not config:
|
||||
return get_llm()
|
||||
|
||||
provider = normalize_provider_name(config)
|
||||
model = config.get("model", "")
|
||||
api_key = config.get("api_key", "")
|
||||
base_url = config.get("base_url", "")
|
||||
|
||||
if provider in {"openai", "deepseek", "custom", "minimax", "kimi", "qwen"}:
|
||||
llm = ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
elif provider == "claude":
|
||||
llm = ChatAnthropic(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
elif provider == "ollama":
|
||||
llm = ChatOllama(
|
||||
base_url=base_url or "http://localhost:11434",
|
||||
model=model,
|
||||
timeout=httpx.Timeout(120.0, connect=10.0),
|
||||
)
|
||||
else:
|
||||
llm = ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
|
||||
setattr(llm, "_jarvis_user_llm_config", config)
|
||||
setattr(llm, "_jarvis_provider_capabilities", resolve_provider_capabilities(config))
|
||||
return llm
|
||||
|
||||
|
||||
class LLMService(ABC):
|
||||
@@ -145,4 +269,7 @@ def get_llm() -> LLMService:
|
||||
_llm_instance = OllamaService()
|
||||
else:
|
||||
raise ValueError(f"Unknown LLM provider: {provider}")
|
||||
setattr(_llm_instance, "_jarvis_provider_capabilities", default_provider_capabilities())
|
||||
return _llm_instance
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user