Files
X-Financial/server/src/app/services/runtime_chat.py

253 lines
7.6 KiB
Python
Raw Normal View History

from __future__ import annotations
from http import HTTPStatus
from typing import Any
from sqlalchemy.orm import Session
from app.core.logging import get_logger
from app.services.model_connectivity import (
AZURE_API_VERSION,
ConnectivityCheckError,
_build_azure_deployment_base,
_build_headers,
_ensure_path,
_normalize_endpoint,
_send_json_request,
)
from app.services.settings import SettingsService
logger = get_logger("app.services.runtime_chat")
class RuntimeChatService:
def __init__(self, db: Session) -> None:
self.db = db
self.settings_service = SettingsService(db)
def complete(
self,
messages: list[dict[str, str]],
*,
slot_priority: tuple[str, ...] = ("main", "backup"),
max_tokens: int = 500,
temperature: float = 0.2,
) -> str | None:
for slot in slot_priority:
config = self._load_chat_slot(slot)
if config is None:
continue
try:
response_text = self._request_chat_completion(
config,
messages,
max_tokens=max_tokens,
temperature=temperature,
)
except Exception as exc:
logger.warning(
"Runtime chat request failed slot=%s provider=%s: %s",
slot,
config["provider"],
exc,
)
continue
if response_text:
return response_text.strip()
return None
def _load_chat_slot(self, slot: str) -> dict[str, str] | None:
try:
config = self.settings_service.get_runtime_model_config(slot)
except ValueError:
return None
if config["capability"] != "chat":
return None
provider = str(config["provider"] or "").strip()
endpoint = str(config["endpoint"] or "").strip()
model = str(config["model"] or "").strip()
api_key = str(config["apiKey"] or "").strip()
if not provider or not endpoint or not model:
return None
if provider != "Ollama" and not api_key:
logger.info("Skip runtime chat slot=%s because api key is empty", slot)
return None
return {
"slot": slot,
"provider": provider,
"endpoint": endpoint,
"model": model,
"apiKey": api_key,
}
def _request_chat_completion(
self,
config: dict[str, str],
messages: list[dict[str, str]],
*,
max_tokens: int,
temperature: float,
) -> str:
provider = config["provider"]
endpoint = config["endpoint"]
model = config["model"]
api_key = config["apiKey"]
if provider == "Azure OpenAI":
return self._request_azure_openai(
endpoint=endpoint,
model=model,
api_key=api_key,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)
if provider == "Ollama":
return self._request_ollama(
endpoint=endpoint,
model=model,
api_key=api_key,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)
return self._request_openai_compatible(
endpoint=endpoint,
model=model,
api_key=api_key,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)
def _request_openai_compatible(
self,
*,
endpoint: str,
model: str,
api_key: str,
messages: list[dict[str, str]],
max_tokens: int,
temperature: float,
) -> str:
url = _ensure_path(_normalize_endpoint(endpoint), "chat/completions")
status_code, payload = _send_json_request(
"POST",
url,
headers=_build_headers(api_key=api_key, use_bearer=True),
payload={
"model": model,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
},
)
if status_code >= HTTPStatus.BAD_REQUEST:
raise ConnectivityCheckError(
f"模型接口返回异常状态 {status_code}",
status_code=status_code,
)
return self._extract_openai_text(payload)
def _request_ollama(
self,
*,
endpoint: str,
model: str,
api_key: str,
messages: list[dict[str, str]],
max_tokens: int,
temperature: float,
) -> str:
url = _ensure_path(_normalize_endpoint(endpoint), "api/chat")
status_code, payload = _send_json_request(
"POST",
url,
headers=_build_headers(api_key=api_key, use_bearer=False),
payload={
"model": model,
"messages": messages,
"stream": False,
"options": {
"num_predict": max_tokens,
"temperature": temperature,
},
},
)
if status_code >= HTTPStatus.BAD_REQUEST:
raise ConnectivityCheckError(
f"Ollama 返回异常状态 {status_code}",
status_code=status_code,
)
return str((payload or {}).get("message", {}).get("content", "")).strip()
def _request_azure_openai(
self,
*,
endpoint: str,
model: str,
api_key: str,
messages: list[dict[str, str]],
max_tokens: int,
temperature: float,
) -> str:
deployment_base = _build_azure_deployment_base(endpoint, model)
url = f"{deployment_base}/chat/completions?api-version={AZURE_API_VERSION}"
status_code, payload = _send_json_request(
"POST",
url,
headers=_build_headers(api_key=api_key, use_bearer=False, use_api_key=True),
payload={
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
},
)
if status_code >= HTTPStatus.BAD_REQUEST:
raise ConnectivityCheckError(
f"Azure OpenAI 返回异常状态 {status_code}",
status_code=status_code,
)
return self._extract_openai_text(payload)
@staticmethod
def _extract_openai_text(payload: Any) -> str:
if not isinstance(payload, dict):
return ""
choices = payload.get("choices")
if not isinstance(choices, list) or not choices:
return ""
first_choice = choices[0]
if not isinstance(first_choice, dict):
return ""
message = first_choice.get("message")
if isinstance(message, dict):
content = message.get("content", "")
if isinstance(content, str):
return content.strip()
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
parts.append(str(item.get("text", "")))
return "\n".join(part.strip() for part in parts if part.strip()).strip()
text = first_choice.get("text")
if isinstance(text, str):
return text.strip()
return ""