Files
X-Financial/server/tests/test_embedding_provider.py
caoxiaozhu 52d57c3be7 test(flywheel): 补 few-shot 飞轮单测并沉淀开发文档
- embedding_provider:GLM/Ollama 分支、维度缓存、HTTP 错误降级
- few_shot_ingestion:confirmed/false_positive 入库、ignored 跳过、幂等去重、
  create_feedback hook 触发、feature flag、吞异常
- few_shot_retrieval:去重、token 预算、超长截断;prompt 注入合并 examples + 向后兼容
- 容器内新增测试 20 passed;回归测试 35 passed(RAG/risk_observations/rule_generation)
- 沉淀 document/development/2026-07-03/feature/ai-data-flywheel 概念文档与 TODO,
  飞轮 1 已勾选证据,飞轮 2-6 待后续迭代
2026-07-03 13:56:21 +08:00

105 lines
3.4 KiB
Python

from __future__ import annotations
from unittest.mock import patch
import pytest
from app.services.embedding_provider import EmbeddingProvider, _runtime_model_config_from_dict
from app.services.knowledge_rag_runtime import KnowledgeRagError, RuntimeModelConfig
def _config(provider: str = "GLM") -> RuntimeModelConfig:
return RuntimeModelConfig(
slot="embedding",
provider=provider,
model="Embedding-3",
endpoint="https://open.bigmodel.cn/api/paas/v4/",
api_key="k",
capability="embedding",
)
def test_runtime_model_config_from_dict_maps_fields() -> None:
cfg = _runtime_model_config_from_dict(
{
"slot": "embedding",
"provider": "GLM",
"model": "Embedding-3",
"endpoint": "https://e",
"apiKey": "secret",
"capability": "embedding",
}
)
assert cfg.api_key == "secret"
assert cfg.model == "Embedding-3"
def test_embed_empty_texts_returns_empty() -> None:
provider = EmbeddingProvider(_config())
assert provider.embed([]) == []
def test_embed_returns_vectors_and_caches_dimension() -> None:
provider = EmbeddingProvider(_config())
with patch(
"app.services.embedding_provider._request_embeddings_public",
return_value=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
) as mock_req:
vectors = provider.embed(["a", "b"])
assert vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
assert provider.dimension() == 3
calls_after_first_dimension = mock_req.call_count
# 第二次 dimension 不应再次请求
assert provider.dimension() == 3
assert mock_req.call_count == calls_after_first_dimension
def test_dimension_raises_on_invalid_vectors() -> None:
provider = EmbeddingProvider(_config())
with patch(
"app.services.embedding_provider._request_embeddings_public",
return_value=[],
):
with pytest.raises(KnowledgeRagError):
provider.dimension()
def test_request_embeddings_public_glm_branch() -> None:
cfg = _config("GLM")
with patch(
"app.services.embedding_provider._send_json_request",
return_value=(200, {"data": [{"embedding": [0.1, 0.2]}]}),
) as mock_send:
from app.services.embedding_provider import _request_embeddings_public
vectors = _request_embeddings_public(cfg, ["x"])
assert vectors == [[0.1, 0.2]]
called_url = mock_send.call_args.args[1]
assert called_url.endswith("/embeddings")
def test_request_embeddings_public_ollama_branch() -> None:
cfg = _config("Ollama")
with patch(
"app.services.embedding_provider._send_json_request",
return_value=(200, {"embeddings": [[0.5, 0.6]]}),
) as mock_send:
from app.services.embedding_provider import _request_embeddings_public
vectors = _request_embeddings_public(cfg, ["x"])
assert vectors == [[0.5, 0.6]]
called_url = mock_send.call_args.args[1]
assert called_url.endswith("/api/embed")
def test_request_embeddings_public_raises_on_http_error() -> None:
cfg = _config("GLM")
with patch(
"app.services.embedding_provider._send_json_request",
return_value=(500, {"message": "boom"}),
):
from app.services.embedding_provider import _request_embeddings_public
with pytest.raises(KnowledgeRagError):
_request_embeddings_public(cfg, ["x"])