242 lines
8.7 KiB
Python
242 lines
8.7 KiB
Python
|
|
"""Anthropic LLM provider implementation."""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import secrets
|
||
|
|
import string
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
import aiohttp
|
||
|
|
from loguru import logger
|
||
|
|
|
||
|
|
from agents.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||
|
|
|
||
|
|
_ALNUM = string.ascii_letters + string.digits
|
||
|
|
|
||
|
|
|
||
|
|
def _short_tool_id() -> str:
|
||
|
|
"""Generate a 9-char alphanumeric ID for tool calls."""
|
||
|
|
return "".join(secrets.choice(_ALNUM) for _ in range(9))
|
||
|
|
|
||
|
|
|
||
|
|
class AnthropicProvider(LLMProvider):
|
||
|
|
"""Anthropic LLM provider using Claude API."""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
api_key: str | None = None,
|
||
|
|
api_base: str | None = None,
|
||
|
|
default_model: str = "claude-sonnet-4-20250514",
|
||
|
|
):
|
||
|
|
super().__init__(api_key, api_base)
|
||
|
|
self.default_model = default_model
|
||
|
|
self._session: aiohttp.ClientSession | None = None
|
||
|
|
|
||
|
|
async def _get_session(self) -> aiohttp.ClientSession:
|
||
|
|
"""Get or create aiohttp session."""
|
||
|
|
if self._session is None or self._session.closed:
|
||
|
|
self._session = aiohttp.ClientSession()
|
||
|
|
return self._session
|
||
|
|
|
||
|
|
async def close(self):
|
||
|
|
"""Close the HTTP session."""
|
||
|
|
if self._session and not self._session.closed:
|
||
|
|
await self._session.close()
|
||
|
|
|
||
|
|
def _convert_messages_to_anthropic(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||
|
|
"""Convert messages to Anthropic API format."""
|
||
|
|
converted = []
|
||
|
|
for msg in messages:
|
||
|
|
role = msg.get("role")
|
||
|
|
content = msg.get("content")
|
||
|
|
|
||
|
|
# Handle tool calls in assistant messages
|
||
|
|
if role == "assistant" and msg.get("tool_calls"):
|
||
|
|
# Anthropic doesn't support tool_calls in the same way, convert to text
|
||
|
|
tool_calls_text = "\n".join([
|
||
|
|
f"Tool call: {tc.get('name')}({json.dumps(tc.get('arguments', {}))})"
|
||
|
|
for tc in msg["tool_calls"]
|
||
|
|
])
|
||
|
|
if content:
|
||
|
|
content = f"{content}\n\n{tool_calls_text}"
|
||
|
|
else:
|
||
|
|
content = tool_calls_text
|
||
|
|
|
||
|
|
# Handle tool results
|
||
|
|
if role == "tool":
|
||
|
|
# Convert tool result to Anthropic format
|
||
|
|
tool_use_id = msg.get("tool_call_id", _short_tool_id())
|
||
|
|
converted.append({
|
||
|
|
"type": "tool_result",
|
||
|
|
"tool_use_id": tool_use_id,
|
||
|
|
"content": content or "(empty)",
|
||
|
|
})
|
||
|
|
continue
|
||
|
|
|
||
|
|
# Skip system messages - they'll be handled separately
|
||
|
|
if role == "system":
|
||
|
|
continue
|
||
|
|
|
||
|
|
# Convert content to Anthropic format
|
||
|
|
if isinstance(content, str):
|
||
|
|
converted.append({
|
||
|
|
"role": role,
|
||
|
|
"content": content,
|
||
|
|
})
|
||
|
|
elif isinstance(content, list):
|
||
|
|
# Handle list content
|
||
|
|
text_parts = []
|
||
|
|
for item in content:
|
||
|
|
if isinstance(item, dict):
|
||
|
|
if item.get("type") == "text":
|
||
|
|
text_parts.append(item.get("text", ""))
|
||
|
|
elif item.get("type") == "tool_use":
|
||
|
|
# This shouldn't happen in input, but handle it
|
||
|
|
text_parts.append(f"[tool_use: {item.get('name')}]")
|
||
|
|
elif item.get("type") == "tool_result":
|
||
|
|
text_parts.append(item.get("content", ""))
|
||
|
|
converted.append({
|
||
|
|
"role": role,
|
||
|
|
"content": "\n".join(text_parts),
|
||
|
|
})
|
||
|
|
else:
|
||
|
|
converted.append({
|
||
|
|
"role": role,
|
||
|
|
"content": str(content) if content else "(empty)",
|
||
|
|
})
|
||
|
|
|
||
|
|
return converted
|
||
|
|
|
||
|
|
def _get_system_message(self, messages: list[dict[str, Any]]) -> str | None:
|
||
|
|
"""Extract system message from messages."""
|
||
|
|
for msg in messages:
|
||
|
|
if msg.get("role") == "system":
|
||
|
|
return msg.get("content")
|
||
|
|
return None
|
||
|
|
|
||
|
|
async def chat(
|
||
|
|
self,
|
||
|
|
messages: list[dict[str, Any]],
|
||
|
|
tools: list[dict[str, Any]] | None = None,
|
||
|
|
model: str | None = None,
|
||
|
|
max_tokens: int = 4096,
|
||
|
|
temperature: float = 0.7,
|
||
|
|
) -> LLMResponse:
|
||
|
|
"""Send a chat completion request to Anthropic API."""
|
||
|
|
model = model or self.default_model
|
||
|
|
api_base = self.api_base or "https://api.anthropic.com"
|
||
|
|
url = f"{api_base}/v1/messages"
|
||
|
|
|
||
|
|
headers = {
|
||
|
|
"Content-Type": "application/json",
|
||
|
|
"anthropic-version": "2023-06-01",
|
||
|
|
}
|
||
|
|
if self.api_key:
|
||
|
|
headers["x-api-key"] = self.api_key
|
||
|
|
|
||
|
|
# Get system message and convert other messages
|
||
|
|
system = self._get_system_message(messages)
|
||
|
|
anthropic_messages = self._convert_messages_to_anthropic(messages)
|
||
|
|
|
||
|
|
payload: dict[str, Any] = {
|
||
|
|
"model": model,
|
||
|
|
"messages": anthropic_messages,
|
||
|
|
"max_tokens": max_tokens,
|
||
|
|
"temperature": temperature,
|
||
|
|
}
|
||
|
|
|
||
|
|
if system:
|
||
|
|
payload["system"] = system
|
||
|
|
|
||
|
|
# Convert tools to Anthropic format if provided
|
||
|
|
if tools:
|
||
|
|
anthropic_tools = self._convert_tools(tools)
|
||
|
|
payload["tools"] = anthropic_tools
|
||
|
|
|
||
|
|
try:
|
||
|
|
session = await self._get_session()
|
||
|
|
async with session.post(url, json=payload, headers=headers) as resp:
|
||
|
|
if resp.status != 200:
|
||
|
|
error_text = await resp.text()
|
||
|
|
try:
|
||
|
|
error_json = json.loads(error_text)
|
||
|
|
error_msg = error_json.get("error", {}).get("message", error_text)
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
error_msg = error_text
|
||
|
|
return LLMResponse(
|
||
|
|
content=f"Anthropic API error (status {resp.status}): {error_msg}",
|
||
|
|
finish_reason="error",
|
||
|
|
)
|
||
|
|
|
||
|
|
data = await resp.json()
|
||
|
|
return self._parse_response(data, tools is not None)
|
||
|
|
|
||
|
|
except aiohttp.ClientError as e:
|
||
|
|
return LLMResponse(
|
||
|
|
content=f"Anthropic API connection error: {str(e)}",
|
||
|
|
finish_reason="error",
|
||
|
|
)
|
||
|
|
except Exception as e:
|
||
|
|
return LLMResponse(
|
||
|
|
content=f"Error calling Anthropic: {str(e)}",
|
||
|
|
finish_reason="error",
|
||
|
|
)
|
||
|
|
|
||
|
|
def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||
|
|
"""Convert OpenAI-style tools to Anthropic format."""
|
||
|
|
anthropic_tools = []
|
||
|
|
for tool in tools:
|
||
|
|
func = tool.get("function", {})
|
||
|
|
anthropic_tools.append({
|
||
|
|
"name": func.get("name", ""),
|
||
|
|
"description": func.get("description", ""),
|
||
|
|
"input_schema": func.get("parameters", {"type": "object", "properties": {}}),
|
||
|
|
})
|
||
|
|
return anthropic_tools
|
||
|
|
|
||
|
|
def _parse_response(self, data: dict[str, Any], has_tools: bool = False) -> LLMResponse:
|
||
|
|
"""Parse Anthropic API response into our standard format."""
|
||
|
|
content = data.get("content", [])
|
||
|
|
|
||
|
|
# Extract text content
|
||
|
|
text_content = ""
|
||
|
|
tool_calls = []
|
||
|
|
for block in content:
|
||
|
|
if block.get("type") == "text":
|
||
|
|
text_content += block.get("text", "")
|
||
|
|
elif block.get("type") == "tool_use" and has_tools:
|
||
|
|
# Convert Anthropic tool_use to our format
|
||
|
|
args = block.get("input", {})
|
||
|
|
tool_calls.append(ToolCallRequest(
|
||
|
|
id=block.get("id", _short_tool_id()),
|
||
|
|
name=block.get("name", ""),
|
||
|
|
arguments=args,
|
||
|
|
))
|
||
|
|
|
||
|
|
# Determine finish reason
|
||
|
|
stop_reason = data.get("stop_reason", "end_turn")
|
||
|
|
if stop_reason == "tool_use":
|
||
|
|
finish_reason = "tool_calls"
|
||
|
|
elif stop_reason == "max_tokens":
|
||
|
|
finish_reason = "length"
|
||
|
|
else:
|
||
|
|
finish_reason = "stop"
|
||
|
|
|
||
|
|
# Parse usage
|
||
|
|
usage = data.get("usage", {})
|
||
|
|
usage_dict = {
|
||
|
|
"prompt_tokens": usage.get("input_tokens", 0),
|
||
|
|
"completion_tokens": usage.get("output_tokens", 0),
|
||
|
|
"total_tokens": usage.get("input_tokens", 0) + usage.get("output_tokens", 0),
|
||
|
|
}
|
||
|
|
|
||
|
|
return LLMResponse(
|
||
|
|
content=text_content if text_content else None,
|
||
|
|
tool_calls=tool_calls,
|
||
|
|
finish_reason=finish_reason,
|
||
|
|
usage=usage_dict,
|
||
|
|
)
|
||
|
|
|
||
|
|
def get_default_model(self) -> str:
|
||
|
|
"""Get the default model."""
|
||
|
|
return self.default_model
|