from __future__ import annotations import json from dataclasses import dataclass from http import HTTPStatus from time import monotonic, sleep from typing import Any from sqlalchemy.orm import Session from app.core.logging import get_logger from app.services.model_connectivity import ( AZURE_API_VERSION, ConnectivityCheckError, _build_azure_deployment_base, _build_headers, _ensure_path, _normalize_endpoint, _send_json_request, ) from app.services.settings import SettingsService logger = get_logger("app.services.runtime_chat") DEFAULT_RUNTIME_CHAT_TIMEOUT_SECONDS = 45 DEFAULT_RUNTIME_CHAT_RETRY_ATTEMPTS = 2 DEFAULT_RUNTIME_CHAT_RETRY_DELAY_SECONDS = 0.6 DEFAULT_RUNTIME_CHAT_FAILURE_COOLDOWN_SECONDS = 90 _slot_failure_until: dict[str, float] = {} def clear_runtime_chat_failure_cache() -> int: cleared_count = len(_slot_failure_until) _slot_failure_until.clear() return cleared_count @dataclass(slots=True) class RuntimeChatCallTrace: slot: str provider: str model: str attempt: int status: str duration_ms: int = 0 error_message: str | None = None skipped_reason: str | None = None def model_dump(self) -> dict[str, Any]: return { "slot": self.slot, "provider": self.provider, "model": self.model, "attempt": self.attempt, "status": self.status, "duration_ms": self.duration_ms, "error_message": self.error_message, "skipped_reason": self.skipped_reason, } @dataclass(slots=True) class RuntimeChatResult: text: str | None calls: list[RuntimeChatCallTrace] def calls_as_dicts(self) -> list[dict[str, Any]]: return [item.model_dump() for item in self.calls] @dataclass(slots=True) class RuntimeChatToolCall: name: str arguments: dict[str, Any] call_id: str | None = None raw_arguments: str = "" @dataclass(slots=True) class RuntimeToolCallResult: tool_call: RuntimeChatToolCall | None calls: list[RuntimeChatCallTrace] def calls_as_dicts(self) -> list[dict[str, Any]]: return [item.model_dump() for item in self.calls] class RuntimeChatService: def __init__(self, db: Session) -> None: self.db = db self.settings_service = SettingsService(db) def complete( self, messages: list[dict[str, Any]], *, slot_priority: tuple[str, ...] = ("main", "backup"), max_tokens: int = 500, temperature: float = 0.2, timeout_seconds: int | None = None, slot_timeouts: dict[str, int] | None = None, max_attempts: int | None = None, ) -> str | None: return self.complete_with_trace( messages, slot_priority=slot_priority, max_tokens=max_tokens, temperature=temperature, timeout_seconds=timeout_seconds, slot_timeouts=slot_timeouts, max_attempts=max_attempts, ).text def complete_with_trace( self, messages: list[dict[str, Any]], *, slot_priority: tuple[str, ...] = ("main", "backup"), max_tokens: int = 500, temperature: float = 0.2, timeout_seconds: int | None = None, slot_timeouts: dict[str, int] | None = None, max_attempts: int | None = None, ) -> RuntimeChatResult: configs: list[dict[str, str]] = [] calls: list[RuntimeChatCallTrace] = [] for slot in slot_priority: config = self._load_chat_slot(slot) if config is None: calls.append( RuntimeChatCallTrace( slot=slot, provider="", model="", attempt=0, status="skipped", skipped_reason="not_configured", ) ) continue configs.append(config) if not configs: return RuntimeChatResult(None, calls) resolved_timeout_seconds = timeout_seconds or DEFAULT_RUNTIME_CHAT_TIMEOUT_SECONDS resolved_slot_timeouts = dict(slot_timeouts or {}) resolved_max_attempts = max_attempts or DEFAULT_RUNTIME_CHAT_RETRY_ATTEMPTS for attempt in range(1, resolved_max_attempts + 1): for config in configs: cache_key = self._build_slot_cache_key(config) if _slot_failure_until.get(cache_key, 0.0) > monotonic(): logger.info( "Skip runtime chat slot=%s provider=%s because it is in cooldown", config["slot"], config["provider"], ) calls.append( RuntimeChatCallTrace( slot=config["slot"], provider=config["provider"], model=config["model"], attempt=attempt, status="skipped", skipped_reason="cooldown", ) ) continue started = monotonic() try: response_text = self._request_chat_completion( config, messages, max_tokens=max_tokens, temperature=temperature, timeout_seconds=resolved_slot_timeouts.get( config["slot"], resolved_timeout_seconds, ), ) duration_ms = int((monotonic() - started) * 1000) if response_text: _slot_failure_until.pop(cache_key, None) calls.append( RuntimeChatCallTrace( slot=config["slot"], provider=config["provider"], model=config["model"], attempt=attempt, status="succeeded", duration_ms=duration_ms, ) ) return RuntimeChatResult(response_text.strip(), calls) calls.append( RuntimeChatCallTrace( slot=config["slot"], provider=config["provider"], model=config["model"], attempt=attempt, status="empty", duration_ms=duration_ms, error_message="模型返回空内容。", ) ) except Exception as exc: duration_ms = int((monotonic() - started) * 1000) _slot_failure_until[cache_key] = ( monotonic() + DEFAULT_RUNTIME_CHAT_FAILURE_COOLDOWN_SECONDS ) calls.append( RuntimeChatCallTrace( slot=config["slot"], provider=config["provider"], model=config["model"], attempt=attempt, status="failed", duration_ms=duration_ms, error_message=str(exc), ) ) logger.warning( "Runtime chat request failed slot=%s provider=%s attempt=%s/%s: %s", config["slot"], config["provider"], attempt, resolved_max_attempts, exc, ) if attempt < resolved_max_attempts: sleep(DEFAULT_RUNTIME_CHAT_RETRY_DELAY_SECONDS) return RuntimeChatResult(None, calls) def complete_with_tool_call( self, messages: list[dict[str, Any]], *, tools: list[dict[str, Any]], tool_choice: dict[str, Any] | str | None = None, slot_priority: tuple[str, ...] = ("main", "backup"), max_tokens: int = 1200, temperature: float = 0.1, timeout_seconds: int | None = None, slot_timeouts: dict[str, int] | None = None, max_attempts: int | None = None, ) -> RuntimeToolCallResult: configs: list[dict[str, str]] = [] calls: list[RuntimeChatCallTrace] = [] for slot in slot_priority: config = self._load_chat_slot(slot) if config is None: calls.append( RuntimeChatCallTrace( slot=slot, provider="", model="", attempt=0, status="skipped", skipped_reason="not_configured", ) ) continue configs.append(config) if not configs: return RuntimeToolCallResult(None, calls) resolved_timeout_seconds = timeout_seconds or DEFAULT_RUNTIME_CHAT_TIMEOUT_SECONDS resolved_slot_timeouts = dict(slot_timeouts or {}) resolved_max_attempts = max_attempts or DEFAULT_RUNTIME_CHAT_RETRY_ATTEMPTS for attempt in range(1, resolved_max_attempts + 1): for config in configs: cache_key = self._build_slot_cache_key(config) if _slot_failure_until.get(cache_key, 0.0) > monotonic(): logger.info( "Skip runtime chat tool slot=%s provider=%s because it is in cooldown", config["slot"], config["provider"], ) calls.append( RuntimeChatCallTrace( slot=config["slot"], provider=config["provider"], model=config["model"], attempt=attempt, status="skipped", skipped_reason="cooldown", ) ) continue started = monotonic() try: tool_call = self._request_chat_tool_call( config, messages, tools=tools, tool_choice=tool_choice, max_tokens=max_tokens, temperature=temperature, timeout_seconds=resolved_slot_timeouts.get( config["slot"], resolved_timeout_seconds, ), ) duration_ms = int((monotonic() - started) * 1000) if tool_call is not None: _slot_failure_until.pop(cache_key, None) calls.append( RuntimeChatCallTrace( slot=config["slot"], provider=config["provider"], model=config["model"], attempt=attempt, status="succeeded", duration_ms=duration_ms, ) ) return RuntimeToolCallResult(tool_call, calls) calls.append( RuntimeChatCallTrace( slot=config["slot"], provider=config["provider"], model=config["model"], attempt=attempt, status="empty", duration_ms=duration_ms, error_message="模型未返回工具调用。", ) ) except Exception as exc: duration_ms = int((monotonic() - started) * 1000) _slot_failure_until[cache_key] = ( monotonic() + DEFAULT_RUNTIME_CHAT_FAILURE_COOLDOWN_SECONDS ) calls.append( RuntimeChatCallTrace( slot=config["slot"], provider=config["provider"], model=config["model"], attempt=attempt, status="failed", duration_ms=duration_ms, error_message=str(exc), ) ) logger.warning( "Runtime chat tool request failed slot=%s provider=%s attempt=%s/%s: %s", config["slot"], config["provider"], attempt, resolved_max_attempts, exc, ) if attempt < resolved_max_attempts: sleep(DEFAULT_RUNTIME_CHAT_RETRY_DELAY_SECONDS) return RuntimeToolCallResult(None, calls) @staticmethod def _build_slot_cache_key(config: dict[str, str]) -> str: return "|".join( [ str(config.get("slot") or ""), str(config.get("provider") or ""), str(config.get("endpoint") or ""), str(config.get("model") or ""), ] ) def _load_chat_slot(self, slot: str) -> dict[str, str] | None: try: config = self.settings_service.get_runtime_model_config(slot) except ValueError: return None if config["capability"] != "chat": return None provider = str(config["provider"] or "").strip() endpoint = str(config["endpoint"] or "").strip() model = str(config["model"] or "").strip() api_key = str(config["apiKey"] or "").strip() if not provider or not endpoint or not model: return None if provider != "Ollama" and not api_key: logger.info("Skip runtime chat slot=%s because api key is empty", slot) return None return { "slot": slot, "provider": provider, "endpoint": endpoint, "model": model, "apiKey": api_key, } def _request_chat_completion( self, config: dict[str, str], messages: list[dict[str, Any]], *, max_tokens: int, temperature: float, timeout_seconds: int, ) -> str: provider = config["provider"] endpoint = config["endpoint"] model = config["model"] api_key = config["apiKey"] if provider == "Azure OpenAI": return self._request_azure_openai( endpoint=endpoint, model=model, api_key=api_key, messages=messages, max_tokens=max_tokens, temperature=temperature, timeout_seconds=timeout_seconds, ) if provider == "Ollama": return self._request_ollama( endpoint=endpoint, model=model, api_key=api_key, messages=messages, max_tokens=max_tokens, temperature=temperature, timeout_seconds=timeout_seconds, ) return self._request_openai_compatible( provider=provider, endpoint=endpoint, model=model, api_key=api_key, messages=messages, max_tokens=max_tokens, temperature=temperature, timeout_seconds=timeout_seconds, ) def _request_chat_tool_call( self, config: dict[str, str], messages: list[dict[str, Any]], *, tools: list[dict[str, Any]], tool_choice: dict[str, Any] | str | None, max_tokens: int, temperature: float, timeout_seconds: int, ) -> RuntimeChatToolCall | None: provider = config["provider"] endpoint = config["endpoint"] model = config["model"] api_key = config["apiKey"] if provider == "Azure OpenAI": return self._request_azure_openai_tool_call( endpoint=endpoint, model=model, api_key=api_key, messages=messages, tools=tools, tool_choice=tool_choice, max_tokens=max_tokens, temperature=temperature, timeout_seconds=timeout_seconds, ) if provider == "Ollama": raise ConnectivityCheckError("Ollama 暂不支持小财管家 function calling。") return self._request_openai_compatible_tool_call( provider=provider, endpoint=endpoint, model=model, api_key=api_key, messages=messages, tools=tools, tool_choice=tool_choice, max_tokens=max_tokens, temperature=temperature, timeout_seconds=timeout_seconds, ) def _request_openai_compatible( self, *, provider: str, endpoint: str, model: str, api_key: str, messages: list[dict[str, Any]], max_tokens: int, temperature: float, timeout_seconds: int, ) -> str: url = _ensure_path(_normalize_endpoint(endpoint), "chat/completions") request_payload: dict[str, Any] = { "model": model, "messages": messages, "max_tokens": max_tokens, "temperature": temperature, } if provider == "GLM": request_payload["thinking"] = {"type": "disabled"} status_code, payload = _send_json_request( "POST", url, headers=_build_headers(api_key=api_key, use_bearer=True), payload=request_payload, timeout_seconds=timeout_seconds, ) if status_code >= HTTPStatus.BAD_REQUEST: raise ConnectivityCheckError( f"模型接口返回异常状态 {status_code}。", status_code=status_code, ) return self._extract_openai_text(payload) def _request_openai_compatible_tool_call( self, *, provider: str, endpoint: str, model: str, api_key: str, messages: list[dict[str, Any]], tools: list[dict[str, Any]], tool_choice: dict[str, Any] | str | None, max_tokens: int, temperature: float, timeout_seconds: int, ) -> RuntimeChatToolCall | None: url = _ensure_path(_normalize_endpoint(endpoint), "chat/completions") request_payload: dict[str, Any] = { "model": model, "messages": messages, "tools": tools, "tool_choice": tool_choice or "auto", "max_tokens": max_tokens, "temperature": temperature, } if provider == "GLM": request_payload["thinking"] = {"type": "disabled"} status_code, payload = _send_json_request( "POST", url, headers=_build_headers(api_key=api_key, use_bearer=True), payload=request_payload, timeout_seconds=timeout_seconds, ) if status_code >= HTTPStatus.BAD_REQUEST: raise ConnectivityCheckError( f"模型接口返回异常状态 {status_code}。", status_code=status_code, ) return self._extract_openai_tool_call(payload) def _request_ollama( self, *, endpoint: str, model: str, api_key: str, messages: list[dict[str, Any]], max_tokens: int, temperature: float, timeout_seconds: int, ) -> str: url = _ensure_path(_normalize_endpoint(endpoint), "api/chat") status_code, payload = _send_json_request( "POST", url, headers=_build_headers(api_key=api_key, use_bearer=False), payload={ "model": model, "messages": messages, "stream": False, "options": { "num_predict": max_tokens, "temperature": temperature, }, }, timeout_seconds=timeout_seconds, ) if status_code >= HTTPStatus.BAD_REQUEST: raise ConnectivityCheckError( f"Ollama 返回异常状态 {status_code}。", status_code=status_code, ) return str((payload or {}).get("message", {}).get("content", "")).strip() def _request_azure_openai( self, *, endpoint: str, model: str, api_key: str, messages: list[dict[str, Any]], max_tokens: int, temperature: float, timeout_seconds: int, ) -> str: deployment_base = _build_azure_deployment_base(endpoint, model) url = f"{deployment_base}/chat/completions?api-version={AZURE_API_VERSION}" status_code, payload = _send_json_request( "POST", url, headers=_build_headers(api_key=api_key, use_bearer=False, use_api_key=True), payload={ "messages": messages, "max_tokens": max_tokens, "temperature": temperature, }, timeout_seconds=timeout_seconds, ) if status_code >= HTTPStatus.BAD_REQUEST: raise ConnectivityCheckError( f"Azure OpenAI 返回异常状态 {status_code}。", status_code=status_code, ) return self._extract_openai_text(payload) def _request_azure_openai_tool_call( self, *, endpoint: str, model: str, api_key: str, messages: list[dict[str, Any]], tools: list[dict[str, Any]], tool_choice: dict[str, Any] | str | None, max_tokens: int, temperature: float, timeout_seconds: int, ) -> RuntimeChatToolCall | None: deployment_base = _build_azure_deployment_base(endpoint, model) url = f"{deployment_base}/chat/completions?api-version={AZURE_API_VERSION}" status_code, payload = _send_json_request( "POST", url, headers=_build_headers(api_key=api_key, use_bearer=False, use_api_key=True), payload={ "messages": messages, "tools": tools, "tool_choice": tool_choice or "auto", "max_tokens": max_tokens, "temperature": temperature, }, timeout_seconds=timeout_seconds, ) if status_code >= HTTPStatus.BAD_REQUEST: raise ConnectivityCheckError( f"Azure OpenAI 返回异常状态 {status_code}。", status_code=status_code, ) return self._extract_openai_tool_call(payload) @staticmethod def _extract_openai_text(payload: Any) -> str: if not isinstance(payload, dict): return "" choices = payload.get("choices") if not isinstance(choices, list) or not choices: return "" first_choice = choices[0] if not isinstance(first_choice, dict): return "" message = first_choice.get("message") if isinstance(message, dict): content = message.get("content", "") if isinstance(content, str): return content.strip() if isinstance(content, list): parts: list[str] = [] for item in content: if isinstance(item, dict) and item.get("type") == "text": parts.append(str(item.get("text", ""))) return "\n".join(part.strip() for part in parts if part.strip()).strip() text = first_choice.get("text") if isinstance(text, str): return text.strip() return "" @staticmethod def _extract_openai_tool_call(payload: Any) -> RuntimeChatToolCall | None: if not isinstance(payload, dict): return None choices = payload.get("choices") if not isinstance(choices, list) or not choices: return None first_choice = choices[0] if not isinstance(first_choice, dict): return None message = first_choice.get("message") if not isinstance(message, dict): return None tool_calls = message.get("tool_calls") if isinstance(tool_calls, list) and tool_calls: first_tool = tool_calls[0] if isinstance(first_tool, dict): function_payload = first_tool.get("function") if isinstance(function_payload, dict): return RuntimeChatService._build_runtime_tool_call( name=function_payload.get("name"), arguments=function_payload.get("arguments"), call_id=first_tool.get("id"), ) function_call = message.get("function_call") if isinstance(function_call, dict): return RuntimeChatService._build_runtime_tool_call( name=function_call.get("name"), arguments=function_call.get("arguments"), call_id=None, ) return None @staticmethod def _build_runtime_tool_call( *, name: Any, arguments: Any, call_id: Any, ) -> RuntimeChatToolCall | None: tool_name = str(name or "").strip() if not tool_name: return None raw_arguments = "" if isinstance(arguments, dict): parsed_arguments = arguments raw_arguments = json.dumps(arguments, ensure_ascii=False) else: raw_arguments = str(arguments or "").strip() if not raw_arguments: parsed_arguments = {} else: parsed = json.loads(raw_arguments) if not isinstance(parsed, dict): raise ValueError("工具调用参数必须是 JSON object。") parsed_arguments = parsed return RuntimeChatToolCall( name=tool_name, arguments=parsed_arguments, call_id=str(call_id).strip() if call_id else None, raw_arguments=raw_arguments, )