Files
X-Agents/core/agents/providers/openai_provider.py

151 lines
4.7 KiB
Python
Raw Normal View History

"""OpenAI LLM provider implementation."""
import json
import secrets
import string
from typing import Any
import aiohttp
from loguru import logger
from agents.providers.base import LLMProvider, LLMResponse, ToolCallRequest
_ALNUM = string.ascii_letters + string.digits
def _short_tool_id() -> str:
"""Generate a 9-char alphanumeric ID for tool calls."""
return "".join(secrets.choice(_ALNUM) for _ in range(9))
class OpenAIProvider(LLMProvider):
"""OpenAI LLM provider using OpenAI API."""
def __init__(
self,
api_key: str | None = None,
api_base: str | None = None,
default_model: str = "gpt-4o",
):
super().__init__(api_key, api_base)
self.default_model = default_model
self._session: aiohttp.ClientSession | None = None
async def _get_session(self) -> aiohttp.ClientSession:
"""Get or create aiohttp session."""
if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession()
return self._session
async def close(self):
"""Close the HTTP session."""
if self._session and not self._session.closed:
await self._session.close()
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> LLMResponse:
"""Send a chat completion request to OpenAI API."""
model = model or self.default_model
api_base = self.api_base or "https://api.openai.com/v1"
url = f"{api_base}/chat/completions"
headers = {
"Content-Type": "application/json",
}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
# Sanitize messages
messages = self._sanitize_empty_content(messages)
payload: dict[str, Any] = {
"model": model,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
}
if tools:
payload["tools"] = tools
payload["tool_choice"] = "auto"
try:
session = await self._get_session()
async with session.post(url, json=payload, headers=headers) as resp:
if resp.status != 200:
error_text = await resp.text()
return LLMResponse(
content=f"OpenAI API error (status {resp.status}): {error_text}",
finish_reason="error",
)
data = await resp.json()
return self._parse_response(data)
except aiohttp.ClientError as e:
return LLMResponse(
content=f"OpenAI API connection error: {str(e)}",
finish_reason="error",
)
except Exception as e:
return LLMResponse(
content=f"Error calling OpenAI: {str(e)}",
finish_reason="error",
)
def _parse_response(self, data: dict[str, Any]) -> LLMResponse:
"""Parse OpenAI API response into our standard format."""
choices = data.get("choices", [])
if not choices:
return LLMResponse(content="", finish_reason="stop")
choice = choices[0]
message = choice.get("message", {})
content = message.get("content")
finish_reason = choice.get("finish_reason", "stop")
# Parse tool calls
tool_calls = []
raw_tool_calls = message.get("tool_calls", [])
for tc in raw_tool_calls:
func = tc.get("function", {})
args_str = func.get("arguments", "{}")
if isinstance(args_str, str):
try:
args = json.loads(args_str)
except json.JSONDecodeError:
args = {}
else:
args = args_str
tool_calls.append(ToolCallRequest(
id=tc.get("id", _short_tool_id()),
name=func.get("name", ""),
arguments=args,
))
# Parse usage
usage = data.get("usage", {})
usage_dict = {
"prompt_tokens": usage.get("prompt_tokens", 0),
"completion_tokens": usage.get("completion_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
}
return LLMResponse(
content=content,
tool_calls=tool_calls,
finish_reason=finish_reason,
usage=usage_dict,
)
def get_default_model(self) -> str:
"""Get the default model."""
return self.default_model