Files
X-Financial/server/tests/test_few_shot_ingestion.py
caoxiaozhu 52d57c3be7 test(flywheel): 补 few-shot 飞轮单测并沉淀开发文档
- embedding_provider:GLM/Ollama 分支、维度缓存、HTTP 错误降级
- few_shot_ingestion:confirmed/false_positive 入库、ignored 跳过、幂等去重、
  create_feedback hook 触发、feature flag、吞异常
- few_shot_retrieval:去重、token 预算、超长截断;prompt 注入合并 examples + 向后兼容
- 容器内新增测试 20 passed;回归测试 35 passed(RAG/risk_observations/rule_generation)
- 沉淀 document/development/2026-07-03/feature/ai-data-flywheel 概念文档与 TODO,
  飞轮 1 已勾选证据,飞轮 2-6 待后续迭代
2026-07-03 13:56:21 +08:00

215 lines
7.9 KiB
Python

from __future__ import annotations
from collections.abc import Generator
from datetime import datetime
from decimal import Decimal
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.db.base import Base
from app.models.employee import Employee
from app.models.few_shot_sample import FewShotSample
from app.models.financial_record import ExpenseClaim
from app.models.risk_observation import RiskObservation
from app.schemas.risk_observation import RiskObservationFeedbackCreate
from app.services.few_shot_ingestion import FewShotIngestionService
from app.services.risk_observations import RiskObservationService
def _build_session() -> Session:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
return factory()
def _observation(db: Session, key: str = "risk:c1:dup") -> RiskObservation:
db.add(Employee(id="emp-1", employee_no="E1", name="员工", email="e@e.com", grade="P6"))
db.add(
ExpenseClaim(
id="c1",
claim_no="BX-001",
employee_id="emp-1",
employee_name="员工",
department_name="风控部",
expense_type="travel",
reason="客户拜访",
location="上海",
amount=Decimal("1000"),
currency="CNY",
occurred_at=datetime(2026, 1, 1),
submitted_at=datetime(2026, 1, 1),
status="submitted",
approval_stage="manager_review",
risk_flags_json=[],
)
)
db.flush()
obs = RiskObservation(
observation_key=key,
subject_type="expense_claim",
subject_key="claim:c1",
claim_id="c1",
claim_no="BX-001",
risk_type="duplicate_invoice",
risk_signal="duplicate_invoice",
title="重复发票",
description="同一发票出现在多张报销单",
risk_score=86,
risk_level="high",
confidence_score=0.8,
source="financial_risk_graph",
algorithm_version="v1",
ontology_json={"domain": "expense", "scenario": "reimbursement"},
)
db.add(obs)
db.commit()
db.refresh(obs)
return obs
def test_ingest_confirmed_persists_sample_and_calls_store() -> None:
with _build_session() as db:
obs = _observation(db)
obs.feedback_status = "confirmed"
service = FewShotIngestionService(db)
fake_store = MagicMock()
fake_store.upsert.return_value = "vec-1"
with patch.object(service, "_store", return_value=fake_store):
sample = service.ingest_observation_feedback(
obs,
MagicMock(feedback_type="confirm", comment="确认重复发票", actor="audit"),
)
assert sample is not None
assert sample.label == "confirmed"
assert sample.sample_key == f"obs:{obs.id}"
assert "重复发票" in sample.case_text
assert "确认重复发票" in sample.conclusion_text
assert sample.vector_id == "vec-1"
fake_store.upsert.assert_called_once()
def test_ingest_false_positive_also_persisted() -> None:
with _build_session() as db:
obs = _observation(db, key="risk:c2:fp")
obs.feedback_status = "false_positive"
db.commit()
service = FewShotIngestionService(db)
with patch.object(service, "_store", return_value=MagicMock(upsert=MagicMock(return_value=None))):
sample = service.ingest_observation_feedback(
obs,
MagicMock(feedback_type="false_positive", comment="", actor="audit"),
)
assert sample is not None
assert sample.label == "false_positive"
assert "误报" in sample.conclusion_text
def test_ingest_ignored_label_returns_none() -> None:
with _build_session() as db:
obs = _observation(db)
obs.feedback_status = "ignored"
service = FewShotIngestionService(db)
assert service.ingest_observation_feedback(obs, MagicMock()) is None
def test_ingest_is_idempotent_on_duplicate_sample_key() -> None:
with _build_session() as db:
obs = _observation(db)
service = FewShotIngestionService(db)
store = MagicMock()
store.upsert.side_effect = ["vec-1", "vec-2"]
with patch.object(service, "_store", return_value=store):
obs.feedback_status = "confirmed"
first = service.ingest_observation_feedback(
obs, MagicMock(feedback_type="confirm", comment="第一次", actor="a")
)
# 模拟后续被改判为误报
obs.feedback_status = "false_positive"
second = service.ingest_observation_feedback(
obs, MagicMock(feedback_type="false_positive", comment="改判", actor="a")
)
assert first is not None and second is not None
assert first.id == second.id # 同一行更新
from sqlalchemy import select
count = db.scalar(select(FewShotSample).where(FewShotSample.sample_key == f"obs:{obs.id}"))
assert count is not None
assert second.label == "false_positive"
def test_create_feedback_hook_triggers_ingestion() -> None:
with _build_session() as db:
service = RiskObservationService(db)
obs = _observation(db)
ingest_calls: list = []
def _spy_ingest(o, f):
ingest_calls.append((o.id, f.feedback_type))
return None
with patch(
"app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback",
side_effect=_spy_ingest,
):
service.create_feedback(
obs.observation_key,
RiskObservationFeedbackCreate(feedback_type="confirm", actor="audit"),
)
assert len(ingest_calls) == 1
assert ingest_calls[0][1] == "confirm"
def test_create_feedback_hook_skipped_for_comment_feedback() -> None:
with _build_session() as db:
service = RiskObservationService(db)
obs = _observation(db)
with patch(
"app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback"
) as mock_ingest:
service.create_feedback(
obs.observation_key,
RiskObservationFeedbackCreate(feedback_type="comment", action="note", actor="audit"),
)
mock_ingest.assert_not_called()
def test_create_feedback_hook_swallows_ingestion_failure() -> None:
with _build_session() as db:
service = RiskObservationService(db)
obs = _observation(db)
with patch(
"app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback",
side_effect=RuntimeError("boom"),
):
# 不应抛异常
feedback = service.create_feedback(
obs.observation_key,
RiskObservationFeedbackCreate(feedback_type="confirm", actor="audit"),
)
assert feedback.feedback_type == "confirm"
def test_create_feedback_hook_respects_feature_flag(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("FEW_SHOT_INJECTION_ENABLED", "false")
with _build_session() as db:
service = RiskObservationService(db)
obs = _observation(db)
with patch(
"app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback"
) as mock_ingest:
service.create_feedback(
obs.observation_key,
RiskObservationFeedbackCreate(feedback_type="confirm", actor="audit"),
)
mock_ingest.assert_not_called()