"""Base LLM provider interface.""" import asyncio import json from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any from loguru import logger @dataclass class ToolCallRequest: """A tool call request from the LLM.""" id: str name: str arguments: dict[str, Any] provider_specific_fields: dict[str, Any] | None = None def to_openai_tool_call(self) -> dict[str, Any]: """Serialize to an OpenAI-style tool_call payload.""" tool_call = { "id": self.id, "type": "function", "function": { "name": self.name, "arguments": json.dumps(self.arguments, ensure_ascii=False), }, } if self.provider_specific_fields: tool_call["provider_specific_fields"] = self.provider_specific_fields return tool_call @dataclass class LLMResponse: """Response from an LLM provider.""" content: str | None tool_calls: list[ToolCallRequest] = field(default_factory=list) finish_reason: str = "stop" usage: dict[str, int] = field(default_factory=dict) reasoning_content: str | None = None # For reasoning models @property def has_tool_calls(self) -> bool: """Check if response contains tool calls.""" return len(self.tool_calls) > 0 @dataclass(frozen=True) class GenerationSettings: """Default generation parameters for LLM calls.""" temperature: float = 0.7 max_tokens: int = 4096 class LLMProvider(ABC): """ Abstract base class for LLM providers. Implementations should handle the specifics of each provider's API while maintaining a consistent interface. """ _CHAT_RETRY_DELAYS = (1, 2, 4) _TRANSIENT_ERROR_MARKERS = ( "429", "rate limit", "500", "502", "503", "504", "overloaded", "timeout", "timed out", "connection", "server error", "temporarily unavailable", ) _SENTINEL = object() def __init__(self, api_key: str | None = None, api_base: str | None = None): self.api_key = api_key self.api_base = api_base self.generation: GenerationSettings = GenerationSettings() @staticmethod def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: """Replace empty text content that causes provider 400 errors.""" result: list[dict[str, Any]] = [] for msg in messages: content = msg.get("content") if isinstance(content, str) and not content: clean = dict(msg) clean["content"] = None if (msg.get("role") == "assistant" and msg.get("tool_calls")) else "(empty)" result.append(clean) continue if isinstance(content, list): filtered = [ item for item in content if not ( isinstance(item, dict) and item.get("type") in ("text", "input_text", "output_text") and not item.get("text") ) ] if len(filtered) != len(content): clean = dict(msg) if filtered: clean["content"] = filtered elif msg.get("role") == "assistant" and msg.get("tool_calls"): clean["content"] = None else: clean["content"] = "(empty)" result.append(clean) continue if isinstance(content, dict): clean = dict(msg) clean["content"] = [content] result.append(clean) continue result.append(msg) return result @abstractmethod async def chat( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, ) -> LLMResponse: """ Send a chat completion request. Args: messages: List of message dicts with 'role' and 'content'. tools: Optional list of tool definitions. model: Model identifier (provider-specific). max_tokens: Maximum tokens in response. temperature: Sampling temperature. Returns: LLMResponse with content and/or tool calls. """ pass @classmethod def _is_transient_error(cls, content: str | None) -> bool: err = (content or "").lower() return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS) async def chat_with_retry( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, model: str | None = None, max_tokens: object = _SENTINEL, temperature: object = _SENTINEL, ) -> LLMResponse: """Call chat() with retry on transient provider failures.""" if max_tokens is self._SENTINEL: max_tokens = self.generation.max_tokens if temperature is self._SENTINEL: temperature = self.generation.temperature for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1): try: response = await self.chat( messages=messages, tools=tools, model=model, max_tokens=max_tokens, temperature=temperature, ) except asyncio.CancelledError: raise except Exception as exc: response = LLMResponse( content=f"Error calling LLM: {exc}", finish_reason="error", ) if response.finish_reason != "error": return response if not self._is_transient_error(response.content): return response err = (response.content or "").lower() logger.warning( "LLM transient error (attempt {}/{}), retrying in {}s: {}", attempt, len(self._CHAT_RETRY_DELAYS), delay, err[:120], ) await asyncio.sleep(delay) try: return await self.chat( messages=messages, tools=tools, model=model, max_tokens=max_tokens, temperature=temperature, ) except asyncio.CancelledError: raise except Exception as exc: return LLMResponse( content=f"Error calling LLM: {exc}", finish_reason="error", ) @abstractmethod def get_default_model(self) -> str: """Get the default model for this provider.""" pass