from __future__ import annotations from unittest.mock import patch import pytest from app.services.embedding_provider import EmbeddingProvider, _runtime_model_config_from_dict from app.services.knowledge_rag_runtime import KnowledgeRagError, RuntimeModelConfig def _config(provider: str = "GLM") -> RuntimeModelConfig: return RuntimeModelConfig( slot="embedding", provider=provider, model="Embedding-3", endpoint="https://open.bigmodel.cn/api/paas/v4/", api_key="k", capability="embedding", ) def test_runtime_model_config_from_dict_maps_fields() -> None: cfg = _runtime_model_config_from_dict( { "slot": "embedding", "provider": "GLM", "model": "Embedding-3", "endpoint": "https://e", "apiKey": "secret", "capability": "embedding", } ) assert cfg.api_key == "secret" assert cfg.model == "Embedding-3" def test_embed_empty_texts_returns_empty() -> None: provider = EmbeddingProvider(_config()) assert provider.embed([]) == [] def test_embed_returns_vectors_and_caches_dimension() -> None: provider = EmbeddingProvider(_config()) with patch( "app.services.embedding_provider._request_embeddings_public", return_value=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], ) as mock_req: vectors = provider.embed(["a", "b"]) assert vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] assert provider.dimension() == 3 calls_after_first_dimension = mock_req.call_count # 第二次 dimension 不应再次请求 assert provider.dimension() == 3 assert mock_req.call_count == calls_after_first_dimension def test_dimension_raises_on_invalid_vectors() -> None: provider = EmbeddingProvider(_config()) with patch( "app.services.embedding_provider._request_embeddings_public", return_value=[], ): with pytest.raises(KnowledgeRagError): provider.dimension() def test_request_embeddings_public_glm_branch() -> None: cfg = _config("GLM") with patch( "app.services.embedding_provider._send_json_request", return_value=(200, {"data": [{"embedding": [0.1, 0.2]}]}), ) as mock_send: from app.services.embedding_provider import _request_embeddings_public vectors = _request_embeddings_public(cfg, ["x"]) assert vectors == [[0.1, 0.2]] called_url = mock_send.call_args.args[1] assert called_url.endswith("/embeddings") def test_request_embeddings_public_ollama_branch() -> None: cfg = _config("Ollama") with patch( "app.services.embedding_provider._send_json_request", return_value=(200, {"embeddings": [[0.5, 0.6]]}), ) as mock_send: from app.services.embedding_provider import _request_embeddings_public vectors = _request_embeddings_public(cfg, ["x"]) assert vectors == [[0.5, 0.6]] called_url = mock_send.call_args.args[1] assert called_url.endswith("/api/embed") def test_request_embeddings_public_raises_on_http_error() -> None: cfg = _config("GLM") with patch( "app.services.embedding_provider._send_json_request", return_value=(500, {"message": "boom"}), ): from app.services.embedding_provider import _request_embeddings_public with pytest.raises(KnowledgeRagError): _request_embeddings_public(cfg, ["x"])