refactor(server): split oversized backend services
This commit is contained in:
672
server/src/app/services/knowledge_rag_runtime.py
Normal file
672
server/src/app/services/knowledge_rag_runtime.py
Normal file
@@ -0,0 +1,672 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from time import perf_counter
|
||||
from typing import Any
|
||||
from urllib.error import HTTPError, URLError
|
||||
from urllib.parse import quote
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.services.model_connectivity import AZURE_API_VERSION
|
||||
|
||||
logger = get_logger("app.services.knowledge_rag")
|
||||
|
||||
DEFAULT_LIGHTRAG_QUERY_MODE = "naive"
|
||||
DEFAULT_LLM_TIMEOUT_SECONDS = 180
|
||||
DEFAULT_EMBEDDING_TIMEOUT_SECONDS = 120
|
||||
|
||||
class KnowledgeRagError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RuntimeModelConfig:
|
||||
slot: str
|
||||
provider: str
|
||||
model: str
|
||||
endpoint: str
|
||||
api_key: str
|
||||
capability: str
|
||||
|
||||
|
||||
class _LightRagRuntime:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
working_dir: Path,
|
||||
workspace: str,
|
||||
qdrant_url: str,
|
||||
qdrant_api_key: str,
|
||||
primary_chat: RuntimeModelConfig,
|
||||
backup_chat: RuntimeModelConfig | None,
|
||||
embedding: RuntimeModelConfig,
|
||||
reranker: RuntimeModelConfig | None,
|
||||
) -> None:
|
||||
self.working_dir = working_dir
|
||||
self.workspace = workspace
|
||||
self.qdrant_url = qdrant_url
|
||||
self.qdrant_api_key = qdrant_api_key
|
||||
self.primary_chat = primary_chat
|
||||
self.backup_chat = backup_chat
|
||||
self.embedding = embedding
|
||||
self.reranker = reranker
|
||||
self._rag = self._build_rag()
|
||||
self._initialize()
|
||||
self._graph_has_content_cache: bool | None = None
|
||||
|
||||
@property
|
||||
def rag(self):
|
||||
return self._rag
|
||||
|
||||
def _build_rag(self):
|
||||
try:
|
||||
from lightrag import LightRAG
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
except ImportError as exc: # pragma: no cover - exercised in runtime env
|
||||
raise KnowledgeRagError(
|
||||
"LightRAG 依赖未安装,请先在 server 环境执行依赖安装。"
|
||||
) from exc
|
||||
|
||||
self.working_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self.qdrant_url:
|
||||
os.environ["QDRANT_URL"] = self.qdrant_url
|
||||
if self.qdrant_api_key:
|
||||
os.environ["QDRANT_API_KEY"] = self.qdrant_api_key
|
||||
|
||||
embedding_dim = self._probe_embedding_dimension(self.embedding)
|
||||
logger.info(
|
||||
"Initialize LightRAG runtime workspace=%s qdrant=%s embedding_model=%s dim=%s",
|
||||
self.workspace,
|
||||
self.qdrant_url,
|
||||
self.embedding.model,
|
||||
embedding_dim,
|
||||
)
|
||||
|
||||
async def embedding_func(texts: list[str]) -> Any:
|
||||
return await asyncio.to_thread(self._embed_sync, texts)
|
||||
|
||||
async def llm_model_func(
|
||||
prompt: str,
|
||||
system_prompt: str | None = None,
|
||||
history_messages: list[dict[str, Any]] | None = None,
|
||||
keyword_extraction: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
return await asyncio.to_thread(
|
||||
self._complete_sync,
|
||||
prompt,
|
||||
system_prompt,
|
||||
history_messages or [],
|
||||
keyword_extraction,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
async def rerank_model_func(
|
||||
query: str,
|
||||
documents: list[str],
|
||||
top_n: int | None = None,
|
||||
**_kwargs: Any,
|
||||
) -> list[dict[str, Any]]:
|
||||
return await asyncio.to_thread(
|
||||
self._rerank_sync,
|
||||
query,
|
||||
documents,
|
||||
top_n,
|
||||
)
|
||||
|
||||
return LightRAG(
|
||||
working_dir=str(self.working_dir),
|
||||
workspace=self.workspace,
|
||||
kv_storage="JsonKVStorage",
|
||||
graph_storage="NetworkXStorage",
|
||||
vector_storage="QdrantVectorDBStorage",
|
||||
doc_status_storage="JsonDocStatusStorage",
|
||||
llm_model_name=self.primary_chat.model,
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=embedding_dim,
|
||||
func=embedding_func,
|
||||
max_token_size=8192,
|
||||
model_name=self.embedding.model,
|
||||
supports_asymmetric=False,
|
||||
),
|
||||
rerank_model_func=rerank_model_func if self.reranker is not None else None,
|
||||
enable_llm_cache=False,
|
||||
enable_llm_cache_for_entity_extract=False,
|
||||
)
|
||||
|
||||
def _initialize(self) -> None:
|
||||
from lightrag.utils import always_get_an_event_loop
|
||||
|
||||
loop = always_get_an_event_loop()
|
||||
loop.run_until_complete(self._rag.initialize_storages())
|
||||
|
||||
def finalize(self) -> None:
|
||||
from lightrag.utils import always_get_an_event_loop
|
||||
|
||||
loop = always_get_an_event_loop()
|
||||
loop.run_until_complete(self._rag.finalize_storages())
|
||||
|
||||
def query_data(self, query: str, *, conversation_history: list[dict[str, str]] | None = None) -> dict[str, Any]:
|
||||
from lightrag import QueryParam
|
||||
|
||||
configured_mode = os.environ.get("LIGHTRAG_QUERY_MODE", DEFAULT_LIGHTRAG_QUERY_MODE).strip() or DEFAULT_LIGHTRAG_QUERY_MODE
|
||||
mode = "naive" if configured_mode != "naive" and not self._graph_has_content() else configured_mode
|
||||
started_at = perf_counter()
|
||||
param = QueryParam(
|
||||
mode=mode,
|
||||
top_k=8,
|
||||
chunk_top_k=10,
|
||||
only_need_context=True,
|
||||
response_type="Multiple Paragraphs",
|
||||
conversation_history=conversation_history or [],
|
||||
include_references=True,
|
||||
)
|
||||
try:
|
||||
result = self._rag.query_data(query, param)
|
||||
logger.info("LightRAG query completed mode=%s elapsed=%.2fs", mode, perf_counter() - started_at)
|
||||
return result
|
||||
except Exception:
|
||||
if mode == "naive":
|
||||
raise
|
||||
logger.warning("LightRAG query mode=%s failed, retry with naive mode", mode)
|
||||
fallback_param = QueryParam(
|
||||
mode="naive",
|
||||
top_k=8,
|
||||
chunk_top_k=10,
|
||||
only_need_context=True,
|
||||
response_type="Multiple Paragraphs",
|
||||
conversation_history=conversation_history or [],
|
||||
include_references=True,
|
||||
)
|
||||
result = self._rag.query_data(query, fallback_param)
|
||||
logger.info("LightRAG query completed mode=naive elapsed=%.2fs", perf_counter() - started_at)
|
||||
return result
|
||||
|
||||
def _graph_has_content(self) -> bool:
|
||||
if self._graph_has_content_cache is not None:
|
||||
return self._graph_has_content_cache
|
||||
|
||||
graph_path = self.working_dir / self.workspace / "graph_chunk_entity_relation.graphml"
|
||||
try:
|
||||
graph_text = graph_path.read_text(encoding="utf-8")
|
||||
except OSError:
|
||||
self._graph_has_content_cache = False
|
||||
return False
|
||||
|
||||
self._graph_has_content_cache = "<node " in graph_text or "<edge " in graph_text
|
||||
return self._graph_has_content_cache
|
||||
|
||||
def insert_documents(
|
||||
self,
|
||||
*,
|
||||
texts: list[str],
|
||||
document_ids: list[str],
|
||||
file_paths: list[str],
|
||||
) -> str:
|
||||
return self._rag.insert(texts, ids=document_ids, file_paths=file_paths)
|
||||
|
||||
def get_document_statuses(self, document_ids: list[str]) -> dict[str, Any]:
|
||||
from lightrag.utils import always_get_an_event_loop
|
||||
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self._rag.aget_docs_by_ids(document_ids))
|
||||
|
||||
def delete_document(self, document_id: str) -> None:
|
||||
from lightrag.utils import always_get_an_event_loop
|
||||
|
||||
loop = always_get_an_event_loop()
|
||||
result = loop.run_until_complete(self._rag.adelete_by_doc_id(document_id))
|
||||
status = str(getattr(result, "status", "") or "")
|
||||
if status not in {"success", "not_found"}:
|
||||
raise KnowledgeRagError(str(getattr(result, "message", "") or "LightRAG 删除文档失败。"))
|
||||
|
||||
def _probe_embedding_dimension(self, config: RuntimeModelConfig) -> int:
|
||||
vectors = self._request_embeddings(config, ["dimension probe"])
|
||||
if not vectors or not isinstance(vectors[0], list):
|
||||
raise KnowledgeRagError("无法从 embedding 模型返回结果中解析向量维度。")
|
||||
dimension = len(vectors[0])
|
||||
if dimension <= 0:
|
||||
raise KnowledgeRagError("embedding 模型返回了无效的向量维度。")
|
||||
return dimension
|
||||
|
||||
def _embed_sync(self, texts: list[str]) -> Any:
|
||||
import numpy as np
|
||||
|
||||
vectors = self._request_embeddings(self.embedding, texts)
|
||||
return np.array(vectors, dtype=float)
|
||||
|
||||
def _rerank_sync(
|
||||
self,
|
||||
query: str,
|
||||
documents: list[str],
|
||||
top_n: int | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
if self.reranker is None:
|
||||
return []
|
||||
|
||||
status_code, body = self._request_rerank(
|
||||
self.reranker,
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
)
|
||||
if status_code >= HTTPStatus.BAD_REQUEST:
|
||||
raise KnowledgeRagError(f"reranker 模型返回异常状态码 {status_code}。")
|
||||
return _extract_rerank_results(body, provider=self.reranker.provider)
|
||||
|
||||
def _complete_sync(
|
||||
self,
|
||||
prompt: str,
|
||||
system_prompt: str | None,
|
||||
history_messages: list[dict[str, Any]],
|
||||
keyword_extraction: bool,
|
||||
kwargs: dict[str, Any],
|
||||
) -> str:
|
||||
del keyword_extraction
|
||||
|
||||
last_error: Exception | None = None
|
||||
for config in [self.primary_chat, self.backup_chat]:
|
||||
if config is None:
|
||||
continue
|
||||
try:
|
||||
return self._request_chat_completion(
|
||||
config,
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
max_tokens=int(kwargs.get("max_tokens") or 1200),
|
||||
temperature=float(kwargs.get("temperature") or 0.1),
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - runtime fallback
|
||||
last_error = exc
|
||||
logger.warning(
|
||||
"LightRAG LLM request failed slot=%s provider=%s model=%s: %s",
|
||||
config.slot,
|
||||
config.provider,
|
||||
config.model,
|
||||
exc,
|
||||
)
|
||||
continue
|
||||
|
||||
raise KnowledgeRagError(f"LightRAG 调用知识模型失败:{last_error or '没有可用模型配置'}")
|
||||
|
||||
def _request_chat_completion(
|
||||
self,
|
||||
config: RuntimeModelConfig,
|
||||
*,
|
||||
prompt: str,
|
||||
system_prompt: str | None,
|
||||
history_messages: list[dict[str, Any]],
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
) -> str:
|
||||
messages: list[dict[str, Any]] = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
if config.provider == "Azure OpenAI":
|
||||
url = f"{_build_azure_deployment_base(config.endpoint, config.model)}/chat/completions?api-version={AZURE_API_VERSION}"
|
||||
payload = {
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
status_code, body = _send_json_request(
|
||||
"POST",
|
||||
url,
|
||||
headers=_build_headers(config.api_key, use_bearer=False, use_api_key=True),
|
||||
payload=payload,
|
||||
timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS,
|
||||
)
|
||||
elif config.provider == "Ollama":
|
||||
url = _ensure_path(_normalize_endpoint(config.endpoint), "api/chat")
|
||||
payload = {
|
||||
"model": config.model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"num_predict": max_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
}
|
||||
status_code, body = _send_json_request(
|
||||
"POST",
|
||||
url,
|
||||
headers={"Content-Type": "application/json", "Accept": "application/json"},
|
||||
payload=payload,
|
||||
timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS,
|
||||
)
|
||||
else:
|
||||
url = _ensure_path(_normalize_endpoint(config.endpoint), "chat/completions")
|
||||
payload = {
|
||||
"model": config.model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
status_code, body = _send_json_request(
|
||||
"POST",
|
||||
url,
|
||||
headers=_build_headers(config.api_key, use_bearer=True),
|
||||
payload=payload,
|
||||
timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
if status_code >= HTTPStatus.BAD_REQUEST:
|
||||
raise KnowledgeRagError(f"知识模型返回异常状态码 {status_code}。")
|
||||
|
||||
return _extract_chat_text(body, provider=config.provider)
|
||||
|
||||
def _request_embeddings(self, config: RuntimeModelConfig, texts: list[str]) -> list[list[float]]:
|
||||
if config.provider == "Azure OpenAI":
|
||||
url = f"{_build_azure_deployment_base(config.endpoint, config.model)}/embeddings?api-version={AZURE_API_VERSION}"
|
||||
payload = {"input": texts}
|
||||
status_code, body = _send_json_request(
|
||||
"POST",
|
||||
url,
|
||||
headers=_build_headers(config.api_key, use_bearer=False, use_api_key=True),
|
||||
payload=payload,
|
||||
timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
|
||||
)
|
||||
elif config.provider == "Ollama":
|
||||
url = _ensure_path(_normalize_endpoint(config.endpoint), "api/embed")
|
||||
payload = {"model": config.model, "input": texts}
|
||||
status_code, body = _send_json_request(
|
||||
"POST",
|
||||
url,
|
||||
headers={"Content-Type": "application/json", "Accept": "application/json"},
|
||||
payload=payload,
|
||||
timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
|
||||
)
|
||||
else:
|
||||
url = _ensure_path(_normalize_endpoint(config.endpoint), "embeddings")
|
||||
payload = {"model": config.model, "input": texts}
|
||||
status_code, body = _send_json_request(
|
||||
"POST",
|
||||
url,
|
||||
headers=_build_headers(config.api_key, use_bearer=True),
|
||||
payload=payload,
|
||||
timeout_seconds=DEFAULT_EMBEDDING_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
if status_code >= HTTPStatus.BAD_REQUEST:
|
||||
raise KnowledgeRagError(f"embedding 模型返回异常状态码 {status_code}。")
|
||||
|
||||
return _extract_embedding_vectors(body, provider=config.provider)
|
||||
|
||||
def _request_rerank(
|
||||
self,
|
||||
config: RuntimeModelConfig,
|
||||
*,
|
||||
query: str,
|
||||
documents: list[str],
|
||||
top_n: int | None,
|
||||
) -> tuple[int, Any]:
|
||||
if config.provider == "Azure OpenAI":
|
||||
url = f"{_build_azure_deployment_base(config.endpoint, config.model)}/rerank?api-version={AZURE_API_VERSION}"
|
||||
payload: dict[str, Any] = {
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
}
|
||||
if top_n is not None:
|
||||
payload["top_n"] = top_n
|
||||
return _send_json_request(
|
||||
"POST",
|
||||
url,
|
||||
headers=_build_headers(config.api_key, use_bearer=False, use_api_key=True),
|
||||
payload=payload,
|
||||
timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
if config.provider == "Ali":
|
||||
url, payload = _build_ali_rerank_request(
|
||||
config.model,
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
)
|
||||
return _send_json_request(
|
||||
"POST",
|
||||
url,
|
||||
headers=_build_headers(config.api_key, use_bearer=True),
|
||||
payload=payload,
|
||||
timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
url = _ensure_path(_normalize_endpoint(config.endpoint), "rerank")
|
||||
payload = {
|
||||
"model": config.model,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
}
|
||||
if top_n is not None:
|
||||
payload["top_n"] = top_n
|
||||
return _send_json_request(
|
||||
"POST",
|
||||
url,
|
||||
headers=_build_headers(config.api_key, use_bearer=True),
|
||||
payload=payload,
|
||||
timeout_seconds=DEFAULT_LLM_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
def _normalize_endpoint(endpoint: str) -> str:
|
||||
normalized = str(endpoint or "").strip()
|
||||
if not normalized:
|
||||
raise KnowledgeRagError("模型 endpoint 不能为空。")
|
||||
return normalized.rstrip("/")
|
||||
|
||||
|
||||
def _ensure_path(endpoint: str, suffix: str) -> str:
|
||||
suffix = suffix.lstrip("/")
|
||||
if endpoint.endswith(suffix):
|
||||
return endpoint
|
||||
return f"{endpoint}/{suffix}"
|
||||
|
||||
|
||||
def _build_azure_deployment_base(endpoint: str, model: str) -> str:
|
||||
normalized_endpoint = _normalize_endpoint(endpoint)
|
||||
quoted_model = quote(model, safe="")
|
||||
if "/openai/deployments/" in normalized_endpoint:
|
||||
return normalized_endpoint
|
||||
if "/openai/v1" in normalized_endpoint:
|
||||
resource_root = normalized_endpoint.split("/openai/v1", maxsplit=1)[0]
|
||||
return f"{resource_root}/openai/deployments/{quoted_model}"
|
||||
if normalized_endpoint.endswith("/openai"):
|
||||
return f"{normalized_endpoint}/deployments/{quoted_model}"
|
||||
return f"{normalized_endpoint}/openai/deployments/{quoted_model}"
|
||||
|
||||
|
||||
def _build_headers(
|
||||
api_key: str,
|
||||
*,
|
||||
use_bearer: bool,
|
||||
use_api_key: bool = False,
|
||||
) -> dict[str, str]:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
normalized_key = str(api_key or "").strip()
|
||||
if normalized_key:
|
||||
if use_api_key:
|
||||
headers["api-key"] = normalized_key
|
||||
elif use_bearer:
|
||||
headers["Authorization"] = f"Bearer {normalized_key}"
|
||||
return headers
|
||||
|
||||
|
||||
def _send_json_request(
|
||||
method: str,
|
||||
url: str,
|
||||
*,
|
||||
headers: dict[str, str],
|
||||
payload: dict[str, Any],
|
||||
timeout_seconds: int,
|
||||
) -> tuple[int, Any]:
|
||||
data = json.dumps(payload).encode("utf-8")
|
||||
request = Request(url=url, data=data, headers=headers, method=method)
|
||||
|
||||
try:
|
||||
with urlopen(request, timeout=timeout_seconds) as response: # noqa: S310
|
||||
body = response.read().decode("utf-8") if response.length != 0 else ""
|
||||
return response.status, _parse_json_body(body)
|
||||
except HTTPError as exc: # pragma: no cover - runtime path
|
||||
body = exc.read().decode("utf-8", errors="ignore")
|
||||
detail = _extract_error_message(_parse_json_body(body)) or f"接口返回 {exc.code}"
|
||||
raise KnowledgeRagError(detail) from exc
|
||||
except URLError as exc: # pragma: no cover - runtime path
|
||||
raise KnowledgeRagError(f"无法连接模型接口:{getattr(exc, 'reason', exc)}") from exc
|
||||
except TimeoutError as exc: # pragma: no cover - runtime path
|
||||
raise KnowledgeRagError("模型接口调用超时。") from exc
|
||||
|
||||
|
||||
def _parse_json_body(body: str) -> Any:
|
||||
if not body:
|
||||
return None
|
||||
try:
|
||||
return json.loads(body)
|
||||
except json.JSONDecodeError:
|
||||
return {"message": body}
|
||||
|
||||
|
||||
def _extract_error_message(payload: Any) -> str | None:
|
||||
if payload is None:
|
||||
return None
|
||||
if isinstance(payload, dict):
|
||||
if isinstance(payload.get("detail"), str):
|
||||
return payload["detail"]
|
||||
if isinstance(payload.get("message"), str):
|
||||
return payload["message"]
|
||||
error_payload = payload.get("error")
|
||||
if isinstance(error_payload, dict) and isinstance(error_payload.get("message"), str):
|
||||
return error_payload["message"]
|
||||
if isinstance(payload, str):
|
||||
return payload
|
||||
return None
|
||||
|
||||
|
||||
def _extract_chat_text(payload: Any, *, provider: str) -> str:
|
||||
if provider == "Ollama":
|
||||
message = payload.get("message") if isinstance(payload, dict) else None
|
||||
if isinstance(message, dict):
|
||||
return str(message.get("content") or "").strip()
|
||||
return ""
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
return ""
|
||||
choices = payload.get("choices")
|
||||
if not isinstance(choices, list) or not choices:
|
||||
return ""
|
||||
first_choice = choices[0]
|
||||
if not isinstance(first_choice, dict):
|
||||
return ""
|
||||
message = first_choice.get("message")
|
||||
if isinstance(message, dict):
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
parts.append(str(item.get("text") or "").strip())
|
||||
return "\n".join(part for part in parts if part).strip()
|
||||
text = first_choice.get("text")
|
||||
if isinstance(text, str):
|
||||
return text.strip()
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_embedding_vectors(payload: Any, *, provider: str) -> list[list[float]]:
|
||||
if provider == "Ollama":
|
||||
embeddings = payload.get("embeddings") if isinstance(payload, dict) else None
|
||||
if isinstance(embeddings, list):
|
||||
return [[float(value) for value in item] for item in embeddings if isinstance(item, list)]
|
||||
embedding = payload.get("embedding") if isinstance(payload, dict) else None
|
||||
if isinstance(embedding, list):
|
||||
return [[float(value) for value in embedding]]
|
||||
raise KnowledgeRagError("Ollama embedding 返回格式无法识别。")
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
raise KnowledgeRagError("embedding 接口返回格式无效。")
|
||||
data = payload.get("data")
|
||||
if not isinstance(data, list) or not data:
|
||||
raise KnowledgeRagError("embedding 接口没有返回 data。")
|
||||
vectors: list[list[float]] = []
|
||||
for item in data:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
embedding = item.get("embedding")
|
||||
if isinstance(embedding, list):
|
||||
vectors.append([float(value) for value in embedding])
|
||||
if not vectors:
|
||||
raise KnowledgeRagError("embedding 接口返回中未找到向量数据。")
|
||||
return vectors
|
||||
|
||||
|
||||
def _build_ali_rerank_request(
|
||||
model: str,
|
||||
*,
|
||||
query: str,
|
||||
documents: list[str],
|
||||
top_n: int | None,
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
normalized_model = str(model or "").strip()
|
||||
if normalized_model == "qwen3-rerank":
|
||||
payload: dict[str, Any] = {
|
||||
"model": normalized_model,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
}
|
||||
if top_n is not None:
|
||||
payload["top_n"] = top_n
|
||||
return "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", payload
|
||||
|
||||
payload = {
|
||||
"model": normalized_model,
|
||||
"input": {
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
},
|
||||
"parameters": {
|
||||
"return_documents": False,
|
||||
},
|
||||
}
|
||||
if top_n is not None:
|
||||
payload["parameters"]["top_n"] = top_n
|
||||
return "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank", payload
|
||||
|
||||
|
||||
def _extract_rerank_results(payload: Any, *, provider: str) -> list[dict[str, Any]]:
|
||||
if not isinstance(payload, dict):
|
||||
return []
|
||||
if provider == "Ali" and isinstance(payload.get("output"), dict):
|
||||
results = payload["output"].get("results")
|
||||
else:
|
||||
results = payload.get("results")
|
||||
if not isinstance(results, list):
|
||||
return []
|
||||
normalized: list[dict[str, Any]] = []
|
||||
for item in results:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
try:
|
||||
normalized.append(
|
||||
{
|
||||
"index": int(item["index"]),
|
||||
"relevance_score": float(item["relevance_score"]),
|
||||
}
|
||||
)
|
||||
except (KeyError, TypeError, ValueError):
|
||||
continue
|
||||
return normalized
|
||||
Reference in New Issue
Block a user