Files
X-Financial/server/src/app/services/model_connectivity.py
caoxiaozhu 68f663f2f4 feat: 重构知识库系统,移除Hermes集成,增强RAG和同步功能
主要变更:
- 移除Hermes智能体及相关回调服务
- 新增知识库RAG、同步、调度、规范化和索引任务服务
- 重构orchestrator服务,增强运行时聊天功能
- 更新前端聊天、政策制度、设置等页面样式和逻辑
- 更新expense_claims和document_intelligence服务
- 删除llm_wiki相关服务和测试文件
- 更新docker-compose配置和启动脚本
2026-05-17 08:38:41 +00:00

264 lines
8.8 KiB
Python

from __future__ import annotations
import json
from datetime import datetime, timezone
from http import HTTPStatus
from typing import Any
from urllib.error import HTTPError, URLError
from urllib.parse import quote
from urllib.request import Request, urlopen
from app.schemas.settings import ModelConnectivityTestRead, ModelConnectivityTestRequest
AZURE_API_VERSION = "2024-10-21"
DEFAULT_TIMEOUT_SECONDS = 12
class ConnectivityCheckError(Exception):
def __init__(self, message: str, status_code: int | None = None) -> None:
super().__init__(message)
self.status_code = status_code
def probe_model_connectivity(payload: ModelConnectivityTestRequest) -> ModelConnectivityTestRead:
checked_at = datetime.now(timezone.utc)
try:
if payload.provider == "Azure OpenAI":
status_code = _probe_azure_openai(payload)
elif payload.provider == "Ollama":
status_code = _probe_ollama(payload)
else:
status_code = _probe_openai_compatible(payload)
detail = f"{payload.provider} 已连接,模型 {payload.model} 可正常访问。"
return ModelConnectivityTestRead(
ok=True,
provider=payload.provider,
model=payload.model,
endpoint=payload.endpoint,
capability=payload.capability,
detail=detail,
status_code=status_code,
checked_at=checked_at,
)
except ConnectivityCheckError as exc:
return ModelConnectivityTestRead(
ok=False,
provider=payload.provider,
model=payload.model,
endpoint=payload.endpoint,
capability=payload.capability,
detail=str(exc),
status_code=exc.status_code,
checked_at=checked_at,
)
def _probe_openai_compatible(payload: ModelConnectivityTestRequest) -> int:
normalized_endpoint = _normalize_endpoint(payload.endpoint)
headers = _build_headers(api_key=payload.api_key, use_bearer=True)
if payload.capability == "reranker" and payload.provider == "Ali":
url, body = _build_ali_reranker_request(payload.model, normalized_endpoint)
elif payload.capability == "embedding":
url = _ensure_path(normalized_endpoint, "embeddings")
body = {"model": payload.model, "input": "connectivity test"}
elif payload.capability == "reranker":
url = _ensure_path(normalized_endpoint, "rerank")
body = {
"model": payload.model,
"query": "connectivity test",
"documents": ["sample document"],
}
else:
url = _ensure_path(normalized_endpoint, "chat/completions")
body = {
"model": payload.model,
"messages": [{"role": "user", "content": "ping"}],
"max_tokens": 1,
}
status_code, _ = _send_json_request("POST", url, headers=headers, payload=body)
return status_code
def _build_ali_reranker_request(model: str, endpoint: str) -> tuple[str, dict[str, Any]]:
normalized_model = str(model or "").strip()
if normalized_model == "qwen3-rerank":
return (
"https://dashscope.aliyuncs.com/compatible-api/v1/reranks",
{
"model": normalized_model,
"query": "connectivity test",
"documents": ["sample document"],
"top_n": 1,
},
)
return (
"https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
{
"model": normalized_model,
"input": {
"query": "connectivity test",
"documents": ["sample document"],
},
"parameters": {
"return_documents": False,
"top_n": 1,
},
},
)
def _probe_ollama(payload: ModelConnectivityTestRequest) -> int:
normalized_endpoint = _normalize_endpoint(payload.endpoint)
headers = _build_headers(api_key=payload.api_key, use_bearer=False)
if payload.capability == "embedding":
url = _ensure_path(normalized_endpoint, "api/embed")
body = {"model": payload.model, "input": "connectivity test"}
elif payload.capability == "reranker":
raise ConnectivityCheckError("Ollama 暂不支持 reranker 连通性探测。", status_code=HTTPStatus.BAD_REQUEST)
else:
url = _ensure_path(normalized_endpoint, "api/chat")
body = {
"model": payload.model,
"messages": [{"role": "user", "content": "ping"}],
"stream": False,
}
status_code, _ = _send_json_request("POST", url, headers=headers, payload=body)
return status_code
def _probe_azure_openai(payload: ModelConnectivityTestRequest) -> int:
deployment_base = _build_azure_deployment_base(payload.endpoint, payload.model)
headers = _build_headers(api_key=payload.api_key, use_bearer=False, use_api_key=True)
if payload.capability == "embedding":
url = f"{deployment_base}/embeddings?api-version={AZURE_API_VERSION}"
body = {"input": "connectivity test"}
elif payload.capability == "reranker":
url = f"{deployment_base}/rerank?api-version={AZURE_API_VERSION}"
body = {
"query": "connectivity test",
"documents": ["sample document"],
}
else:
url = f"{deployment_base}/chat/completions?api-version={AZURE_API_VERSION}"
body = {
"messages": [{"role": "user", "content": "ping"}],
"max_tokens": 1,
}
status_code, _ = _send_json_request("POST", url, headers=headers, payload=body)
return status_code
def _build_azure_deployment_base(endpoint: str, model: str) -> str:
normalized_endpoint = _normalize_endpoint(endpoint)
quoted_model = quote(model, safe="")
if "/openai/deployments/" in normalized_endpoint:
return normalized_endpoint
if "/openai/v1" in normalized_endpoint:
resource_root = normalized_endpoint.split("/openai/v1", maxsplit=1)[0]
return f"{resource_root}/openai/deployments/{quoted_model}"
if normalized_endpoint.endswith("/openai"):
return f"{normalized_endpoint}/deployments/{quoted_model}"
return f"{normalized_endpoint}/openai/deployments/{quoted_model}"
def _build_headers(
api_key: str | None,
*,
use_bearer: bool,
use_api_key: bool = False,
) -> dict[str, str]:
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
if api_key:
if use_api_key:
headers["api-key"] = api_key
elif use_bearer:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def _normalize_endpoint(endpoint: str) -> str:
normalized = endpoint.strip()
if not normalized:
raise ConnectivityCheckError("接口地址不能为空。", status_code=HTTPStatus.BAD_REQUEST)
return normalized.rstrip("/")
def _ensure_path(endpoint: str, suffix: str) -> str:
suffix = suffix.lstrip("/")
if endpoint.endswith(suffix):
return endpoint
return f"{endpoint}/{suffix}"
def _send_json_request(
method: str,
url: str,
*,
headers: dict[str, str],
payload: dict[str, Any],
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
) -> tuple[int, Any]:
data = json.dumps(payload).encode("utf-8")
request = Request(url=url, data=data, headers=headers, method=method)
try:
with urlopen(request, timeout=timeout_seconds) as response:
body = response.read().decode("utf-8") if response.length != 0 else ""
return response.status, _parse_json_body(body)
except HTTPError as exc:
body = exc.read().decode("utf-8", errors="ignore")
message = _extract_error_message(_parse_json_body(body)) or f"模型接口返回 {exc.code}"
raise ConnectivityCheckError(message, status_code=exc.code) from exc
except URLError as exc:
reason = getattr(exc, "reason", exc)
raise ConnectivityCheckError(f"无法连接到模型接口:{reason}") from exc
except TimeoutError as exc:
raise ConnectivityCheckError("模型接口连接超时,请检查地址或网络。") from exc
def _parse_json_body(body: str) -> Any:
if not body:
return None
try:
return json.loads(body)
except json.JSONDecodeError:
return {"message": body}
def _extract_error_message(payload: Any) -> str | None:
if payload is None:
return None
if isinstance(payload, dict):
if isinstance(payload.get("detail"), str):
return payload["detail"]
if isinstance(payload.get("message"), str):
return payload["message"]
error_payload = payload.get("error")
if isinstance(error_payload, dict) and isinstance(error_payload.get("message"), str):
return error_payload["message"]
if isinstance(payload, str):
return payload
return None