from __future__ import annotations import json from datetime import datetime, timezone from http import HTTPStatus from typing import Any from urllib.error import HTTPError, URLError from urllib.parse import quote from urllib.request import Request, urlopen from app.schemas.settings import ModelConnectivityTestRead, ModelConnectivityTestRequest AZURE_API_VERSION = "2024-10-21" DEFAULT_TIMEOUT_SECONDS = 12 class ConnectivityCheckError(Exception): def __init__(self, message: str, status_code: int | None = None) -> None: super().__init__(message) self.status_code = status_code def probe_model_connectivity(payload: ModelConnectivityTestRequest) -> ModelConnectivityTestRead: checked_at = datetime.now(timezone.utc) try: if payload.provider == "Azure OpenAI": status_code = _probe_azure_openai(payload) elif payload.provider == "Ollama": status_code = _probe_ollama(payload) else: status_code = _probe_openai_compatible(payload) detail = f"{payload.provider} 已连接,模型 {payload.model} 可正常访问。" return ModelConnectivityTestRead( ok=True, provider=payload.provider, model=payload.model, endpoint=payload.endpoint, capability=payload.capability, detail=detail, status_code=status_code, checked_at=checked_at, ) except ConnectivityCheckError as exc: return ModelConnectivityTestRead( ok=False, provider=payload.provider, model=payload.model, endpoint=payload.endpoint, capability=payload.capability, detail=str(exc), status_code=exc.status_code, checked_at=checked_at, ) def _probe_openai_compatible(payload: ModelConnectivityTestRequest) -> int: normalized_endpoint = _normalize_endpoint(payload.endpoint) headers = _build_headers(api_key=payload.api_key, use_bearer=True) if payload.capability == "reranker" and payload.provider == "Ali": url, body = _build_ali_reranker_request(payload.model, normalized_endpoint) elif payload.capability == "embedding": url = _ensure_path(normalized_endpoint, "embeddings") body = {"model": payload.model, "input": "connectivity test"} elif payload.capability == "reranker": url = _ensure_path(normalized_endpoint, "rerank") body = { "model": payload.model, "query": "connectivity test", "documents": ["sample document"], } else: url = _ensure_path(normalized_endpoint, "chat/completions") body = { "model": payload.model, "messages": [{"role": "user", "content": "ping"}], "max_tokens": 1, } status_code, _ = _send_json_request("POST", url, headers=headers, payload=body) return status_code def _build_ali_reranker_request(model: str, endpoint: str) -> tuple[str, dict[str, Any]]: normalized_model = str(model or "").strip() if normalized_model == "qwen3-rerank": return ( "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", { "model": normalized_model, "query": "connectivity test", "documents": ["sample document"], "top_n": 1, }, ) return ( "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank", { "model": normalized_model, "input": { "query": "connectivity test", "documents": ["sample document"], }, "parameters": { "return_documents": False, "top_n": 1, }, }, ) def _probe_ollama(payload: ModelConnectivityTestRequest) -> int: normalized_endpoint = _normalize_endpoint(payload.endpoint) headers = _build_headers(api_key=payload.api_key, use_bearer=False) if payload.capability == "embedding": url = _ensure_path(normalized_endpoint, "api/embed") body = {"model": payload.model, "input": "connectivity test"} elif payload.capability == "reranker": raise ConnectivityCheckError("Ollama 暂不支持 reranker 连通性探测。", status_code=HTTPStatus.BAD_REQUEST) else: url = _ensure_path(normalized_endpoint, "api/chat") body = { "model": payload.model, "messages": [{"role": "user", "content": "ping"}], "stream": False, } status_code, _ = _send_json_request("POST", url, headers=headers, payload=body) return status_code def _probe_azure_openai(payload: ModelConnectivityTestRequest) -> int: deployment_base = _build_azure_deployment_base(payload.endpoint, payload.model) headers = _build_headers(api_key=payload.api_key, use_bearer=False, use_api_key=True) if payload.capability == "embedding": url = f"{deployment_base}/embeddings?api-version={AZURE_API_VERSION}" body = {"input": "connectivity test"} elif payload.capability == "reranker": url = f"{deployment_base}/rerank?api-version={AZURE_API_VERSION}" body = { "query": "connectivity test", "documents": ["sample document"], } else: url = f"{deployment_base}/chat/completions?api-version={AZURE_API_VERSION}" body = { "messages": [{"role": "user", "content": "ping"}], "max_tokens": 1, } status_code, _ = _send_json_request("POST", url, headers=headers, payload=body) return status_code def _build_azure_deployment_base(endpoint: str, model: str) -> str: normalized_endpoint = _normalize_endpoint(endpoint) quoted_model = quote(model, safe="") if "/openai/deployments/" in normalized_endpoint: return normalized_endpoint if "/openai/v1" in normalized_endpoint: resource_root = normalized_endpoint.split("/openai/v1", maxsplit=1)[0] return f"{resource_root}/openai/deployments/{quoted_model}" if normalized_endpoint.endswith("/openai"): return f"{normalized_endpoint}/deployments/{quoted_model}" return f"{normalized_endpoint}/openai/deployments/{quoted_model}" def _build_headers( api_key: str | None, *, use_bearer: bool, use_api_key: bool = False, ) -> dict[str, str]: headers = { "Content-Type": "application/json", "Accept": "application/json", } if api_key: if use_api_key: headers["api-key"] = api_key elif use_bearer: headers["Authorization"] = f"Bearer {api_key}" return headers def _normalize_endpoint(endpoint: str) -> str: normalized = endpoint.strip() if not normalized: raise ConnectivityCheckError("接口地址不能为空。", status_code=HTTPStatus.BAD_REQUEST) return normalized.rstrip("/") def _ensure_path(endpoint: str, suffix: str) -> str: suffix = suffix.lstrip("/") if endpoint.endswith(suffix): return endpoint return f"{endpoint}/{suffix}" def _send_json_request( method: str, url: str, *, headers: dict[str, str], payload: dict[str, Any], timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS, ) -> tuple[int, Any]: data = json.dumps(payload).encode("utf-8") request = Request(url=url, data=data, headers=headers, method=method) try: with urlopen(request, timeout=timeout_seconds) as response: body = response.read().decode("utf-8") if response.length != 0 else "" return response.status, _parse_json_body(body) except HTTPError as exc: body = exc.read().decode("utf-8", errors="ignore") message = _extract_error_message(_parse_json_body(body)) or f"模型接口返回 {exc.code}。" raise ConnectivityCheckError(message, status_code=exc.code) from exc except URLError as exc: reason = getattr(exc, "reason", exc) raise ConnectivityCheckError(f"无法连接到模型接口:{reason}") from exc except TimeoutError as exc: raise ConnectivityCheckError("模型接口连接超时,请检查地址或网络。") from exc def _parse_json_body(body: str) -> Any: if not body: return None try: return json.loads(body) except json.JSONDecodeError: return {"message": body} def _extract_error_message(payload: Any) -> str | None: if payload is None: return None if isinstance(payload, dict): if isinstance(payload.get("detail"), str): return payload["detail"] if isinstance(payload.get("message"), str): return payload["message"] error_payload = payload.get("error") if isinstance(error_payload, dict) and isinstance(error_payload.get("message"), str): return error_payload["message"] if isinstance(payload, str): return payload return None