Files
X-Financial/server/tests/test_risk_rule_golden_evaluator.py
caoxiaozhu c7ba7bb453 feat(flywheel): golden case 管理 API 与评测单测
- 新增 GoldenCaseCreate/Read、GoldenEvalRequest/Read schema
- 新增 3 个端点:创建 golden case、按规则列表、手动触发 golden 评测
  (不入门禁,供运营试跑)
- 单测 15 passed:单条 hit/severity 比对、集合 accuracy/precision/recall
  聚合、空集降级、100% 通过/失败拦截、feature flag、异常降级
- 回归 test_agent_asset_service 27 passed(1 个预存失败与本改动无关)
2026-07-03 14:38:43 +08:00

263 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
from collections.abc import Generator
from datetime import datetime
from decimal import Decimal
from unittest.mock import MagicMock, patch
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.db.base import Base
from app.models.agent_asset import AgentAsset, AgentAssetTestRun
from app.models.employee import Employee
from app.models.financial_record import ExpenseClaim
from app.models.golden_case import GoldenCase
from app.services.risk_rule_golden_evaluator import (
GoldenEvalReport,
RiskRuleGoldenEvaluator,
_aggregate,
_run_single_case,
)
def _build_session() -> Session:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
return factory()
def _keyword_manifest() -> dict:
"""一个简单的 keyword_match_v1 manifestreason 含"虚假"则命中。"""
return {
"rule_code": "risk.test.keyword",
"template_key": "keyword_match_v1",
"inputs": {
"fields": [
{"key": "claim.reason", "label": "事由", "type": "text", "source": "claim"},
]
},
"params": {
"keywords": ["虚假"],
"field_keys": ["claim.reason"],
"search_fields": ["claim.reason"],
},
"outcomes": {"fail": {"severity": "high", "risk_score": 80}},
}
def _golden_case(
case_key: str,
*,
reason: str,
expected_hit: bool,
rule_code: str = "risk.test.keyword",
) -> GoldenCase:
return GoldenCase(
case_key=case_key,
rule_code=rule_code,
name=f"case-{case_key}",
values_json={"claim.reason": reason},
expected_hit=expected_hit,
status="active",
)
def test_run_single_case_hit_matches() -> None:
result = _run_single_case(
_keyword_manifest(),
values={"claim.reason": "虚假发票报销"},
expected_hit=True,
expected_severity="high",
)
assert result.actual_hit is True
assert result.passed is True
assert result.actual_severity == "high"
def test_run_single_case_no_hit_matches() -> None:
result = _run_single_case(
_keyword_manifest(),
values={"claim.reason": "正常差旅报销"},
expected_hit=False,
expected_severity="",
)
assert result.actual_hit is False
assert result.passed is True
def test_run_single_case_mismatch_fails() -> None:
result = _run_single_case(
_keyword_manifest(),
values={"claim.reason": "虚假发票"},
expected_hit=False, # 期望不命中,但实际命中
expected_severity="",
)
assert result.actual_hit is True
assert result.passed is False
def test_run_single_case_severity_mismatch_fails() -> None:
result = _run_single_case(
_keyword_manifest(),
values={"claim.reason": "虚假发票"},
expected_hit=True,
expected_severity="critical", # 实际是 high
)
assert result.passed is False
def test_aggregate_empty_returns_passed() -> None:
report = _aggregate([])
assert report.total == 0
assert report.all_passed is True
assert report.accuracy == 0.0
def test_aggregate_all_passed() -> None:
from app.services.risk_rule_golden_evaluator import GoldenCaseResult
results = [
GoldenCaseResult("1", "a", True, True, "high", "high", True),
GoldenCaseResult("2", "b", False, False, "", "none", True),
]
report = _aggregate(results)
assert report.total == 2
assert report.passed_count == 2
assert report.accuracy == 1.0
assert report.all_passed is True
def test_aggregate_with_failure() -> None:
from app.services.risk_rule_golden_evaluator import GoldenCaseResult
results = [
GoldenCaseResult("1", "a", True, True, "high", "high", True),
GoldenCaseResult("2", "b", True, False, "high", "none", False), # FP
]
report = _aggregate(results)
assert report.passed_count == 1
assert report.failed_count == 1
assert report.accuracy == 0.5
assert report.all_passed is False
assert report.precision == 0.5 # 1/(1+1)
def test_evaluate_for_rule_empty_returns_passed() -> None:
with _build_session() as db:
report = RiskRuleGoldenEvaluator().evaluate_for_rule(db, _keyword_manifest(), "risk.test.keyword")
assert report.total == 0
assert report.all_passed is True
def test_evaluate_for_rule_all_pass() -> None:
with _build_session() as db:
db.add(_golden_case("g1", reason="虚假发票", expected_hit=True))
db.add(_golden_case("g2", reason="正常报销", expected_hit=False))
db.commit()
report = RiskRuleGoldenEvaluator().evaluate_for_rule(db, _keyword_manifest(), "risk.test.keyword")
assert report.total == 2
assert report.all_passed is True
assert report.accuracy == 1.0
def test_evaluate_for_rule_with_failure() -> None:
with _build_session() as db:
db.add(_golden_case("g1", reason="虚假发票", expected_hit=False)) # 期望不命中但实际命中
db.add(_golden_case("g2", reason="正常报销", expected_hit=True)) # 期望命中但实际不命中
db.commit()
report = RiskRuleGoldenEvaluator().evaluate_for_rule(db, _keyword_manifest(), "risk.test.keyword")
assert report.total == 2
assert report.all_passed is False
assert report.failed_count == 2
def _asset(asset_id: str, code: str) -> AgentAsset:
return AgentAsset(
id=asset_id,
code=code,
name=code,
asset_type="rule",
domain="expense",
owner="tester",
status="review",
working_version="v1",
)
def test_require_pass_passes_when_all_green() -> None:
with _build_session() as db:
asset = _asset("a1", "R1")
db.add(asset)
db.add(_golden_case("g1", reason="虚假", expected_hit=True))
db.commit()
report = RiskRuleGoldenEvaluator().require_pass(
db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester"
)
assert report.all_passed is True
# 应写一条 test_type='golden' 记录
run = db.query(AgentAssetTestRun).filter_by(asset_id="a1", test_type="golden").one()
assert run.passed is True
def test_require_pass_raises_on_failure() -> None:
with _build_session() as db:
asset = _asset("a2", "R2")
db.add(asset)
db.add(_golden_case("g1", reason="虚假", expected_hit=False)) # 会失败
db.commit()
with pytest.raises(PermissionError):
RiskRuleGoldenEvaluator().require_pass(
db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester"
)
run = db.query(AgentAssetTestRun).filter_by(asset_id="a2", test_type="golden").one()
assert run.passed is False
def test_require_pass_empty_golden_set_passes() -> None:
with _build_session() as db:
asset = _asset("a3", "R3")
db.add(asset)
db.commit()
report = RiskRuleGoldenEvaluator().require_pass(
db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester"
)
assert report.total == 0
assert report.all_passed is True
def test_require_pass_respects_feature_flag(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("GOLDEN_SET_GATE_ENABLED", "false")
with _build_session() as db:
asset = _asset("a4", "R4")
db.add(asset)
db.add(_golden_case("g1", reason="虚假", expected_hit=False)) # 本应失败
db.commit()
# 门禁关闭,应放行不抛异常
report = RiskRuleGoldenEvaluator().require_pass(
db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester"
)
assert report.total == 0
def test_require_pass_swallows_evaluator_exception() -> None:
with _build_session() as db:
asset = _asset("a5", "R5")
db.add(asset)
db.commit()
evaluator = RiskRuleGoldenEvaluator()
with patch.object(evaluator, "evaluate_for_rule", side_effect=RuntimeError("boom")):
report = evaluator.require_pass(
db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester"
)
assert report.total == 0
assert report.all_passed is True # 降级放行