test(flywheel): 补 few-shot 飞轮单测并沉淀开发文档
- embedding_provider:GLM/Ollama 分支、维度缓存、HTTP 错误降级 - few_shot_ingestion:confirmed/false_positive 入库、ignored 跳过、幂等去重、 create_feedback hook 触发、feature flag、吞异常 - few_shot_retrieval:去重、token 预算、超长截断;prompt 注入合并 examples + 向后兼容 - 容器内新增测试 20 passed;回归测试 35 passed(RAG/risk_observations/rule_generation) - 沉淀 document/development/2026-07-03/feature/ai-data-flywheel 概念文档与 TODO, 飞轮 1 已勾选证据,飞轮 2-6 待后续迭代
This commit is contained in:
104
server/tests/test_embedding_provider.py
Normal file
104
server/tests/test_embedding_provider.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.embedding_provider import EmbeddingProvider, _runtime_model_config_from_dict
|
||||
from app.services.knowledge_rag_runtime import KnowledgeRagError, RuntimeModelConfig
|
||||
|
||||
|
||||
def _config(provider: str = "GLM") -> RuntimeModelConfig:
|
||||
return RuntimeModelConfig(
|
||||
slot="embedding",
|
||||
provider=provider,
|
||||
model="Embedding-3",
|
||||
endpoint="https://open.bigmodel.cn/api/paas/v4/",
|
||||
api_key="k",
|
||||
capability="embedding",
|
||||
)
|
||||
|
||||
|
||||
def test_runtime_model_config_from_dict_maps_fields() -> None:
|
||||
cfg = _runtime_model_config_from_dict(
|
||||
{
|
||||
"slot": "embedding",
|
||||
"provider": "GLM",
|
||||
"model": "Embedding-3",
|
||||
"endpoint": "https://e",
|
||||
"apiKey": "secret",
|
||||
"capability": "embedding",
|
||||
}
|
||||
)
|
||||
assert cfg.api_key == "secret"
|
||||
assert cfg.model == "Embedding-3"
|
||||
|
||||
|
||||
def test_embed_empty_texts_returns_empty() -> None:
|
||||
provider = EmbeddingProvider(_config())
|
||||
assert provider.embed([]) == []
|
||||
|
||||
|
||||
def test_embed_returns_vectors_and_caches_dimension() -> None:
|
||||
provider = EmbeddingProvider(_config())
|
||||
with patch(
|
||||
"app.services.embedding_provider._request_embeddings_public",
|
||||
return_value=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
) as mock_req:
|
||||
vectors = provider.embed(["a", "b"])
|
||||
assert vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
assert provider.dimension() == 3
|
||||
calls_after_first_dimension = mock_req.call_count
|
||||
# 第二次 dimension 不应再次请求
|
||||
assert provider.dimension() == 3
|
||||
assert mock_req.call_count == calls_after_first_dimension
|
||||
|
||||
|
||||
def test_dimension_raises_on_invalid_vectors() -> None:
|
||||
provider = EmbeddingProvider(_config())
|
||||
with patch(
|
||||
"app.services.embedding_provider._request_embeddings_public",
|
||||
return_value=[],
|
||||
):
|
||||
with pytest.raises(KnowledgeRagError):
|
||||
provider.dimension()
|
||||
|
||||
|
||||
def test_request_embeddings_public_glm_branch() -> None:
|
||||
cfg = _config("GLM")
|
||||
with patch(
|
||||
"app.services.embedding_provider._send_json_request",
|
||||
return_value=(200, {"data": [{"embedding": [0.1, 0.2]}]}),
|
||||
) as mock_send:
|
||||
from app.services.embedding_provider import _request_embeddings_public
|
||||
|
||||
vectors = _request_embeddings_public(cfg, ["x"])
|
||||
assert vectors == [[0.1, 0.2]]
|
||||
called_url = mock_send.call_args.args[1]
|
||||
assert called_url.endswith("/embeddings")
|
||||
|
||||
|
||||
def test_request_embeddings_public_ollama_branch() -> None:
|
||||
cfg = _config("Ollama")
|
||||
with patch(
|
||||
"app.services.embedding_provider._send_json_request",
|
||||
return_value=(200, {"embeddings": [[0.5, 0.6]]}),
|
||||
) as mock_send:
|
||||
from app.services.embedding_provider import _request_embeddings_public
|
||||
|
||||
vectors = _request_embeddings_public(cfg, ["x"])
|
||||
assert vectors == [[0.5, 0.6]]
|
||||
called_url = mock_send.call_args.args[1]
|
||||
assert called_url.endswith("/api/embed")
|
||||
|
||||
|
||||
def test_request_embeddings_public_raises_on_http_error() -> None:
|
||||
cfg = _config("GLM")
|
||||
with patch(
|
||||
"app.services.embedding_provider._send_json_request",
|
||||
return_value=(500, {"message": "boom"}),
|
||||
):
|
||||
from app.services.embedding_provider import _request_embeddings_public
|
||||
|
||||
with pytest.raises(KnowledgeRagError):
|
||||
_request_embeddings_public(cfg, ["x"])
|
||||
214
server/tests/test_few_shot_ingestion.py
Normal file
214
server/tests/test_few_shot_ingestion.py
Normal file
@@ -0,0 +1,214 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.db.base import Base
|
||||
from app.models.employee import Employee
|
||||
from app.models.few_shot_sample import FewShotSample
|
||||
from app.models.financial_record import ExpenseClaim
|
||||
from app.models.risk_observation import RiskObservation
|
||||
from app.schemas.risk_observation import RiskObservationFeedbackCreate
|
||||
from app.services.few_shot_ingestion import FewShotIngestionService
|
||||
from app.services.risk_observations import RiskObservationService
|
||||
|
||||
|
||||
def _build_session() -> Session:
|
||||
engine = create_engine(
|
||||
"sqlite+pysqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||||
return factory()
|
||||
|
||||
|
||||
def _observation(db: Session, key: str = "risk:c1:dup") -> RiskObservation:
|
||||
db.add(Employee(id="emp-1", employee_no="E1", name="员工", email="e@e.com", grade="P6"))
|
||||
db.add(
|
||||
ExpenseClaim(
|
||||
id="c1",
|
||||
claim_no="BX-001",
|
||||
employee_id="emp-1",
|
||||
employee_name="员工",
|
||||
department_name="风控部",
|
||||
expense_type="travel",
|
||||
reason="客户拜访",
|
||||
location="上海",
|
||||
amount=Decimal("1000"),
|
||||
currency="CNY",
|
||||
occurred_at=datetime(2026, 1, 1),
|
||||
submitted_at=datetime(2026, 1, 1),
|
||||
status="submitted",
|
||||
approval_stage="manager_review",
|
||||
risk_flags_json=[],
|
||||
)
|
||||
)
|
||||
db.flush()
|
||||
obs = RiskObservation(
|
||||
observation_key=key,
|
||||
subject_type="expense_claim",
|
||||
subject_key="claim:c1",
|
||||
claim_id="c1",
|
||||
claim_no="BX-001",
|
||||
risk_type="duplicate_invoice",
|
||||
risk_signal="duplicate_invoice",
|
||||
title="重复发票",
|
||||
description="同一发票出现在多张报销单",
|
||||
risk_score=86,
|
||||
risk_level="high",
|
||||
confidence_score=0.8,
|
||||
source="financial_risk_graph",
|
||||
algorithm_version="v1",
|
||||
ontology_json={"domain": "expense", "scenario": "reimbursement"},
|
||||
)
|
||||
db.add(obs)
|
||||
db.commit()
|
||||
db.refresh(obs)
|
||||
return obs
|
||||
|
||||
|
||||
def test_ingest_confirmed_persists_sample_and_calls_store() -> None:
|
||||
with _build_session() as db:
|
||||
obs = _observation(db)
|
||||
obs.feedback_status = "confirmed"
|
||||
service = FewShotIngestionService(db)
|
||||
fake_store = MagicMock()
|
||||
fake_store.upsert.return_value = "vec-1"
|
||||
with patch.object(service, "_store", return_value=fake_store):
|
||||
sample = service.ingest_observation_feedback(
|
||||
obs,
|
||||
MagicMock(feedback_type="confirm", comment="确认重复发票", actor="audit"),
|
||||
)
|
||||
assert sample is not None
|
||||
assert sample.label == "confirmed"
|
||||
assert sample.sample_key == f"obs:{obs.id}"
|
||||
assert "重复发票" in sample.case_text
|
||||
assert "确认重复发票" in sample.conclusion_text
|
||||
assert sample.vector_id == "vec-1"
|
||||
fake_store.upsert.assert_called_once()
|
||||
|
||||
|
||||
def test_ingest_false_positive_also_persisted() -> None:
|
||||
with _build_session() as db:
|
||||
obs = _observation(db, key="risk:c2:fp")
|
||||
obs.feedback_status = "false_positive"
|
||||
db.commit()
|
||||
service = FewShotIngestionService(db)
|
||||
with patch.object(service, "_store", return_value=MagicMock(upsert=MagicMock(return_value=None))):
|
||||
sample = service.ingest_observation_feedback(
|
||||
obs,
|
||||
MagicMock(feedback_type="false_positive", comment="", actor="audit"),
|
||||
)
|
||||
assert sample is not None
|
||||
assert sample.label == "false_positive"
|
||||
assert "误报" in sample.conclusion_text
|
||||
|
||||
|
||||
def test_ingest_ignored_label_returns_none() -> None:
|
||||
with _build_session() as db:
|
||||
obs = _observation(db)
|
||||
obs.feedback_status = "ignored"
|
||||
service = FewShotIngestionService(db)
|
||||
assert service.ingest_observation_feedback(obs, MagicMock()) is None
|
||||
|
||||
|
||||
def test_ingest_is_idempotent_on_duplicate_sample_key() -> None:
|
||||
with _build_session() as db:
|
||||
obs = _observation(db)
|
||||
service = FewShotIngestionService(db)
|
||||
store = MagicMock()
|
||||
store.upsert.side_effect = ["vec-1", "vec-2"]
|
||||
with patch.object(service, "_store", return_value=store):
|
||||
obs.feedback_status = "confirmed"
|
||||
first = service.ingest_observation_feedback(
|
||||
obs, MagicMock(feedback_type="confirm", comment="第一次", actor="a")
|
||||
)
|
||||
# 模拟后续被改判为误报
|
||||
obs.feedback_status = "false_positive"
|
||||
second = service.ingest_observation_feedback(
|
||||
obs, MagicMock(feedback_type="false_positive", comment="改判", actor="a")
|
||||
)
|
||||
assert first is not None and second is not None
|
||||
assert first.id == second.id # 同一行更新
|
||||
from sqlalchemy import select
|
||||
|
||||
count = db.scalar(select(FewShotSample).where(FewShotSample.sample_key == f"obs:{obs.id}"))
|
||||
assert count is not None
|
||||
assert second.label == "false_positive"
|
||||
|
||||
|
||||
def test_create_feedback_hook_triggers_ingestion() -> None:
|
||||
with _build_session() as db:
|
||||
service = RiskObservationService(db)
|
||||
obs = _observation(db)
|
||||
ingest_calls: list = []
|
||||
|
||||
def _spy_ingest(o, f):
|
||||
ingest_calls.append((o.id, f.feedback_type))
|
||||
return None
|
||||
|
||||
with patch(
|
||||
"app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback",
|
||||
side_effect=_spy_ingest,
|
||||
):
|
||||
service.create_feedback(
|
||||
obs.observation_key,
|
||||
RiskObservationFeedbackCreate(feedback_type="confirm", actor="audit"),
|
||||
)
|
||||
assert len(ingest_calls) == 1
|
||||
assert ingest_calls[0][1] == "confirm"
|
||||
|
||||
|
||||
def test_create_feedback_hook_skipped_for_comment_feedback() -> None:
|
||||
with _build_session() as db:
|
||||
service = RiskObservationService(db)
|
||||
obs = _observation(db)
|
||||
with patch(
|
||||
"app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback"
|
||||
) as mock_ingest:
|
||||
service.create_feedback(
|
||||
obs.observation_key,
|
||||
RiskObservationFeedbackCreate(feedback_type="comment", action="note", actor="audit"),
|
||||
)
|
||||
mock_ingest.assert_not_called()
|
||||
|
||||
|
||||
def test_create_feedback_hook_swallows_ingestion_failure() -> None:
|
||||
with _build_session() as db:
|
||||
service = RiskObservationService(db)
|
||||
obs = _observation(db)
|
||||
with patch(
|
||||
"app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback",
|
||||
side_effect=RuntimeError("boom"),
|
||||
):
|
||||
# 不应抛异常
|
||||
feedback = service.create_feedback(
|
||||
obs.observation_key,
|
||||
RiskObservationFeedbackCreate(feedback_type="confirm", actor="audit"),
|
||||
)
|
||||
assert feedback.feedback_type == "confirm"
|
||||
|
||||
|
||||
def test_create_feedback_hook_respects_feature_flag(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("FEW_SHOT_INJECTION_ENABLED", "false")
|
||||
with _build_session() as db:
|
||||
service = RiskObservationService(db)
|
||||
obs = _observation(db)
|
||||
with patch(
|
||||
"app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback"
|
||||
) as mock_ingest:
|
||||
service.create_feedback(
|
||||
obs.observation_key,
|
||||
RiskObservationFeedbackCreate(feedback_type="confirm", actor="audit"),
|
||||
)
|
||||
mock_ingest.assert_not_called()
|
||||
119
server/tests/test_few_shot_retrieval_and_prompt.py
Normal file
119
server/tests/test_few_shot_retrieval_and_prompt.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.few_shot_retrieval import FewShotRetriever
|
||||
from app.services.few_shot_store import FewShotStore
|
||||
from app.services.risk_rule_generation_prompt import build_risk_rule_compiler_messages
|
||||
|
||||
|
||||
def _hit(score: float, label: str, conclusion: str, risk_type: str = "duplicate_invoice") -> dict:
|
||||
return {
|
||||
"sample_id": "s1",
|
||||
"score": score,
|
||||
"label": label,
|
||||
"domain": "expense",
|
||||
"risk_type": risk_type,
|
||||
"conclusion_text": conclusion,
|
||||
"payload_json": {
|
||||
"risk_signal": risk_type,
|
||||
"risk_level": "high",
|
||||
"ontology": {"scenario": "reimbursement"},
|
||||
"feedback_comment": "",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_retrieve_returns_injection_blocks_with_token_budget() -> None:
|
||||
store = MagicMock(spec=FewShotStore)
|
||||
store.search.return_value = [
|
||||
_hit(0.9, "confirmed", "确认重复发票需拦截"),
|
||||
_hit(0.8, "false_positive", "此情形属于正常拆单不拦截"),
|
||||
_hit(0.7, "confirmed", "确认重复发票需拦截"), # 重复结论应被去重
|
||||
]
|
||||
retriever = FewShotRetriever(store)
|
||||
blocks = retriever.retrieve_for_risk_rule_generation(
|
||||
domain="expense", natural_language="同一发票重复报销"
|
||||
)
|
||||
assert len(blocks) == 2
|
||||
assert blocks[0]["score"] == 0.9
|
||||
assert blocks[0]["label"] == "confirmed"
|
||||
assert blocks[0]["source"] == "historical_confirmed"
|
||||
assert blocks[1]["label"] == "false_positive"
|
||||
# 去重:第三条结论与第一条相同,应被过滤
|
||||
conclusions = [b["conclusion"] for b in blocks]
|
||||
assert len(set(conclusions)) == len(conclusions)
|
||||
|
||||
|
||||
def test_retrieve_empty_case_text_returns_empty() -> None:
|
||||
store = MagicMock(spec=FewShotStore)
|
||||
retriever = FewShotRetriever(store)
|
||||
assert retriever.retrieve_for_risk_rule_generation(natural_language="") == []
|
||||
store.search.assert_not_called()
|
||||
|
||||
|
||||
def test_retrieve_truncates_overlong_conclusion() -> None:
|
||||
store = MagicMock(spec=FewShotStore)
|
||||
long_text = "长结论" * 500
|
||||
store.search.return_value = [
|
||||
_hit(0.9, "confirmed", long_text),
|
||||
]
|
||||
retriever = FewShotRetriever(store)
|
||||
blocks = retriever.retrieve_for_risk_rule_generation(natural_language="x")
|
||||
assert len(blocks) == 1
|
||||
# 超长结论应被截断到单条上限
|
||||
from app.services.few_shot_retrieval import SINGLE_SAMPLE_MAX_CHARS
|
||||
|
||||
assert len(blocks[0]["conclusion"]) <= SINGLE_SAMPLE_MAX_CHARS
|
||||
|
||||
|
||||
def test_build_prompt_merges_few_shot_into_examples() -> None:
|
||||
samples = [
|
||||
{
|
||||
"source": "historical_confirmed",
|
||||
"label": "confirmed",
|
||||
"domain": "expense",
|
||||
"risk_type": "duplicate_invoice",
|
||||
"conclusion": "确认重复发票",
|
||||
"context": {"risk_signal": "duplicate_invoice"},
|
||||
}
|
||||
]
|
||||
messages = build_risk_rule_compiler_messages(
|
||||
domain="expense",
|
||||
domain_label="报销",
|
||||
business_stage="reimbursement",
|
||||
business_stage_label="报销",
|
||||
expense_category=None,
|
||||
expense_category_label="",
|
||||
natural_language="重复发票规则",
|
||||
available_fields=[{"key": "attachment.invoice_no", "label": "发票号", "type": "string", "source": "attachment"}],
|
||||
few_shot_samples=samples,
|
||||
)
|
||||
assert len(messages) == 2
|
||||
payload = json.loads(messages[1]["content"])
|
||||
examples = payload["examples"]
|
||||
# 前两条是历史样本,后面是内置 examples
|
||||
assert examples[0]["source"] == "historical_confirmed"
|
||||
assert examples[0]["conclusion"] == "确认重复发票"
|
||||
# 内置 example 仍存在(无 source 字段)
|
||||
assert any("user_rule" in ex for ex in examples)
|
||||
|
||||
|
||||
def test_build_prompt_without_few_shot_is_backward_compatible() -> None:
|
||||
messages = build_risk_rule_compiler_messages(
|
||||
domain="expense",
|
||||
domain_label="报销",
|
||||
business_stage="reimbursement",
|
||||
business_stage_label="报销",
|
||||
expense_category=None,
|
||||
expense_category_label="",
|
||||
natural_language="重复发票规则",
|
||||
available_fields=[],
|
||||
)
|
||||
payload = json.loads(messages[1]["content"])
|
||||
examples = payload["examples"]
|
||||
# 无 few_shot_samples 时 examples 里不应有 historical_confirmed 来源
|
||||
assert all(ex.get("source") != "historical_confirmed" for ex in examples)
|
||||
Reference in New Issue
Block a user