from __future__ import annotations from typing import Any INPUT_TOKEN_USD_RATE = 0.000003 OUTPUT_TOKEN_USD_RATE = 0.000015 DEFAULT_COST_THRESHOLDS = { "total_tokens": 4000, "estimated_cost": 0.02, } def estimate_token_cost(input_tokens: int, output_tokens: int) -> float | None: total_tokens = max(input_tokens, 0) + max(output_tokens, 0) if total_tokens <= 0: return None return round( (max(input_tokens, 0) * INPUT_TOKEN_USD_RATE) + (max(output_tokens, 0) * OUTPUT_TOKEN_USD_RATE), 6, ) def extract_token_usage(response: Any) -> tuple[int, int]: usage_metadata = getattr(response, "usage_metadata", None) or {} if isinstance(usage_metadata, dict): input_tokens = int( usage_metadata.get("input_tokens") or usage_metadata.get("prompt_tokens") or 0 ) output_tokens = int( usage_metadata.get("output_tokens") or usage_metadata.get("completion_tokens") or 0 ) if input_tokens or output_tokens: return input_tokens, output_tokens response_metadata = getattr(response, "response_metadata", None) or {} token_usage = {} if isinstance(response_metadata, dict): token_usage = response_metadata.get("token_usage") or response_metadata.get("usage") or {} if isinstance(token_usage, dict): input_tokens = int( token_usage.get("prompt_tokens") or token_usage.get("input_tokens") or 0 ) output_tokens = int( token_usage.get("completion_tokens") or token_usage.get("output_tokens") or 0 ) if input_tokens or output_tokens: return input_tokens, output_tokens return 0, 0 def coerce_cost_thresholds(raw_thresholds: Any) -> dict[str, float]: thresholds: dict[str, float] = dict(DEFAULT_COST_THRESHOLDS) if not isinstance(raw_thresholds, dict): return thresholds for key in DEFAULT_COST_THRESHOLDS: value = raw_thresholds.get(key) if isinstance(value, (int, float)) and value > 0: thresholds[key] = float(value) return thresholds def is_cost_budget_warning( input_tokens: int, output_tokens: int, estimated_cost: float | None, thresholds: dict[str, float] | None = None, ) -> bool: effective_thresholds = thresholds or DEFAULT_COST_THRESHOLDS total_tokens = max(input_tokens, 0) + max(output_tokens, 0) token_threshold = float(effective_thresholds.get("total_tokens") or 0) cost_threshold = float(effective_thresholds.get("estimated_cost") or 0) return ( (token_threshold > 0 and total_tokens >= token_threshold) or (cost_threshold > 0 and estimated_cost is not None and estimated_cost >= cost_threshold) )