feat: add system settings with model connectivity and encrypted storage
This commit is contained in:
216
server/src/app/services/model_connectivity.py
Normal file
216
server/src/app/services/model_connectivity.py
Normal file
@@ -0,0 +1,216 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from urllib.error import HTTPError, URLError
|
||||
from urllib.parse import quote
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
from app.schemas.settings import ModelConnectivityTestRead, ModelConnectivityTestRequest
|
||||
|
||||
AZURE_API_VERSION = "2024-10-21"
|
||||
DEFAULT_TIMEOUT_SECONDS = 12
|
||||
|
||||
|
||||
class ConnectivityCheckError(Exception):
|
||||
def __init__(self, message: str, status_code: int | None = None) -> None:
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
def probe_model_connectivity(payload: ModelConnectivityTestRequest) -> ModelConnectivityTestRead:
|
||||
checked_at = datetime.now(timezone.utc)
|
||||
|
||||
try:
|
||||
if payload.provider == "Azure OpenAI":
|
||||
status_code = _probe_azure_openai(payload)
|
||||
elif payload.provider == "Ollama":
|
||||
status_code = _probe_ollama(payload)
|
||||
else:
|
||||
status_code = _probe_openai_compatible(payload)
|
||||
|
||||
detail = f"{payload.provider} 已连接,模型 {payload.model} 可正常访问。"
|
||||
return ModelConnectivityTestRead(
|
||||
ok=True,
|
||||
provider=payload.provider,
|
||||
model=payload.model,
|
||||
endpoint=payload.endpoint,
|
||||
capability=payload.capability,
|
||||
detail=detail,
|
||||
status_code=status_code,
|
||||
checked_at=checked_at,
|
||||
)
|
||||
except ConnectivityCheckError as exc:
|
||||
return ModelConnectivityTestRead(
|
||||
ok=False,
|
||||
provider=payload.provider,
|
||||
model=payload.model,
|
||||
endpoint=payload.endpoint,
|
||||
capability=payload.capability,
|
||||
detail=str(exc),
|
||||
status_code=exc.status_code,
|
||||
checked_at=checked_at,
|
||||
)
|
||||
|
||||
|
||||
def _probe_openai_compatible(payload: ModelConnectivityTestRequest) -> int:
|
||||
normalized_endpoint = _normalize_endpoint(payload.endpoint)
|
||||
headers = _build_headers(api_key=payload.api_key, use_bearer=True)
|
||||
|
||||
if payload.capability == "embedding":
|
||||
url = _ensure_path(normalized_endpoint, "embeddings")
|
||||
body = {"model": payload.model, "input": "connectivity test"}
|
||||
else:
|
||||
url = _ensure_path(normalized_endpoint, "chat/completions")
|
||||
body = {
|
||||
"model": payload.model,
|
||||
"messages": [{"role": "user", "content": "ping"}],
|
||||
"max_tokens": 1,
|
||||
}
|
||||
|
||||
status_code, _ = _send_json_request("POST", url, headers=headers, payload=body)
|
||||
return status_code
|
||||
|
||||
|
||||
def _probe_ollama(payload: ModelConnectivityTestRequest) -> int:
|
||||
normalized_endpoint = _normalize_endpoint(payload.endpoint)
|
||||
headers = _build_headers(api_key=payload.api_key, use_bearer=False)
|
||||
|
||||
if payload.capability == "embedding":
|
||||
url = _ensure_path(normalized_endpoint, "api/embed")
|
||||
body = {"model": payload.model, "input": "connectivity test"}
|
||||
else:
|
||||
url = _ensure_path(normalized_endpoint, "api/chat")
|
||||
body = {
|
||||
"model": payload.model,
|
||||
"messages": [{"role": "user", "content": "ping"}],
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
status_code, _ = _send_json_request("POST", url, headers=headers, payload=body)
|
||||
return status_code
|
||||
|
||||
|
||||
def _probe_azure_openai(payload: ModelConnectivityTestRequest) -> int:
|
||||
deployment_base = _build_azure_deployment_base(payload.endpoint, payload.model)
|
||||
headers = _build_headers(api_key=payload.api_key, use_bearer=False, use_api_key=True)
|
||||
|
||||
if payload.capability == "embedding":
|
||||
url = f"{deployment_base}/embeddings?api-version={AZURE_API_VERSION}"
|
||||
body = {"input": "connectivity test"}
|
||||
else:
|
||||
url = f"{deployment_base}/chat/completions?api-version={AZURE_API_VERSION}"
|
||||
body = {
|
||||
"messages": [{"role": "user", "content": "ping"}],
|
||||
"max_tokens": 1,
|
||||
}
|
||||
|
||||
status_code, _ = _send_json_request("POST", url, headers=headers, payload=body)
|
||||
return status_code
|
||||
|
||||
|
||||
def _build_azure_deployment_base(endpoint: str, model: str) -> str:
|
||||
normalized_endpoint = _normalize_endpoint(endpoint)
|
||||
quoted_model = quote(model, safe="")
|
||||
|
||||
if "/openai/deployments/" in normalized_endpoint:
|
||||
return normalized_endpoint
|
||||
|
||||
if "/openai/v1" in normalized_endpoint:
|
||||
resource_root = normalized_endpoint.split("/openai/v1", maxsplit=1)[0]
|
||||
return f"{resource_root}/openai/deployments/{quoted_model}"
|
||||
|
||||
if normalized_endpoint.endswith("/openai"):
|
||||
return f"{normalized_endpoint}/deployments/{quoted_model}"
|
||||
|
||||
return f"{normalized_endpoint}/openai/deployments/{quoted_model}"
|
||||
|
||||
|
||||
def _build_headers(
|
||||
api_key: str | None,
|
||||
*,
|
||||
use_bearer: bool,
|
||||
use_api_key: bool = False,
|
||||
) -> dict[str, str]:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
if api_key:
|
||||
if use_api_key:
|
||||
headers["api-key"] = api_key
|
||||
elif use_bearer:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
def _normalize_endpoint(endpoint: str) -> str:
|
||||
normalized = endpoint.strip()
|
||||
if not normalized:
|
||||
raise ConnectivityCheckError("接口地址不能为空。", status_code=HTTPStatus.BAD_REQUEST)
|
||||
return normalized.rstrip("/")
|
||||
|
||||
|
||||
def _ensure_path(endpoint: str, suffix: str) -> str:
|
||||
suffix = suffix.lstrip("/")
|
||||
if endpoint.endswith(suffix):
|
||||
return endpoint
|
||||
return f"{endpoint}/{suffix}"
|
||||
|
||||
|
||||
def _send_json_request(
|
||||
method: str,
|
||||
url: str,
|
||||
*,
|
||||
headers: dict[str, str],
|
||||
payload: dict[str, Any],
|
||||
) -> tuple[int, Any]:
|
||||
data = json.dumps(payload).encode("utf-8")
|
||||
request = Request(url=url, data=data, headers=headers, method=method)
|
||||
|
||||
try:
|
||||
with urlopen(request, timeout=DEFAULT_TIMEOUT_SECONDS) as response:
|
||||
body = response.read().decode("utf-8") if response.length != 0 else ""
|
||||
return response.status, _parse_json_body(body)
|
||||
except HTTPError as exc:
|
||||
body = exc.read().decode("utf-8", errors="ignore")
|
||||
message = _extract_error_message(_parse_json_body(body)) or f"模型接口返回 {exc.code}。"
|
||||
raise ConnectivityCheckError(message, status_code=exc.code) from exc
|
||||
except URLError as exc:
|
||||
reason = getattr(exc, "reason", exc)
|
||||
raise ConnectivityCheckError(f"无法连接到模型接口:{reason}") from exc
|
||||
except TimeoutError as exc:
|
||||
raise ConnectivityCheckError("模型接口连接超时,请检查地址或网络。") from exc
|
||||
|
||||
|
||||
def _parse_json_body(body: str) -> Any:
|
||||
if not body:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(body)
|
||||
except json.JSONDecodeError:
|
||||
return {"message": body}
|
||||
|
||||
|
||||
def _extract_error_message(payload: Any) -> str | None:
|
||||
if payload is None:
|
||||
return None
|
||||
|
||||
if isinstance(payload, dict):
|
||||
if isinstance(payload.get("detail"), str):
|
||||
return payload["detail"]
|
||||
if isinstance(payload.get("message"), str):
|
||||
return payload["message"]
|
||||
error_payload = payload.get("error")
|
||||
if isinstance(error_payload, dict) and isinstance(error_payload.get("message"), str):
|
||||
return error_payload["message"]
|
||||
|
||||
if isinstance(payload, str):
|
||||
return payload
|
||||
|
||||
return None
|
||||
Reference in New Issue
Block a user