"""OpenAI LLM provider implementation.""" import json import secrets import string from typing import Any import aiohttp from loguru import logger from agents.providers.base import LLMProvider, LLMResponse, ToolCallRequest _ALNUM = string.ascii_letters + string.digits def _short_tool_id() -> str: """Generate a 9-char alphanumeric ID for tool calls.""" return "".join(secrets.choice(_ALNUM) for _ in range(9)) class OpenAIProvider(LLMProvider): """OpenAI LLM provider using OpenAI API.""" def __init__( self, api_key: str | None = None, api_base: str | None = None, default_model: str = "gpt-4o", ): super().__init__(api_key, api_base) self.default_model = default_model self._session: aiohttp.ClientSession | None = None async def _get_session(self) -> aiohttp.ClientSession: """Get or create aiohttp session.""" if self._session is None or self._session.closed: self._session = aiohttp.ClientSession() return self._session async def close(self): """Close the HTTP session.""" if self._session and not self._session.closed: await self._session.close() async def chat( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, ) -> LLMResponse: """Send a chat completion request to OpenAI API.""" model = model or self.default_model api_base = self.api_base or "https://api.openai.com/v1" url = f"{api_base}/chat/completions" headers = { "Content-Type": "application/json", } if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" # Sanitize messages messages = self._sanitize_empty_content(messages) payload: dict[str, Any] = { "model": model, "messages": messages, "max_tokens": max_tokens, "temperature": temperature, } if tools: payload["tools"] = tools payload["tool_choice"] = "auto" try: session = await self._get_session() async with session.post(url, json=payload, headers=headers) as resp: if resp.status != 200: error_text = await resp.text() return LLMResponse( content=f"OpenAI API error (status {resp.status}): {error_text}", finish_reason="error", ) data = await resp.json() return self._parse_response(data) except aiohttp.ClientError as e: return LLMResponse( content=f"OpenAI API connection error: {str(e)}", finish_reason="error", ) except Exception as e: return LLMResponse( content=f"Error calling OpenAI: {str(e)}", finish_reason="error", ) def _parse_response(self, data: dict[str, Any]) -> LLMResponse: """Parse OpenAI API response into our standard format.""" choices = data.get("choices", []) if not choices: return LLMResponse(content="", finish_reason="stop") choice = choices[0] message = choice.get("message", {}) content = message.get("content") finish_reason = choice.get("finish_reason", "stop") # Parse tool calls tool_calls = [] raw_tool_calls = message.get("tool_calls", []) for tc in raw_tool_calls: func = tc.get("function", {}) args_str = func.get("arguments", "{}") if isinstance(args_str, str): try: args = json.loads(args_str) except json.JSONDecodeError: args = {} else: args = args_str tool_calls.append(ToolCallRequest( id=tc.get("id", _short_tool_id()), name=func.get("name", ""), arguments=args, )) # Parse usage usage = data.get("usage", {}) usage_dict = { "prompt_tokens": usage.get("prompt_tokens", 0), "completion_tokens": usage.get("completion_tokens", 0), "total_tokens": usage.get("total_tokens", 0), } return LLMResponse( content=content, tool_calls=tool_calls, finish_reason=finish_reason, usage=usage_dict, ) def get_default_model(self) -> str: """Get the default model.""" return self.default_model