from __future__ import annotations from collections.abc import Generator from datetime import datetime from decimal import Decimal from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.pool import StaticPool from app.db.base import Base from app.models.employee import Employee from app.models.few_shot_sample import FewShotSample from app.models.financial_record import ExpenseClaim from app.models.risk_observation import RiskObservation from app.schemas.risk_observation import RiskObservationFeedbackCreate from app.services.few_shot_ingestion import FewShotIngestionService from app.services.risk_observations import RiskObservationService def _build_session() -> Session: engine = create_engine( "sqlite+pysqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) Base.metadata.create_all(bind=engine) factory = sessionmaker(bind=engine, autoflush=False, autocommit=False) return factory() def _observation(db: Session, key: str = "risk:c1:dup") -> RiskObservation: db.add(Employee(id="emp-1", employee_no="E1", name="员工", email="e@e.com", grade="P6")) db.add( ExpenseClaim( id="c1", claim_no="BX-001", employee_id="emp-1", employee_name="员工", department_name="风控部", expense_type="travel", reason="客户拜访", location="上海", amount=Decimal("1000"), currency="CNY", occurred_at=datetime(2026, 1, 1), submitted_at=datetime(2026, 1, 1), status="submitted", approval_stage="manager_review", risk_flags_json=[], ) ) db.flush() obs = RiskObservation( observation_key=key, subject_type="expense_claim", subject_key="claim:c1", claim_id="c1", claim_no="BX-001", risk_type="duplicate_invoice", risk_signal="duplicate_invoice", title="重复发票", description="同一发票出现在多张报销单", risk_score=86, risk_level="high", confidence_score=0.8, source="financial_risk_graph", algorithm_version="v1", ontology_json={"domain": "expense", "scenario": "reimbursement"}, ) db.add(obs) db.commit() db.refresh(obs) return obs def test_ingest_confirmed_persists_sample_and_calls_store() -> None: with _build_session() as db: obs = _observation(db) obs.feedback_status = "confirmed" service = FewShotIngestionService(db) fake_store = MagicMock() fake_store.upsert.return_value = "vec-1" with patch.object(service, "_store", return_value=fake_store): sample = service.ingest_observation_feedback( obs, MagicMock(feedback_type="confirm", comment="确认重复发票", actor="audit"), ) assert sample is not None assert sample.label == "confirmed" assert sample.sample_key == f"obs:{obs.id}" assert "重复发票" in sample.case_text assert "确认重复发票" in sample.conclusion_text assert sample.vector_id == "vec-1" fake_store.upsert.assert_called_once() def test_ingest_false_positive_also_persisted() -> None: with _build_session() as db: obs = _observation(db, key="risk:c2:fp") obs.feedback_status = "false_positive" db.commit() service = FewShotIngestionService(db) with patch.object(service, "_store", return_value=MagicMock(upsert=MagicMock(return_value=None))): sample = service.ingest_observation_feedback( obs, MagicMock(feedback_type="false_positive", comment="", actor="audit"), ) assert sample is not None assert sample.label == "false_positive" assert "误报" in sample.conclusion_text def test_ingest_ignored_label_returns_none() -> None: with _build_session() as db: obs = _observation(db) obs.feedback_status = "ignored" service = FewShotIngestionService(db) assert service.ingest_observation_feedback(obs, MagicMock()) is None def test_ingest_is_idempotent_on_duplicate_sample_key() -> None: with _build_session() as db: obs = _observation(db) service = FewShotIngestionService(db) store = MagicMock() store.upsert.side_effect = ["vec-1", "vec-2"] with patch.object(service, "_store", return_value=store): obs.feedback_status = "confirmed" first = service.ingest_observation_feedback( obs, MagicMock(feedback_type="confirm", comment="第一次", actor="a") ) # 模拟后续被改判为误报 obs.feedback_status = "false_positive" second = service.ingest_observation_feedback( obs, MagicMock(feedback_type="false_positive", comment="改判", actor="a") ) assert first is not None and second is not None assert first.id == second.id # 同一行更新 from sqlalchemy import select count = db.scalar(select(FewShotSample).where(FewShotSample.sample_key == f"obs:{obs.id}")) assert count is not None assert second.label == "false_positive" def test_create_feedback_hook_triggers_ingestion() -> None: with _build_session() as db: service = RiskObservationService(db) obs = _observation(db) ingest_calls: list = [] def _spy_ingest(o, f): ingest_calls.append((o.id, f.feedback_type)) return None with patch( "app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback", side_effect=_spy_ingest, ): service.create_feedback( obs.observation_key, RiskObservationFeedbackCreate(feedback_type="confirm", actor="audit"), ) assert len(ingest_calls) == 1 assert ingest_calls[0][1] == "confirm" def test_create_feedback_hook_skipped_for_comment_feedback() -> None: with _build_session() as db: service = RiskObservationService(db) obs = _observation(db) with patch( "app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback" ) as mock_ingest: service.create_feedback( obs.observation_key, RiskObservationFeedbackCreate(feedback_type="comment", action="note", actor="audit"), ) mock_ingest.assert_not_called() def test_create_feedback_hook_swallows_ingestion_failure() -> None: with _build_session() as db: service = RiskObservationService(db) obs = _observation(db) with patch( "app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback", side_effect=RuntimeError("boom"), ): # 不应抛异常 feedback = service.create_feedback( obs.observation_key, RiskObservationFeedbackCreate(feedback_type="confirm", actor="audit"), ) assert feedback.feedback_type == "confirm" def test_create_feedback_hook_respects_feature_flag(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("FEW_SHOT_INJECTION_ENABLED", "false") with _build_session() as db: service = RiskObservationService(db) obs = _observation(db) with patch( "app.services.few_shot_ingestion.FewShotIngestionService.ingest_observation_feedback" ) as mock_ingest: service.create_feedback( obs.observation_key, RiskObservationFeedbackCreate(feedback_type="confirm", actor="audit"), ) mock_ingest.assert_not_called()