105 lines
3.4 KiB
Python
105 lines
3.4 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from unittest.mock import patch
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from app.services.embedding_provider import EmbeddingProvider, _runtime_model_config_from_dict
|
||
|
|
from app.services.knowledge_rag_runtime import KnowledgeRagError, RuntimeModelConfig
|
||
|
|
|
||
|
|
|
||
|
|
def _config(provider: str = "GLM") -> RuntimeModelConfig:
|
||
|
|
return RuntimeModelConfig(
|
||
|
|
slot="embedding",
|
||
|
|
provider=provider,
|
||
|
|
model="Embedding-3",
|
||
|
|
endpoint="https://open.bigmodel.cn/api/paas/v4/",
|
||
|
|
api_key="k",
|
||
|
|
capability="embedding",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def test_runtime_model_config_from_dict_maps_fields() -> None:
|
||
|
|
cfg = _runtime_model_config_from_dict(
|
||
|
|
{
|
||
|
|
"slot": "embedding",
|
||
|
|
"provider": "GLM",
|
||
|
|
"model": "Embedding-3",
|
||
|
|
"endpoint": "https://e",
|
||
|
|
"apiKey": "secret",
|
||
|
|
"capability": "embedding",
|
||
|
|
}
|
||
|
|
)
|
||
|
|
assert cfg.api_key == "secret"
|
||
|
|
assert cfg.model == "Embedding-3"
|
||
|
|
|
||
|
|
|
||
|
|
def test_embed_empty_texts_returns_empty() -> None:
|
||
|
|
provider = EmbeddingProvider(_config())
|
||
|
|
assert provider.embed([]) == []
|
||
|
|
|
||
|
|
|
||
|
|
def test_embed_returns_vectors_and_caches_dimension() -> None:
|
||
|
|
provider = EmbeddingProvider(_config())
|
||
|
|
with patch(
|
||
|
|
"app.services.embedding_provider._request_embeddings_public",
|
||
|
|
return_value=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||
|
|
) as mock_req:
|
||
|
|
vectors = provider.embed(["a", "b"])
|
||
|
|
assert vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||
|
|
assert provider.dimension() == 3
|
||
|
|
calls_after_first_dimension = mock_req.call_count
|
||
|
|
# 第二次 dimension 不应再次请求
|
||
|
|
assert provider.dimension() == 3
|
||
|
|
assert mock_req.call_count == calls_after_first_dimension
|
||
|
|
|
||
|
|
|
||
|
|
def test_dimension_raises_on_invalid_vectors() -> None:
|
||
|
|
provider = EmbeddingProvider(_config())
|
||
|
|
with patch(
|
||
|
|
"app.services.embedding_provider._request_embeddings_public",
|
||
|
|
return_value=[],
|
||
|
|
):
|
||
|
|
with pytest.raises(KnowledgeRagError):
|
||
|
|
provider.dimension()
|
||
|
|
|
||
|
|
|
||
|
|
def test_request_embeddings_public_glm_branch() -> None:
|
||
|
|
cfg = _config("GLM")
|
||
|
|
with patch(
|
||
|
|
"app.services.embedding_provider._send_json_request",
|
||
|
|
return_value=(200, {"data": [{"embedding": [0.1, 0.2]}]}),
|
||
|
|
) as mock_send:
|
||
|
|
from app.services.embedding_provider import _request_embeddings_public
|
||
|
|
|
||
|
|
vectors = _request_embeddings_public(cfg, ["x"])
|
||
|
|
assert vectors == [[0.1, 0.2]]
|
||
|
|
called_url = mock_send.call_args.args[1]
|
||
|
|
assert called_url.endswith("/embeddings")
|
||
|
|
|
||
|
|
|
||
|
|
def test_request_embeddings_public_ollama_branch() -> None:
|
||
|
|
cfg = _config("Ollama")
|
||
|
|
with patch(
|
||
|
|
"app.services.embedding_provider._send_json_request",
|
||
|
|
return_value=(200, {"embeddings": [[0.5, 0.6]]}),
|
||
|
|
) as mock_send:
|
||
|
|
from app.services.embedding_provider import _request_embeddings_public
|
||
|
|
|
||
|
|
vectors = _request_embeddings_public(cfg, ["x"])
|
||
|
|
assert vectors == [[0.5, 0.6]]
|
||
|
|
called_url = mock_send.call_args.args[1]
|
||
|
|
assert called_url.endswith("/api/embed")
|
||
|
|
|
||
|
|
|
||
|
|
def test_request_embeddings_public_raises_on_http_error() -> None:
|
||
|
|
cfg = _config("GLM")
|
||
|
|
with patch(
|
||
|
|
"app.services.embedding_provider._send_json_request",
|
||
|
|
return_value=(500, {"message": "boom"}),
|
||
|
|
):
|
||
|
|
from app.services.embedding_provider import _request_embeddings_public
|
||
|
|
|
||
|
|
with pytest.raises(KnowledgeRagError):
|
||
|
|
_request_embeddings_public(cfg, ["x"])
|