from __future__ import annotations from http import HTTPStatus from time import monotonic, sleep from typing import Any from sqlalchemy.orm import Session from app.core.logging import get_logger from app.services.model_connectivity import ( AZURE_API_VERSION, ConnectivityCheckError, _build_azure_deployment_base, _build_headers, _ensure_path, _normalize_endpoint, _send_json_request, ) from app.services.settings import SettingsService logger = get_logger("app.services.runtime_chat") DEFAULT_RUNTIME_CHAT_TIMEOUT_SECONDS = 45 DEFAULT_RUNTIME_CHAT_RETRY_ATTEMPTS = 2 DEFAULT_RUNTIME_CHAT_RETRY_DELAY_SECONDS = 0.6 DEFAULT_RUNTIME_CHAT_FAILURE_COOLDOWN_SECONDS = 90 _slot_failure_until: dict[str, float] = {} class RuntimeChatService: def __init__(self, db: Session) -> None: self.db = db self.settings_service = SettingsService(db) def complete( self, messages: list[dict[str, Any]], *, slot_priority: tuple[str, ...] = ("main", "backup"), max_tokens: int = 500, temperature: float = 0.2, timeout_seconds: int | None = None, slot_timeouts: dict[str, int] | None = None, max_attempts: int | None = None, ) -> str | None: configs = [ config for slot in slot_priority if (config := self._load_chat_slot(slot)) is not None ] resolved_timeout_seconds = timeout_seconds or DEFAULT_RUNTIME_CHAT_TIMEOUT_SECONDS resolved_slot_timeouts = dict(slot_timeouts or {}) resolved_max_attempts = max_attempts or DEFAULT_RUNTIME_CHAT_RETRY_ATTEMPTS for attempt in range(1, resolved_max_attempts + 1): for config in configs: cache_key = self._build_slot_cache_key(config) if _slot_failure_until.get(cache_key, 0.0) > monotonic(): logger.info( "Skip runtime chat slot=%s provider=%s because it is in cooldown", config["slot"], config["provider"], ) continue try: response_text = self._request_chat_completion( config, messages, max_tokens=max_tokens, temperature=temperature, timeout_seconds=resolved_slot_timeouts.get( config["slot"], resolved_timeout_seconds, ), ) if response_text: _slot_failure_until.pop(cache_key, None) return response_text.strip() except Exception as exc: _slot_failure_until[cache_key] = ( monotonic() + DEFAULT_RUNTIME_CHAT_FAILURE_COOLDOWN_SECONDS ) logger.warning( "Runtime chat request failed slot=%s provider=%s attempt=%s/%s: %s", config["slot"], config["provider"], attempt, resolved_max_attempts, exc, ) if attempt < resolved_max_attempts: sleep(DEFAULT_RUNTIME_CHAT_RETRY_DELAY_SECONDS) return None @staticmethod def _build_slot_cache_key(config: dict[str, str]) -> str: return "|".join( [ str(config.get("slot") or ""), str(config.get("provider") or ""), str(config.get("endpoint") or ""), str(config.get("model") or ""), ] ) def _load_chat_slot(self, slot: str) -> dict[str, str] | None: try: config = self.settings_service.get_runtime_model_config(slot) except ValueError: return None if config["capability"] != "chat": return None provider = str(config["provider"] or "").strip() endpoint = str(config["endpoint"] or "").strip() model = str(config["model"] or "").strip() api_key = str(config["apiKey"] or "").strip() if not provider or not endpoint or not model: return None if provider != "Ollama" and not api_key: logger.info("Skip runtime chat slot=%s because api key is empty", slot) return None return { "slot": slot, "provider": provider, "endpoint": endpoint, "model": model, "apiKey": api_key, } def _request_chat_completion( self, config: dict[str, str], messages: list[dict[str, Any]], *, max_tokens: int, temperature: float, timeout_seconds: int, ) -> str: provider = config["provider"] endpoint = config["endpoint"] model = config["model"] api_key = config["apiKey"] if provider == "Azure OpenAI": return self._request_azure_openai( endpoint=endpoint, model=model, api_key=api_key, messages=messages, max_tokens=max_tokens, temperature=temperature, timeout_seconds=timeout_seconds, ) if provider == "Ollama": return self._request_ollama( endpoint=endpoint, model=model, api_key=api_key, messages=messages, max_tokens=max_tokens, temperature=temperature, timeout_seconds=timeout_seconds, ) return self._request_openai_compatible( provider=provider, endpoint=endpoint, model=model, api_key=api_key, messages=messages, max_tokens=max_tokens, temperature=temperature, timeout_seconds=timeout_seconds, ) def _request_openai_compatible( self, *, provider: str, endpoint: str, model: str, api_key: str, messages: list[dict[str, Any]], max_tokens: int, temperature: float, timeout_seconds: int, ) -> str: url = _ensure_path(_normalize_endpoint(endpoint), "chat/completions") request_payload: dict[str, Any] = { "model": model, "messages": messages, "max_tokens": max_tokens, "temperature": temperature, } if provider == "GLM": request_payload["thinking"] = {"type": "disabled"} status_code, payload = _send_json_request( "POST", url, headers=_build_headers(api_key=api_key, use_bearer=True), payload=request_payload, timeout_seconds=timeout_seconds, ) if status_code >= HTTPStatus.BAD_REQUEST: raise ConnectivityCheckError( f"模型接口返回异常状态 {status_code}。", status_code=status_code, ) return self._extract_openai_text(payload) def _request_ollama( self, *, endpoint: str, model: str, api_key: str, messages: list[dict[str, Any]], max_tokens: int, temperature: float, timeout_seconds: int, ) -> str: url = _ensure_path(_normalize_endpoint(endpoint), "api/chat") status_code, payload = _send_json_request( "POST", url, headers=_build_headers(api_key=api_key, use_bearer=False), payload={ "model": model, "messages": messages, "stream": False, "options": { "num_predict": max_tokens, "temperature": temperature, }, }, timeout_seconds=timeout_seconds, ) if status_code >= HTTPStatus.BAD_REQUEST: raise ConnectivityCheckError( f"Ollama 返回异常状态 {status_code}。", status_code=status_code, ) return str((payload or {}).get("message", {}).get("content", "")).strip() def _request_azure_openai( self, *, endpoint: str, model: str, api_key: str, messages: list[dict[str, Any]], max_tokens: int, temperature: float, timeout_seconds: int, ) -> str: deployment_base = _build_azure_deployment_base(endpoint, model) url = f"{deployment_base}/chat/completions?api-version={AZURE_API_VERSION}" status_code, payload = _send_json_request( "POST", url, headers=_build_headers(api_key=api_key, use_bearer=False, use_api_key=True), payload={ "messages": messages, "max_tokens": max_tokens, "temperature": temperature, }, timeout_seconds=timeout_seconds, ) if status_code >= HTTPStatus.BAD_REQUEST: raise ConnectivityCheckError( f"Azure OpenAI 返回异常状态 {status_code}。", status_code=status_code, ) return self._extract_openai_text(payload) @staticmethod def _extract_openai_text(payload: Any) -> str: if not isinstance(payload, dict): return "" choices = payload.get("choices") if not isinstance(choices, list) or not choices: return "" first_choice = choices[0] if not isinstance(first_choice, dict): return "" message = first_choice.get("message") if isinstance(message, dict): content = message.get("content", "") if isinstance(content, str): return content.strip() if isinstance(content, list): parts: list[str] = [] for item in content: if isinstance(item, dict) and item.get("type") == "text": parts.append(str(item.get("text", ""))) return "\n".join(part.strip() for part in parts if part.strip()).strip() text = first_choice.get("text") if isinstance(text, str): return text.strip() return ""