"""LLM Provider base classes and implementations.""" import asyncio import json import logging from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any, AsyncGenerator logger = logging.getLogger(__name__) @dataclass class ToolCallRequest: """A tool call request from the LLM.""" id: str name: str arguments: dict[str, Any] def to_dict(self) -> dict[str, Any]: """Serialize to dict.""" return { "id": self.id, "type": "function", "function": { "name": self.name, "arguments": json.dumps(self.arguments, ensure_ascii=False), }, } @dataclass class LLMResponse: """Response from an LLM provider.""" content: str | None tool_calls: list[ToolCallRequest] = field(default_factory=list) finish_reason: str = "stop" usage: dict[str, int] = field(default_factory=dict) reasoning_content: str | None = None # For reasoning models @property def has_tool_calls(self) -> bool: """Check if response contains tool calls.""" return len(self.tool_calls) > 0 @dataclass class GenerationSettings: """Default generation parameters for LLM calls.""" temperature: float = 0.7 max_tokens: int = 4096 class LLMProvider(ABC): """Abstract base class for LLM providers.""" _CHAT_RETRY_DELAYS = (1, 2, 4) _TRANSIENT_ERROR_MARKERS = ( "429", "rate limit", "500", "502", "503", "504", "overloaded", "timeout", "timed out", "connection", "server error", "temporarily unavailable", ) def __init__( self, api_key: str | None = None, api_base: str | None = None, ): self.api_key = api_key self.api_base = api_base self.generation = GenerationSettings() @staticmethod def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: """Sanitize messages to remove empty content that causes provider errors.""" result = [] for msg in messages: content = msg.get("content") if isinstance(content, str) and not content: clean = dict(msg) if msg.get("role") == "assistant" and msg.get("tool_calls"): clean["content"] = None else: clean["content"] = "(empty)" result.append(clean) continue result.append(msg) return result @abstractmethod async def chat( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, stream: bool = False, ) -> LLMResponse | AsyncGenerator[str, None]: """Send a chat completion request.""" pass @classmethod def _is_transient_error(cls, content: str | None) -> bool: err = (content or "").lower() return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS) async def chat_with_retry( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, model: str | None = None, max_tokens: int | None = None, temperature: float | None = None, ) -> LLMResponse: """Call chat() with retry on transient provider failures.""" max_tokens = max_tokens or self.generation.max_tokens temperature = temperature or self.generation.temperature messages = self._sanitize_messages(messages) for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1): try: response = await self.chat( messages=messages, tools=tools, model=model, max_tokens=max_tokens, temperature=temperature, ) except asyncio.CancelledError: raise except Exception as exc: response = LLMResponse( content=f"Error calling LLM: {exc}", finish_reason="error", ) if response.finish_reason != "error": return response if not self._is_transient_error(response.content): return response logger.warning( "LLM transient error (attempt {}/{}), retrying in {}s", attempt, len(self._CHAT_RETRY_DELAYS), delay, ) await asyncio.sleep(delay) # Last attempt try: return await self.chat( messages=messages, tools=tools, model=model, max_tokens=max_tokens, temperature=temperature, ) except asyncio.CancelledError: raise except Exception as exc: return LLMResponse( content=f"Error calling LLM: {exc}", finish_reason="error", ) @abstractmethod def get_default_model(self) -> str: """Get the default model for this provider.""" pass # OpenAI Provider class OpenAIProvider(LLMProvider): """OpenAI LLM provider.""" def __init__( self, api_key: str | None = None, api_base: str | None = None, ): super().__init__(api_key, api_base) self._client = None @property def client(self): """Lazy load OpenAI client.""" if self._client is None: try: from openai import AsyncOpenAI self._client = AsyncOpenAI( api_key=self.api_key, base_url=self.api_base, ) except ImportError: raise ImportError("openai package required: pip install openai") return self._client async def chat( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, stream: bool = False, ) -> LLMResponse: model = model or self.get_default_model() params = { "model": model, "messages": messages, "max_tokens": max_tokens, "temperature": temperature, } if tools: params["tools"] = tools params["tool_choice"] = "auto" try: response = await self.client.chat.completions.create(**params) choice = response.choices[0] msg = choice.message tool_calls = [] if msg.tool_calls: for tc in msg.tool_calls: args = tc.function.arguments if isinstance(args, str): args = json.loads(args) tool_calls.append(ToolCallRequest( id=tc.id, name=tc.function.name, arguments=args, )) return LLMResponse( content=msg.content, tool_calls=tool_calls, finish_reason=choice.finish_reason, usage={ "prompt_tokens": response.usage.prompt_tokens if response.usage else 0, "completion_tokens": response.usage.completion_tokens if response.usage else 0, }, ) except Exception as exc: logger.error(f"OpenAI API error: {exc}") return LLMResponse( content=f"Error: {exc}", finish_reason="error", ) async def chat_stream( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, ) -> AsyncGenerator[str, None]: """Stream chat completions.""" model = model or self.get_default_model() params = { "model": model, "messages": messages, "max_tokens": max_tokens, "temperature": temperature, "stream": True, } if tools: params["tools"] = tools try: response = await self.client.chat.completions.create(**params) async for chunk in response: if chunk.choices and chunk.choices[0].delta.content: yield chunk.choices[0].delta.content except Exception as exc: yield f"Error: {exc}" def get_default_model(self) -> str: return "gpt-4o" # Anthropic Provider class AnthropicProvider(LLMProvider): """Anthropic Claude LLM provider.""" def __init__( self, api_key: str | None = None, api_base: str | None = None, ): super().__init__(api_key, api_base) self._client = None @property def client(self): """Lazy load Anthropic client.""" if self._client is None: try: from anthropic import AsyncAnthropic self._client = AsyncAnthropic( api_key=self.api_key, base_url=self.api_base, ) except ImportError: raise ImportError("anthropic package required: pip install anthropic") return self._client def _convert_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: """Convert messages to Anthropic format.""" converted = [] for msg in messages: role = msg.get("role") if role == "system": # Anthropic puts system in first user message content = msg.get("content", "") if converted and converted[0].get("role") == "user": converted[0]["content"] = f"{content}\n\n{converted[0].content}" else: converted.append({"role": "user", "content": f"{content}"}) else: # Handle tool results if role == "tool": converted.append({ "role": "user", "content": [ { "type": "tool_result", "tool_use_id": msg.get("tool_call_id"), "content": msg.get("content", ""), } ], }) else: converted.append(msg) return converted def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: """Convert OpenAI-style tools to Anthropic format.""" anthropic_tools = [] for tool in tools: func = tool.get("function", {}) anthropic_tools.append({ "name": func.get("name"), "description": func.get("description"), "input_schema": func.get("parameters", {}), }) return anthropic_tools async def chat( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, stream: bool = False, ) -> LLMResponse: model = model or self.get_default_model() params = { "model": model, "max_tokens": max_tokens, "temperature": temperature, "messages": self._convert_messages(messages), } if tools: params["tools"] = self._convert_tools(tools) try: response = await self.client.messages.create(**params) tool_calls = [] for tc in response.tool_calls: args = tc.input if isinstance(args, str): args = json.loads(args) tool_calls.append(ToolCallRequest( id=tc.id, name=tc.name, arguments=args, )) # Get content text content_text = "" thinking = None if response.content: for block in response.content: if block.type == "text": content_text = block.text elif block.type == "thinking": thinking = block.thinking return LLMResponse( content=content_text, tool_calls=tool_calls, finish_reason="stop" if not tool_calls else "tool_use", reasoning_content=thinking, usage={ "input_tokens": response.usage.input_tokens, "output_tokens": response.usage.output_tokens, }, ) except Exception as exc: logger.error(f"Anthropic API error: {exc}") return LLMResponse( content=f"Error: {exc}", finish_reason="error", ) async def chat_stream( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, ) -> AsyncGenerator[str, None]: """Stream chat completions.""" model = model or self.get_default_model() params = { "model": model, "max_tokens": max_tokens, "temperature": temperature, "messages": self._convert_messages(messages), "stream": True, } if tools: params["tools"] = self._convert_tools(tools) try: async with self.client.messages.stream(**params) as stream: async for text in stream.text_stream: yield text except Exception as exc: yield f"Error: {exc}" def get_default_model(self) -> str: return "claude-sonnet-4-20250514" # Provider factory class ProviderFactory: """Factory for creating LLM providers.""" _PROVIDERS = { "openai": OpenAIProvider, "anthropic": AnthropicProvider, } @classmethod def create( cls, provider: str, api_key: str | None = None, api_base: str | None = None, ) -> LLMProvider: """Create an LLM provider instance. Args: provider: Provider name (openai, anthropic) api_key: API key api_base: Optional base URL for API Returns: LLM provider instance """ provider_cls = cls._PROVIDERS.get(provider.lower()) if not provider_cls: raise ValueError(f"Unknown provider: {provider}. Available: {list(cls._PROVIDERS.keys())}") return provider_cls(api_key=api_key, api_base=api_base)