182 lines
5.5 KiB
Python
182 lines
5.5 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import tempfile
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import yaml
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class HermesModelRoute:
|
|
model: str
|
|
endpoint: str
|
|
api_key: str = ""
|
|
provider_label: str = ""
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class HermesConfigSnapshot:
|
|
path: Path
|
|
existed: bool
|
|
content: bytes
|
|
|
|
|
|
def get_hermes_home() -> Path:
|
|
configured_home = str(os.getenv("HERMES_HOME", "")).strip()
|
|
if configured_home:
|
|
return Path(configured_home).expanduser()
|
|
return Path.home() / ".hermes"
|
|
|
|
|
|
def get_hermes_config_path() -> Path:
|
|
return get_hermes_home() / "config.yaml"
|
|
|
|
|
|
def capture_hermes_config_snapshot(config_path: Path | None = None) -> HermesConfigSnapshot:
|
|
target_path = config_path or get_hermes_config_path()
|
|
if not target_path.exists():
|
|
return HermesConfigSnapshot(path=target_path, existed=False, content=b"")
|
|
return HermesConfigSnapshot(path=target_path, existed=True, content=target_path.read_bytes())
|
|
|
|
|
|
def restore_hermes_config_snapshot(snapshot: HermesConfigSnapshot) -> None:
|
|
snapshot.path.parent.mkdir(parents=True, exist_ok=True)
|
|
if snapshot.existed:
|
|
snapshot.path.write_bytes(snapshot.content)
|
|
return
|
|
if snapshot.path.exists():
|
|
snapshot.path.unlink()
|
|
|
|
|
|
def sync_hermes_model_settings(
|
|
primary_route: HermesModelRoute,
|
|
fallback_route: HermesModelRoute | None = None,
|
|
config_path: Path | None = None,
|
|
) -> Path:
|
|
target_path = config_path or get_hermes_config_path()
|
|
target_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
config = _load_existing_config(target_path)
|
|
config["model"] = _build_primary_model_config(primary_route, existing_model_config=config.get("model"))
|
|
|
|
if fallback_route is None:
|
|
config.pop("fallback_model", None)
|
|
else:
|
|
config["fallback_model"] = _build_fallback_model_config(
|
|
fallback_route,
|
|
existing_fallback_config=config.get("fallback_model"),
|
|
)
|
|
|
|
_atomic_yaml_write(target_path, config)
|
|
return target_path
|
|
|
|
|
|
def _load_existing_config(config_path: Path) -> dict[str, Any]:
|
|
if not config_path.exists():
|
|
return {}
|
|
|
|
raw_content = config_path.read_text(encoding="utf-8")
|
|
if not raw_content.strip():
|
|
return {}
|
|
|
|
loaded = yaml.safe_load(raw_content)
|
|
if loaded is None:
|
|
return {}
|
|
if not isinstance(loaded, dict):
|
|
raise ValueError(f"Hermes 配置文件格式无效: {config_path}")
|
|
return dict(loaded)
|
|
|
|
|
|
def _build_primary_model_config(
|
|
route: HermesModelRoute,
|
|
*,
|
|
existing_model_config: Any = None,
|
|
) -> dict[str, Any]:
|
|
normalized_model = route.model.strip()
|
|
normalized_endpoint = route.endpoint.strip().rstrip("/")
|
|
if not normalized_model or not normalized_endpoint:
|
|
raise ValueError("Hermes 主模型同步失败:模型名称或接口地址为空。")
|
|
|
|
api_mode = _infer_api_mode(route)
|
|
payload: dict[str, Any] = {
|
|
"provider": "custom",
|
|
"default": normalized_model,
|
|
"base_url": normalized_endpoint,
|
|
}
|
|
existing_api_key = _extract_existing_api_key(existing_model_config)
|
|
if route.api_key.strip():
|
|
payload["api_key"] = route.api_key.strip()
|
|
elif existing_api_key:
|
|
payload["api_key"] = existing_api_key
|
|
if api_mode != "chat_completions":
|
|
payload["api_mode"] = api_mode
|
|
else:
|
|
payload.pop("api_mode", None)
|
|
return payload
|
|
|
|
|
|
def _build_fallback_model_config(
|
|
route: HermesModelRoute,
|
|
*,
|
|
existing_fallback_config: Any = None,
|
|
) -> dict[str, Any]:
|
|
normalized_model = route.model.strip()
|
|
normalized_endpoint = route.endpoint.strip().rstrip("/")
|
|
if not normalized_model or not normalized_endpoint:
|
|
raise ValueError("Hermes 备份模型同步失败:模型名称或接口地址为空。")
|
|
|
|
api_mode = _infer_api_mode(route)
|
|
payload: dict[str, Any] = {
|
|
"provider": "custom",
|
|
"model": normalized_model,
|
|
"base_url": normalized_endpoint,
|
|
}
|
|
existing_api_key = _extract_existing_api_key(existing_fallback_config)
|
|
if route.api_key.strip():
|
|
payload["api_key"] = route.api_key.strip()
|
|
elif existing_api_key:
|
|
payload["api_key"] = existing_api_key
|
|
if api_mode != "chat_completions":
|
|
payload["api_mode"] = api_mode
|
|
else:
|
|
payload.pop("api_mode", None)
|
|
return payload
|
|
|
|
|
|
def _extract_existing_api_key(config_section: Any) -> str:
|
|
if not isinstance(config_section, dict):
|
|
return ""
|
|
api_key = config_section.get("api_key")
|
|
if not isinstance(api_key, str):
|
|
return ""
|
|
return api_key.strip()
|
|
|
|
|
|
def _infer_api_mode(route: HermesModelRoute) -> str:
|
|
provider_label = route.provider_label.strip().casefold()
|
|
endpoint = route.endpoint.strip().lower().rstrip("/")
|
|
|
|
if provider_label == "claude" or "anthropic.com" in endpoint or endpoint.endswith("/anthropic"):
|
|
return "anthropic_messages"
|
|
if "api.openai.com" in endpoint or "api.x.ai" in endpoint:
|
|
return "codex_responses"
|
|
return "chat_completions"
|
|
|
|
|
|
def _atomic_yaml_write(target_path: Path, payload: dict[str, Any]) -> None:
|
|
serialized = yaml.safe_dump(payload, sort_keys=False, allow_unicode=True)
|
|
with tempfile.NamedTemporaryFile(
|
|
mode="w",
|
|
encoding="utf-8",
|
|
dir=str(target_path.parent),
|
|
prefix=f".{target_path.name}.",
|
|
delete=False,
|
|
) as temp_file:
|
|
temp_file.write(serialized)
|
|
temp_path = Path(temp_file.name)
|
|
|
|
temp_path.replace(target_path)
|