217 lines
7.1 KiB
Python
217 lines
7.1 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from datetime import datetime, timezone
|
|
from http import HTTPStatus
|
|
from typing import Any
|
|
from urllib.error import HTTPError, URLError
|
|
from urllib.parse import quote
|
|
from urllib.request import Request, urlopen
|
|
|
|
from app.schemas.settings import ModelConnectivityTestRead, ModelConnectivityTestRequest
|
|
|
|
AZURE_API_VERSION = "2024-10-21"
|
|
DEFAULT_TIMEOUT_SECONDS = 12
|
|
|
|
|
|
class ConnectivityCheckError(Exception):
|
|
def __init__(self, message: str, status_code: int | None = None) -> None:
|
|
super().__init__(message)
|
|
self.status_code = status_code
|
|
|
|
|
|
def probe_model_connectivity(payload: ModelConnectivityTestRequest) -> ModelConnectivityTestRead:
|
|
checked_at = datetime.now(timezone.utc)
|
|
|
|
try:
|
|
if payload.provider == "Azure OpenAI":
|
|
status_code = _probe_azure_openai(payload)
|
|
elif payload.provider == "Ollama":
|
|
status_code = _probe_ollama(payload)
|
|
else:
|
|
status_code = _probe_openai_compatible(payload)
|
|
|
|
detail = f"{payload.provider} 已连接,模型 {payload.model} 可正常访问。"
|
|
return ModelConnectivityTestRead(
|
|
ok=True,
|
|
provider=payload.provider,
|
|
model=payload.model,
|
|
endpoint=payload.endpoint,
|
|
capability=payload.capability,
|
|
detail=detail,
|
|
status_code=status_code,
|
|
checked_at=checked_at,
|
|
)
|
|
except ConnectivityCheckError as exc:
|
|
return ModelConnectivityTestRead(
|
|
ok=False,
|
|
provider=payload.provider,
|
|
model=payload.model,
|
|
endpoint=payload.endpoint,
|
|
capability=payload.capability,
|
|
detail=str(exc),
|
|
status_code=exc.status_code,
|
|
checked_at=checked_at,
|
|
)
|
|
|
|
|
|
def _probe_openai_compatible(payload: ModelConnectivityTestRequest) -> int:
|
|
normalized_endpoint = _normalize_endpoint(payload.endpoint)
|
|
headers = _build_headers(api_key=payload.api_key, use_bearer=True)
|
|
|
|
if payload.capability == "embedding":
|
|
url = _ensure_path(normalized_endpoint, "embeddings")
|
|
body = {"model": payload.model, "input": "connectivity test"}
|
|
else:
|
|
url = _ensure_path(normalized_endpoint, "chat/completions")
|
|
body = {
|
|
"model": payload.model,
|
|
"messages": [{"role": "user", "content": "ping"}],
|
|
"max_tokens": 1,
|
|
}
|
|
|
|
status_code, _ = _send_json_request("POST", url, headers=headers, payload=body)
|
|
return status_code
|
|
|
|
|
|
def _probe_ollama(payload: ModelConnectivityTestRequest) -> int:
|
|
normalized_endpoint = _normalize_endpoint(payload.endpoint)
|
|
headers = _build_headers(api_key=payload.api_key, use_bearer=False)
|
|
|
|
if payload.capability == "embedding":
|
|
url = _ensure_path(normalized_endpoint, "api/embed")
|
|
body = {"model": payload.model, "input": "connectivity test"}
|
|
else:
|
|
url = _ensure_path(normalized_endpoint, "api/chat")
|
|
body = {
|
|
"model": payload.model,
|
|
"messages": [{"role": "user", "content": "ping"}],
|
|
"stream": False,
|
|
}
|
|
|
|
status_code, _ = _send_json_request("POST", url, headers=headers, payload=body)
|
|
return status_code
|
|
|
|
|
|
def _probe_azure_openai(payload: ModelConnectivityTestRequest) -> int:
|
|
deployment_base = _build_azure_deployment_base(payload.endpoint, payload.model)
|
|
headers = _build_headers(api_key=payload.api_key, use_bearer=False, use_api_key=True)
|
|
|
|
if payload.capability == "embedding":
|
|
url = f"{deployment_base}/embeddings?api-version={AZURE_API_VERSION}"
|
|
body = {"input": "connectivity test"}
|
|
else:
|
|
url = f"{deployment_base}/chat/completions?api-version={AZURE_API_VERSION}"
|
|
body = {
|
|
"messages": [{"role": "user", "content": "ping"}],
|
|
"max_tokens": 1,
|
|
}
|
|
|
|
status_code, _ = _send_json_request("POST", url, headers=headers, payload=body)
|
|
return status_code
|
|
|
|
|
|
def _build_azure_deployment_base(endpoint: str, model: str) -> str:
|
|
normalized_endpoint = _normalize_endpoint(endpoint)
|
|
quoted_model = quote(model, safe="")
|
|
|
|
if "/openai/deployments/" in normalized_endpoint:
|
|
return normalized_endpoint
|
|
|
|
if "/openai/v1" in normalized_endpoint:
|
|
resource_root = normalized_endpoint.split("/openai/v1", maxsplit=1)[0]
|
|
return f"{resource_root}/openai/deployments/{quoted_model}"
|
|
|
|
if normalized_endpoint.endswith("/openai"):
|
|
return f"{normalized_endpoint}/deployments/{quoted_model}"
|
|
|
|
return f"{normalized_endpoint}/openai/deployments/{quoted_model}"
|
|
|
|
|
|
def _build_headers(
|
|
api_key: str | None,
|
|
*,
|
|
use_bearer: bool,
|
|
use_api_key: bool = False,
|
|
) -> dict[str, str]:
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Accept": "application/json",
|
|
}
|
|
|
|
if api_key:
|
|
if use_api_key:
|
|
headers["api-key"] = api_key
|
|
elif use_bearer:
|
|
headers["Authorization"] = f"Bearer {api_key}"
|
|
|
|
return headers
|
|
|
|
|
|
def _normalize_endpoint(endpoint: str) -> str:
|
|
normalized = endpoint.strip()
|
|
if not normalized:
|
|
raise ConnectivityCheckError("接口地址不能为空。", status_code=HTTPStatus.BAD_REQUEST)
|
|
return normalized.rstrip("/")
|
|
|
|
|
|
def _ensure_path(endpoint: str, suffix: str) -> str:
|
|
suffix = suffix.lstrip("/")
|
|
if endpoint.endswith(suffix):
|
|
return endpoint
|
|
return f"{endpoint}/{suffix}"
|
|
|
|
|
|
def _send_json_request(
|
|
method: str,
|
|
url: str,
|
|
*,
|
|
headers: dict[str, str],
|
|
payload: dict[str, Any],
|
|
) -> tuple[int, Any]:
|
|
data = json.dumps(payload).encode("utf-8")
|
|
request = Request(url=url, data=data, headers=headers, method=method)
|
|
|
|
try:
|
|
with urlopen(request, timeout=DEFAULT_TIMEOUT_SECONDS) as response:
|
|
body = response.read().decode("utf-8") if response.length != 0 else ""
|
|
return response.status, _parse_json_body(body)
|
|
except HTTPError as exc:
|
|
body = exc.read().decode("utf-8", errors="ignore")
|
|
message = _extract_error_message(_parse_json_body(body)) or f"模型接口返回 {exc.code}。"
|
|
raise ConnectivityCheckError(message, status_code=exc.code) from exc
|
|
except URLError as exc:
|
|
reason = getattr(exc, "reason", exc)
|
|
raise ConnectivityCheckError(f"无法连接到模型接口:{reason}") from exc
|
|
except TimeoutError as exc:
|
|
raise ConnectivityCheckError("模型接口连接超时,请检查地址或网络。") from exc
|
|
|
|
|
|
def _parse_json_body(body: str) -> Any:
|
|
if not body:
|
|
return None
|
|
|
|
try:
|
|
return json.loads(body)
|
|
except json.JSONDecodeError:
|
|
return {"message": body}
|
|
|
|
|
|
def _extract_error_message(payload: Any) -> str | None:
|
|
if payload is None:
|
|
return None
|
|
|
|
if isinstance(payload, dict):
|
|
if isinstance(payload.get("detail"), str):
|
|
return payload["detail"]
|
|
if isinstance(payload.get("message"), str):
|
|
return payload["message"]
|
|
error_payload = payload.get("error")
|
|
if isinstance(error_payload, dict) and isinstance(error_payload.get("message"), str):
|
|
return error_payload["message"]
|
|
|
|
if isinstance(payload, str):
|
|
return payload
|
|
|
|
return None
|