Files
X-Financial/server/tests/test_embedding_provider.py

105 lines
3.4 KiB
Python
Raw Normal View History

from __future__ import annotations
from unittest.mock import patch
import pytest
from app.services.embedding_provider import EmbeddingProvider, _runtime_model_config_from_dict
from app.services.knowledge_rag_runtime import KnowledgeRagError, RuntimeModelConfig
def _config(provider: str = "GLM") -> RuntimeModelConfig:
return RuntimeModelConfig(
slot="embedding",
provider=provider,
model="Embedding-3",
endpoint="https://open.bigmodel.cn/api/paas/v4/",
api_key="k",
capability="embedding",
)
def test_runtime_model_config_from_dict_maps_fields() -> None:
cfg = _runtime_model_config_from_dict(
{
"slot": "embedding",
"provider": "GLM",
"model": "Embedding-3",
"endpoint": "https://e",
"apiKey": "secret",
"capability": "embedding",
}
)
assert cfg.api_key == "secret"
assert cfg.model == "Embedding-3"
def test_embed_empty_texts_returns_empty() -> None:
provider = EmbeddingProvider(_config())
assert provider.embed([]) == []
def test_embed_returns_vectors_and_caches_dimension() -> None:
provider = EmbeddingProvider(_config())
with patch(
"app.services.embedding_provider._request_embeddings_public",
return_value=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
) as mock_req:
vectors = provider.embed(["a", "b"])
assert vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
assert provider.dimension() == 3
calls_after_first_dimension = mock_req.call_count
# 第二次 dimension 不应再次请求
assert provider.dimension() == 3
assert mock_req.call_count == calls_after_first_dimension
def test_dimension_raises_on_invalid_vectors() -> None:
provider = EmbeddingProvider(_config())
with patch(
"app.services.embedding_provider._request_embeddings_public",
return_value=[],
):
with pytest.raises(KnowledgeRagError):
provider.dimension()
def test_request_embeddings_public_glm_branch() -> None:
cfg = _config("GLM")
with patch(
"app.services.embedding_provider._send_json_request",
return_value=(200, {"data": [{"embedding": [0.1, 0.2]}]}),
) as mock_send:
from app.services.embedding_provider import _request_embeddings_public
vectors = _request_embeddings_public(cfg, ["x"])
assert vectors == [[0.1, 0.2]]
called_url = mock_send.call_args.args[1]
assert called_url.endswith("/embeddings")
def test_request_embeddings_public_ollama_branch() -> None:
cfg = _config("Ollama")
with patch(
"app.services.embedding_provider._send_json_request",
return_value=(200, {"embeddings": [[0.5, 0.6]]}),
) as mock_send:
from app.services.embedding_provider import _request_embeddings_public
vectors = _request_embeddings_public(cfg, ["x"])
assert vectors == [[0.5, 0.6]]
called_url = mock_send.call_args.args[1]
assert called_url.endswith("/api/embed")
def test_request_embeddings_public_raises_on_http_error() -> None:
cfg = _config("GLM")
with patch(
"app.services.embedding_provider._send_json_request",
return_value=(500, {"message": "boom"}),
):
from app.services.embedding_provider import _request_embeddings_public
with pytest.raises(KnowledgeRagError):
_request_embeddings_public(cfg, ["x"])