from __future__ import annotations from http import HTTPStatus from typing import Any from sqlalchemy.orm import Session from app.core.logging import get_logger from app.services.model_connectivity import ( AZURE_API_VERSION, ConnectivityCheckError, _build_azure_deployment_base, _build_headers, _ensure_path, _normalize_endpoint, _send_json_request, ) from app.services.settings import SettingsService logger = get_logger("app.services.runtime_chat") class RuntimeChatService: def __init__(self, db: Session) -> None: self.db = db self.settings_service = SettingsService(db) def complete( self, messages: list[dict[str, Any]], *, slot_priority: tuple[str, ...] = ("main", "backup"), max_tokens: int = 500, temperature: float = 0.2, ) -> str | None: for slot in slot_priority: config = self._load_chat_slot(slot) if config is None: continue try: response_text = self._request_chat_completion( config, messages, max_tokens=max_tokens, temperature=temperature, ) except Exception as exc: logger.warning( "Runtime chat request failed slot=%s provider=%s: %s", slot, config["provider"], exc, ) continue if response_text: return response_text.strip() return None def _load_chat_slot(self, slot: str) -> dict[str, str] | None: try: config = self.settings_service.get_runtime_model_config(slot) except ValueError: return None if config["capability"] != "chat": return None provider = str(config["provider"] or "").strip() endpoint = str(config["endpoint"] or "").strip() model = str(config["model"] or "").strip() api_key = str(config["apiKey"] or "").strip() if not provider or not endpoint or not model: return None if provider != "Ollama" and not api_key: logger.info("Skip runtime chat slot=%s because api key is empty", slot) return None return { "slot": slot, "provider": provider, "endpoint": endpoint, "model": model, "apiKey": api_key, } def _request_chat_completion( self, config: dict[str, str], messages: list[dict[str, Any]], *, max_tokens: int, temperature: float, ) -> str: provider = config["provider"] endpoint = config["endpoint"] model = config["model"] api_key = config["apiKey"] if provider == "Azure OpenAI": return self._request_azure_openai( endpoint=endpoint, model=model, api_key=api_key, messages=messages, max_tokens=max_tokens, temperature=temperature, ) if provider == "Ollama": return self._request_ollama( endpoint=endpoint, model=model, api_key=api_key, messages=messages, max_tokens=max_tokens, temperature=temperature, ) return self._request_openai_compatible( endpoint=endpoint, model=model, api_key=api_key, messages=messages, max_tokens=max_tokens, temperature=temperature, ) def _request_openai_compatible( self, *, endpoint: str, model: str, api_key: str, messages: list[dict[str, Any]], max_tokens: int, temperature: float, ) -> str: url = _ensure_path(_normalize_endpoint(endpoint), "chat/completions") status_code, payload = _send_json_request( "POST", url, headers=_build_headers(api_key=api_key, use_bearer=True), payload={ "model": model, "messages": messages, "max_tokens": max_tokens, "temperature": temperature, }, ) if status_code >= HTTPStatus.BAD_REQUEST: raise ConnectivityCheckError( f"模型接口返回异常状态 {status_code}。", status_code=status_code, ) return self._extract_openai_text(payload) def _request_ollama( self, *, endpoint: str, model: str, api_key: str, messages: list[dict[str, Any]], max_tokens: int, temperature: float, ) -> str: url = _ensure_path(_normalize_endpoint(endpoint), "api/chat") status_code, payload = _send_json_request( "POST", url, headers=_build_headers(api_key=api_key, use_bearer=False), payload={ "model": model, "messages": messages, "stream": False, "options": { "num_predict": max_tokens, "temperature": temperature, }, }, ) if status_code >= HTTPStatus.BAD_REQUEST: raise ConnectivityCheckError( f"Ollama 返回异常状态 {status_code}。", status_code=status_code, ) return str((payload or {}).get("message", {}).get("content", "")).strip() def _request_azure_openai( self, *, endpoint: str, model: str, api_key: str, messages: list[dict[str, Any]], max_tokens: int, temperature: float, ) -> str: deployment_base = _build_azure_deployment_base(endpoint, model) url = f"{deployment_base}/chat/completions?api-version={AZURE_API_VERSION}" status_code, payload = _send_json_request( "POST", url, headers=_build_headers(api_key=api_key, use_bearer=False, use_api_key=True), payload={ "messages": messages, "max_tokens": max_tokens, "temperature": temperature, }, ) if status_code >= HTTPStatus.BAD_REQUEST: raise ConnectivityCheckError( f"Azure OpenAI 返回异常状态 {status_code}。", status_code=status_code, ) return self._extract_openai_text(payload) @staticmethod def _extract_openai_text(payload: Any) -> str: if not isinstance(payload, dict): return "" choices = payload.get("choices") if not isinstance(choices, list) or not choices: return "" first_choice = choices[0] if not isinstance(first_choice, dict): return "" message = first_choice.get("message") if isinstance(message, dict): content = message.get("content", "") if isinstance(content, str): return content.strip() if isinstance(content, list): parts: list[str] = [] for item in content: if isinstance(item, dict) and item.get("type") == "text": parts.append(str(item.get("text", ""))) return "\n".join(part.strip() for part in parts if part.strip()).strip() text = first_choice.get("text") if isinstance(text, str): return text.strip() return ""