87 lines
2.8 KiB
Python
87 lines
2.8 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
|
|
INPUT_TOKEN_USD_RATE = 0.000003
|
|
OUTPUT_TOKEN_USD_RATE = 0.000015
|
|
DEFAULT_COST_THRESHOLDS = {
|
|
"total_tokens": 4000,
|
|
"estimated_cost": 0.02,
|
|
}
|
|
|
|
|
|
def estimate_token_cost(input_tokens: int, output_tokens: int) -> float | None:
|
|
total_tokens = max(input_tokens, 0) + max(output_tokens, 0)
|
|
if total_tokens <= 0:
|
|
return None
|
|
return round(
|
|
(max(input_tokens, 0) * INPUT_TOKEN_USD_RATE)
|
|
+ (max(output_tokens, 0) * OUTPUT_TOKEN_USD_RATE),
|
|
6,
|
|
)
|
|
|
|
|
|
def extract_token_usage(response: Any) -> tuple[int, int]:
|
|
usage_metadata = getattr(response, "usage_metadata", None) or {}
|
|
if isinstance(usage_metadata, dict):
|
|
input_tokens = int(
|
|
usage_metadata.get("input_tokens")
|
|
or usage_metadata.get("prompt_tokens")
|
|
or 0
|
|
)
|
|
output_tokens = int(
|
|
usage_metadata.get("output_tokens")
|
|
or usage_metadata.get("completion_tokens")
|
|
or 0
|
|
)
|
|
if input_tokens or output_tokens:
|
|
return input_tokens, output_tokens
|
|
|
|
response_metadata = getattr(response, "response_metadata", None) or {}
|
|
token_usage = {}
|
|
if isinstance(response_metadata, dict):
|
|
token_usage = response_metadata.get("token_usage") or response_metadata.get("usage") or {}
|
|
if isinstance(token_usage, dict):
|
|
input_tokens = int(
|
|
token_usage.get("prompt_tokens")
|
|
or token_usage.get("input_tokens")
|
|
or 0
|
|
)
|
|
output_tokens = int(
|
|
token_usage.get("completion_tokens")
|
|
or token_usage.get("output_tokens")
|
|
or 0
|
|
)
|
|
if input_tokens or output_tokens:
|
|
return input_tokens, output_tokens
|
|
|
|
return 0, 0
|
|
|
|
|
|
def coerce_cost_thresholds(raw_thresholds: Any) -> dict[str, float]:
|
|
thresholds: dict[str, float] = dict(DEFAULT_COST_THRESHOLDS)
|
|
if not isinstance(raw_thresholds, dict):
|
|
return thresholds
|
|
for key in DEFAULT_COST_THRESHOLDS:
|
|
value = raw_thresholds.get(key)
|
|
if isinstance(value, (int, float)) and value > 0:
|
|
thresholds[key] = float(value)
|
|
return thresholds
|
|
|
|
|
|
def is_cost_budget_warning(
|
|
input_tokens: int,
|
|
output_tokens: int,
|
|
estimated_cost: float | None,
|
|
thresholds: dict[str, float] | None = None,
|
|
) -> bool:
|
|
effective_thresholds = thresholds or DEFAULT_COST_THRESHOLDS
|
|
total_tokens = max(input_tokens, 0) + max(output_tokens, 0)
|
|
token_threshold = float(effective_thresholds.get("total_tokens") or 0)
|
|
cost_threshold = float(effective_thresholds.get("estimated_cost") or 0)
|
|
return (
|
|
(token_threshold > 0 and total_tokens >= token_threshold)
|
|
or (cost_threshold > 0 and estimated_cost is not None and estimated_cost >= cost_threshold)
|
|
)
|