215 lines
7.9 KiB
Python
215 lines
7.9 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from collections.abc import Generator
|
||
|
|
from datetime import datetime
|
||
|
|
from decimal import Decimal
|
||
|
|
from types import SimpleNamespace
|
||
|
|
from unittest.mock import MagicMock, patch
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from sqlalchemy import create_engine
|
||
|
|
from sqlalchemy.orm import Session, sessionmaker
|
||
|
|
from sqlalchemy.pool import StaticPool
|
||
|
|
|
||
|
|
from app.db.base import Base
|
||
|
|
from app.models.employee import Employee
|
||
|
|
from app.models.few_shot_sample import FewShotSample
|
||
|
|
from app.models.financial_record import ExpenseClaim
|
||
|
|
from app.models.risk_observation import RiskObservation
|
||
|
|
from app.schemas.risk_observation import RiskObservationFeedbackCreate
|
||
|
|
from app.services.few_shot_ingestion import FewShotIngestionService
|
||
|
|
from app.services.risk_observations import RiskObservationService
|
||
|
|
|
||
|
|
|
||
|
|
def _build_session() -> Session:
|
||
|
|
engine = create_engine(
|
||
|
|
"sqlite+pysqlite:///:memory:",
|
||
|
|
connect_args={"check_same_thread": False},
|
||
|
|
poolclass=StaticPool,
|
||
|
|
)
|
||
|
|
Base.metadata.create_all(bind=engine)
|
||
|
|
factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||
|
|
return factory()
|
||
|
|
|
||
|
|
|
||
|
|
def _observation(db: Session, key: str = "risk:c1:dup") -> RiskObservation:
|
||
|
|
db.add(Employee(id="emp-1", employee_no="E1", name="员工", email="e@e.com", grade="P6"))
|
||
|
|
db.add(
|
||
|
|
ExpenseClaim(
|
||
|
|
id="c1",
|
||
|
|
claim_no="BX-001",
|
||
|
|
employee_id="emp-1",
|
||
|
|
employee_name="员工",
|
||
|
|
department_name="风控部",
|
||
|
|
expense_type="travel",
|
||
|
|
reason="客户拜访",
|
||
|
|
location="上海",
|
||
|
|
amount=Decimal("1000"),
|
||
|
|
currency="CNY",
|
||
|
|
occurred_at=datetime(2026, 1, 1),
|
||
|
|
submitted_at=datetime(2026, 1, 1),
|
||
|
|
status="submitted",
|
||
|
|
approval_stage="manager_review",
|
||
|
|
risk_flags_json=[],
|
||
|
|
)
|
||
|
|
)
|
||
|
|
db.flush()
|
||
|
|
obs = RiskObservation(
|
||
|
|
observation_key=key,
|
||
|
|
subject_type="expense_claim",
|
||
|
|
subject_key="claim:c1",
|
||
|
|
claim_id="c1",
|
||
|
|
claim_no="BX-001",
|
||
|
|
risk_type="duplicate_invoice",
|
||
|
|
risk_signal="duplicate_invoice",
|
||
|
|
title="重复发票",
|
||
|
|
description="同一发票出现在多张报销单",
|
||
|
|
risk_score=86,
|
||
|
|
risk_level="high",
|
||
|
|
confidence_score=0.8,
|
||
|
|
source="financial_risk_graph",
|
||
|
|
algorithm_version="v1",
|
||
|
|
ontology_json={"domain": "expense", "scenario": "reimbursement"},
|
||
|
|
)
|
||
|
|
db.add(obs)
|
||
|
|
db.commit()
|
||
|
|
db.refresh(obs)
|
||
|
|
return obs
|
||
|
|
|
||
|
|
|
||
|
|
def test_ingest_confirmed_persists_sample_and_calls_store() -> None:
|
||
|
|
with _build_session() as db:
|
||
|
|
obs = _observation(db)
|
||
|
|
obs.feedback_status = "confirmed"
|
||
|
|
service = FewShotIngestionService(db)
|
||
|
|
fake_store = MagicMock()
|
||
|
|
fake_store.upsert.return_value = "vec-1"
|
||
|
|
with patch.object(service, "_store", return_value=fake_store):
|
||
|
|
sample = service.ingest_observation_feedback(
|
||
|
|
obs,
|
||
|
|
MagicMock(feedback_type="confirm", comment="确认重复发票", actor="audit"),
|
||
|
|
)
|
||
|
|
assert sample is not None
|
||
|
|
assert sample.label == "confirmed"
|
||
|
|
assert sample.sample_key == f"obs:{obs.id}"
|
||
|
|
assert "重复发票" in sample.case_text
|
||
|
|
assert "确认重复发票" in sample.conclusion_text
|
||
|
|
assert sample.vector_id == "vec-1"
|
||
|
|
fake_store.upsert.assert_called_once()
|
||
|
|
|
||
|
|
|
||
|
|
def test_ingest_false_positive_also_persisted() -> None:
|
||
|
|
with _build_session() as db:
|
||
|
|
obs = _observation(db, key="risk:c2:fp")
|
||
|
|
obs.feedback_status = "false_positive"
|
||
|
|
db.commit()
|
||
|
|
service = FewShotIngestionService(db)
|
||
|
|
with patch.object(service, "_store", return_value=MagicMock(upsert=MagicMock(return_value=None))):
|
||
|
|
sample = service.ingest_observation_feedback(
|
||
|
|
obs,
|
||
|
|
MagicMock(feedback_type="false_positive", comment="", actor="audit"),
|
||
|
|
)
|
||
|
|
assert sample is not None
|
||
|
|
assert sample.label == "false_positive"
|
||
|
|
assert "误报" in sample.conclusion_text
|
||
|
|
|
||
|
|
|
||
|
|
def test_ingest_ignored_label_returns_none() -> None:
|
||
|
|
with _build_session() as db:
|
||
|
|
obs = _observation(db)
|
||
|
|
obs.feedback_status = "ignored"
|
||
|
|
service = FewShotIngestionService(db)
|
||
|
|
assert service.ingest_observation_feedback(obs, MagicMock()) is None
|
||
|
|
|
||
|
|
|
||
|
|
def test_ingest_is_idempotent_on_duplicate_sample_key() -> None:
|
||
|
|
with _build_session() as db:
|
||
|
|
obs = _observation(db)
|
||
|
|
service = FewShotIngestionService(db)
|
||
|
|
store = MagicMock()
|
||
|
|
store.upsert.side_effect = ["vec-1", "vec-2"]
|
||
|
|
with patch.object(service, "_store", return_value=store):
|
||
|
|
obs.feedback_status = "confirmed"
|
||
|
|
first = service.ingest_observation_feedback(
|
||
|
|
obs, MagicMock(feedback_type="confirm", comment="第一次", actor="a")
|
||
|
|
)
|
||
|
|
# 模拟后续被改判为误报
|
||
|
|
obs.feedback_status = "false_positive"
|
||
|
|
second = service.ingest_observation_feedback(
|
||
|
|
obs, MagicMock(feedback_type="false_positive", comment="改判", actor="a")
|
||
|
|
)
|
||
|
|
assert first is not None and second is not None
|
||
|
|
assert first.id == second.id # 同一行更新
|
||
|
|
from sqlalchemy import select
|
||
|
|
|
||
|
|
count = db.scalar(select(FewShotSample).where(FewShotSample.sample_key == f"obs:{obs.id}"))
|
||
|
|
assert count is not None
|
||
|
|
assert second.label == "false_positive"
|
||
|
|
|
||
|
|
|
||
|
|
def test_create_feedback_hook_triggers_ingestion() -> None:
|
||
|
|
with _build_session() as db:
|
||
|
|
service = RiskObservationService(db)
|
||
|
|
obs = _observation(db)
|
||
|
|
ingest_calls: list = []
|
||
|
|
|
||
|
|
def _spy_ingest(o, f):
|
||
|
|
ingest_calls.append((o.id, f.feedback_type))
|
||
|
|
return None
|
||
|
|
|
||
|
|
with patch(
|
||
|
|
"app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback",
|
||
|
|
side_effect=_spy_ingest,
|
||
|
|
):
|
||
|
|
service.create_feedback(
|
||
|
|
obs.observation_key,
|
||
|
|
RiskObservationFeedbackCreate(feedback_type="confirm", actor="audit"),
|
||
|
|
)
|
||
|
|
assert len(ingest_calls) == 1
|
||
|
|
assert ingest_calls[0][1] == "confirm"
|
||
|
|
|
||
|
|
|
||
|
|
def test_create_feedback_hook_skipped_for_comment_feedback() -> None:
|
||
|
|
with _build_session() as db:
|
||
|
|
service = RiskObservationService(db)
|
||
|
|
obs = _observation(db)
|
||
|
|
with patch(
|
||
|
|
"app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback"
|
||
|
|
) as mock_ingest:
|
||
|
|
service.create_feedback(
|
||
|
|
obs.observation_key,
|
||
|
|
RiskObservationFeedbackCreate(feedback_type="comment", action="note", actor="audit"),
|
||
|
|
)
|
||
|
|
mock_ingest.assert_not_called()
|
||
|
|
|
||
|
|
|
||
|
|
def test_create_feedback_hook_swallows_ingestion_failure() -> None:
|
||
|
|
with _build_session() as db:
|
||
|
|
service = RiskObservationService(db)
|
||
|
|
obs = _observation(db)
|
||
|
|
with patch(
|
||
|
|
"app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback",
|
||
|
|
side_effect=RuntimeError("boom"),
|
||
|
|
):
|
||
|
|
# 不应抛异常
|
||
|
|
feedback = service.create_feedback(
|
||
|
|
obs.observation_key,
|
||
|
|
RiskObservationFeedbackCreate(feedback_type="confirm", actor="audit"),
|
||
|
|
)
|
||
|
|
assert feedback.feedback_type == "confirm"
|
||
|
|
|
||
|
|
|
||
|
|
def test_create_feedback_hook_respects_feature_flag(monkeypatch: pytest.MonkeyPatch) -> None:
|
||
|
|
monkeypatch.setenv("FEW_SHOT_INJECTION_ENABLED", "false")
|
||
|
|
with _build_session() as db:
|
||
|
|
service = RiskObservationService(db)
|
||
|
|
obs = _observation(db)
|
||
|
|
with patch(
|
||
|
|
"app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback"
|
||
|
|
) as mock_ingest:
|
||
|
|
service.create_feedback(
|
||
|
|
obs.observation_key,
|
||
|
|
RiskObservationFeedbackCreate(feedback_type="confirm", actor="audit"),
|
||
|
|
)
|
||
|
|
mock_ingest.assert_not_called()
|