Files
X-Financial/server/src/app/services/model_connectivity.py

217 lines
7.1 KiB
Python
Raw Normal View History

from __future__ import annotations
import json
from datetime import datetime, timezone
from http import HTTPStatus
from typing import Any
from urllib.error import HTTPError, URLError
from urllib.parse import quote
from urllib.request import Request, urlopen
from app.schemas.settings import ModelConnectivityTestRead, ModelConnectivityTestRequest
AZURE_API_VERSION = "2024-10-21"
DEFAULT_TIMEOUT_SECONDS = 12
class ConnectivityCheckError(Exception):
def __init__(self, message: str, status_code: int | None = None) -> None:
super().__init__(message)
self.status_code = status_code
def probe_model_connectivity(payload: ModelConnectivityTestRequest) -> ModelConnectivityTestRead:
checked_at = datetime.now(timezone.utc)
try:
if payload.provider == "Azure OpenAI":
status_code = _probe_azure_openai(payload)
elif payload.provider == "Ollama":
status_code = _probe_ollama(payload)
else:
status_code = _probe_openai_compatible(payload)
detail = f"{payload.provider} 已连接,模型 {payload.model} 可正常访问。"
return ModelConnectivityTestRead(
ok=True,
provider=payload.provider,
model=payload.model,
endpoint=payload.endpoint,
capability=payload.capability,
detail=detail,
status_code=status_code,
checked_at=checked_at,
)
except ConnectivityCheckError as exc:
return ModelConnectivityTestRead(
ok=False,
provider=payload.provider,
model=payload.model,
endpoint=payload.endpoint,
capability=payload.capability,
detail=str(exc),
status_code=exc.status_code,
checked_at=checked_at,
)
def _probe_openai_compatible(payload: ModelConnectivityTestRequest) -> int:
normalized_endpoint = _normalize_endpoint(payload.endpoint)
headers = _build_headers(api_key=payload.api_key, use_bearer=True)
if payload.capability == "embedding":
url = _ensure_path(normalized_endpoint, "embeddings")
body = {"model": payload.model, "input": "connectivity test"}
else:
url = _ensure_path(normalized_endpoint, "chat/completions")
body = {
"model": payload.model,
"messages": [{"role": "user", "content": "ping"}],
"max_tokens": 1,
}
status_code, _ = _send_json_request("POST", url, headers=headers, payload=body)
return status_code
def _probe_ollama(payload: ModelConnectivityTestRequest) -> int:
normalized_endpoint = _normalize_endpoint(payload.endpoint)
headers = _build_headers(api_key=payload.api_key, use_bearer=False)
if payload.capability == "embedding":
url = _ensure_path(normalized_endpoint, "api/embed")
body = {"model": payload.model, "input": "connectivity test"}
else:
url = _ensure_path(normalized_endpoint, "api/chat")
body = {
"model": payload.model,
"messages": [{"role": "user", "content": "ping"}],
"stream": False,
}
status_code, _ = _send_json_request("POST", url, headers=headers, payload=body)
return status_code
def _probe_azure_openai(payload: ModelConnectivityTestRequest) -> int:
deployment_base = _build_azure_deployment_base(payload.endpoint, payload.model)
headers = _build_headers(api_key=payload.api_key, use_bearer=False, use_api_key=True)
if payload.capability == "embedding":
url = f"{deployment_base}/embeddings?api-version={AZURE_API_VERSION}"
body = {"input": "connectivity test"}
else:
url = f"{deployment_base}/chat/completions?api-version={AZURE_API_VERSION}"
body = {
"messages": [{"role": "user", "content": "ping"}],
"max_tokens": 1,
}
status_code, _ = _send_json_request("POST", url, headers=headers, payload=body)
return status_code
def _build_azure_deployment_base(endpoint: str, model: str) -> str:
normalized_endpoint = _normalize_endpoint(endpoint)
quoted_model = quote(model, safe="")
if "/openai/deployments/" in normalized_endpoint:
return normalized_endpoint
if "/openai/v1" in normalized_endpoint:
resource_root = normalized_endpoint.split("/openai/v1", maxsplit=1)[0]
return f"{resource_root}/openai/deployments/{quoted_model}"
if normalized_endpoint.endswith("/openai"):
return f"{normalized_endpoint}/deployments/{quoted_model}"
return f"{normalized_endpoint}/openai/deployments/{quoted_model}"
def _build_headers(
api_key: str | None,
*,
use_bearer: bool,
use_api_key: bool = False,
) -> dict[str, str]:
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
if api_key:
if use_api_key:
headers["api-key"] = api_key
elif use_bearer:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def _normalize_endpoint(endpoint: str) -> str:
normalized = endpoint.strip()
if not normalized:
raise ConnectivityCheckError("接口地址不能为空。", status_code=HTTPStatus.BAD_REQUEST)
return normalized.rstrip("/")
def _ensure_path(endpoint: str, suffix: str) -> str:
suffix = suffix.lstrip("/")
if endpoint.endswith(suffix):
return endpoint
return f"{endpoint}/{suffix}"
def _send_json_request(
method: str,
url: str,
*,
headers: dict[str, str],
payload: dict[str, Any],
) -> tuple[int, Any]:
data = json.dumps(payload).encode("utf-8")
request = Request(url=url, data=data, headers=headers, method=method)
try:
with urlopen(request, timeout=DEFAULT_TIMEOUT_SECONDS) as response:
body = response.read().decode("utf-8") if response.length != 0 else ""
return response.status, _parse_json_body(body)
except HTTPError as exc:
body = exc.read().decode("utf-8", errors="ignore")
message = _extract_error_message(_parse_json_body(body)) or f"模型接口返回 {exc.code}"
raise ConnectivityCheckError(message, status_code=exc.code) from exc
except URLError as exc:
reason = getattr(exc, "reason", exc)
raise ConnectivityCheckError(f"无法连接到模型接口:{reason}") from exc
except TimeoutError as exc:
raise ConnectivityCheckError("模型接口连接超时,请检查地址或网络。") from exc
def _parse_json_body(body: str) -> Any:
if not body:
return None
try:
return json.loads(body)
except json.JSONDecodeError:
return {"message": body}
def _extract_error_message(payload: Any) -> str | None:
if payload is None:
return None
if isinstance(payload, dict):
if isinstance(payload.get("detail"), str):
return payload["detail"]
if isinstance(payload.get("message"), str):
return payload["message"]
error_payload = payload.get("error")
if isinstance(error_payload, dict) and isinstance(error_payload.get("message"), str):
return error_payload["message"]
if isinstance(payload, str):
return payload
return None