From c7ba7bb45356372832767fa04c506f6c8c9701ce Mon Sep 17 00:00:00 2001 From: caoxiaozhu Date: Fri, 3 Jul 2026 14:38:43 +0800 Subject: [PATCH] =?UTF-8?q?feat(flywheel):=20golden=20case=20=E7=AE=A1?= =?UTF-8?q?=E7=90=86=20API=20=E4=B8=8E=E8=AF=84=E6=B5=8B=E5=8D=95=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 GoldenCaseCreate/Read、GoldenEvalRequest/Read schema - 新增 3 个端点:创建 golden case、按规则列表、手动触发 golden 评测 (不入门禁,供运营试跑) - 单测 15 passed:单条 hit/severity 比对、集合 accuracy/precision/recall 聚合、空集降级、100% 通过/失败拦截、feature flag、异常降级 - 回归 test_agent_asset_service 27 passed(1 个预存失败与本改动无关) --- .../src/app/api/v1/endpoints/agent_assets.py | 111 ++++++++ server/src/app/schemas/agent_asset.py | 40 +++ .../tests/test_risk_rule_golden_evaluator.py | 262 ++++++++++++++++++ 3 files changed, 413 insertions(+) create mode 100644 server/tests/test_risk_rule_golden_evaluator.py diff --git a/server/src/app/api/v1/endpoints/agent_assets.py b/server/src/app/api/v1/endpoints/agent_assets.py index 723a3c6..78d284e 100644 --- a/server/src/app/api/v1/endpoints/agent_assets.py +++ b/server/src/app/api/v1/endpoints/agent_assets.py @@ -43,6 +43,10 @@ from app.schemas.agent_asset import ( AgentAssetVersionCreate, AgentAssetVersionRead, AgentAssetVersionTimelineItemRead, + GoldenCaseCreate, + GoldenCaseRead, + GoldenEvalRead, + GoldenEvalRequest, ) from app.schemas.common import ErrorResponse, PaginatedResponse from app.services.agent_assets import AgentAssetService @@ -923,3 +927,110 @@ def get_agent_asset_version_timeline( return AgentAssetService(db).list_version_timeline(asset_id) except Exception as exc: _handle_asset_error(exc) + + +@router.post( + "/risk-rules/golden-cases", + response_model=GoldenCaseRead, + status_code=status.HTTP_201_CREATED, + summary="创建 golden set 黄金用例", + description="为指定规则(或通用场景)创建一条回归用例,发布前作为门禁集执行。", +) +def create_golden_case( + body: GoldenCaseCreate, + _: RuleEditorUser, + db: DbSession, +) -> GoldenCaseRead: + from app.models.golden_case import GoldenCase + from sqlalchemy import select + + existing = db.scalar(select(GoldenCase).where(GoldenCase.case_key == body.case_key)) + if existing is not None: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="case_key 已存在") + case = GoldenCase( + case_key=body.case_key, + rule_code=body.rule_code, + scene=body.scene, + name=body.name, + values_json=body.values, + expected_hit=body.expected_hit, + expected_severity=body.expected_severity, + note=body.note, + status="active", + source="manual", + ) + db.add(case) + db.commit() + db.refresh(case) + return _golden_case_read(case) + + +@router.get( + "/risk-rules/{rule_code}/golden-cases", + response_model=list[GoldenCaseRead], + summary="列出规则的 golden 用例", +) +def list_golden_cases( + rule_code: str, + _: CurrentUser, + db: DbSession, +) -> list[GoldenCaseRead]: + from app.models.golden_case import GoldenCase + from sqlalchemy import select + + cases = db.scalars( + select(GoldenCase).where(GoldenCase.rule_code == rule_code).order_by(GoldenCase.created_at) + ).all() + return [_golden_case_read(case) for case in cases] + + +@router.post( + "/{asset_id}/golden-eval", + response_model=GoldenEvalRead, + summary="手动触发 golden set 评测(不入门禁)", + description="在当前规则版本上跑 golden 用例集,返回指标。门禁由 publish 时自动触发。", +) +def run_golden_eval( + asset_id: str, + body: GoldenEvalRequest, + _: RuleReviewerUser, + db: DbSession, +) -> GoldenEvalRead: + from app.services.agent_asset_spreadsheet import RISK_RULES_LIBRARY + from app.services.risk_rule_golden_evaluator import RiskRuleGoldenEvaluator + + try: + asset = AgentAssetService(db).get_asset(asset_id) + if asset is None: + raise LookupError("Asset not found") + config = asset.config_json if isinstance(asset.config_json, dict) else {} + rule_document = config.get("rule_document") if isinstance(config.get("rule_document"), dict) else {} + file_name = str(rule_document.get("file_name") or "").strip() + if not file_name: + raise ValueError("该规则没有可执行的 manifest 文件。") + manager = AgentAssetService(db).rule_library_manager + manifest = manager.read_rule_library_json(library=RISK_RULES_LIBRARY, file_name=file_name) + rule_code = str(manifest.get("rule_code") or "").strip() + if not rule_code: + raise ValueError("manifest 缺少 rule_code。") + version = body.version or asset.working_version or "" + report = RiskRuleGoldenEvaluator().evaluate_for_rule(db, manifest, rule_code) + return GoldenEvalRead(**report.to_dict()) + except Exception as exc: + _handle_asset_error(exc) + + +def _golden_case_read(case) -> GoldenCaseRead: + return GoldenCaseRead( + id=case.id, + case_key=case.case_key, + rule_code=case.rule_code, + scene=case.scene or "", + name=case.name or "", + values=case.values_json or {}, + expected_hit=bool(case.expected_hit), + expected_severity=case.expected_severity, + note=case.note, + status=case.status, + source=case.source, + ) diff --git a/server/src/app/schemas/agent_asset.py b/server/src/app/schemas/agent_asset.py index be6b335..0ba9cd6 100644 --- a/server/src/app/schemas/agent_asset.py +++ b/server/src/app/schemas/agent_asset.py @@ -204,6 +204,46 @@ class AgentAssetRiskRuleReportRequest(BaseModel): note: str | None = Field(default=None, max_length=1000) +class GoldenCaseCreate(BaseModel): + case_key: str = Field(..., max_length=160) + rule_code: str | None = Field(default=None, max_length=120) + scene: str = Field(default="", max_length=50) + name: str = Field(default="", max_length=120) + values: dict[str, Any] = Field(default_factory=dict) + expected_hit: bool = True + expected_severity: str | None = Field(default=None, max_length=20) + note: str | None = None + + +class GoldenCaseRead(BaseModel): + id: str + case_key: str + rule_code: str | None = None + scene: str = "" + name: str = "" + values: dict[str, Any] = Field(default_factory=dict) + expected_hit: bool = True + expected_severity: str | None = None + note: str | None = None + status: str = "active" + source: str = "manual" + + +class GoldenEvalRequest(BaseModel): + version: str | None = Field(default=None, max_length=30) + + +class GoldenEvalRead(BaseModel): + total: int = 0 + passed_count: int = 0 + failed_count: int = 0 + accuracy: float = 0.0 + precision: float = 0.0 + recall: float = 0.0 + all_passed: bool = True + results: list[dict[str, Any]] = Field(default_factory=list) + + class AgentAssetRiskRuleSimulationAttachment(BaseModel): name: str = Field(default="", max_length=240) content_type: str | None = Field(default=None, max_length=120) diff --git a/server/tests/test_risk_rule_golden_evaluator.py b/server/tests/test_risk_rule_golden_evaluator.py new file mode 100644 index 0000000..6acf1ff --- /dev/null +++ b/server/tests/test_risk_rule_golden_evaluator.py @@ -0,0 +1,262 @@ +from __future__ import annotations + +from collections.abc import Generator +from datetime import datetime +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.pool import StaticPool + +from app.db.base import Base +from app.models.agent_asset import AgentAsset, AgentAssetTestRun +from app.models.employee import Employee +from app.models.financial_record import ExpenseClaim +from app.models.golden_case import GoldenCase +from app.services.risk_rule_golden_evaluator import ( + GoldenEvalReport, + RiskRuleGoldenEvaluator, + _aggregate, + _run_single_case, +) + + +def _build_session() -> Session: + engine = create_engine( + "sqlite+pysqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + Base.metadata.create_all(bind=engine) + factory = sessionmaker(bind=engine, autoflush=False, autocommit=False) + return factory() + + +def _keyword_manifest() -> dict: + """一个简单的 keyword_match_v1 manifest:reason 含"虚假"则命中。""" + + return { + "rule_code": "risk.test.keyword", + "template_key": "keyword_match_v1", + "inputs": { + "fields": [ + {"key": "claim.reason", "label": "事由", "type": "text", "source": "claim"}, + ] + }, + "params": { + "keywords": ["虚假"], + "field_keys": ["claim.reason"], + "search_fields": ["claim.reason"], + }, + "outcomes": {"fail": {"severity": "high", "risk_score": 80}}, + } + + +def _golden_case( + case_key: str, + *, + reason: str, + expected_hit: bool, + rule_code: str = "risk.test.keyword", +) -> GoldenCase: + return GoldenCase( + case_key=case_key, + rule_code=rule_code, + name=f"case-{case_key}", + values_json={"claim.reason": reason}, + expected_hit=expected_hit, + status="active", + ) + + +def test_run_single_case_hit_matches() -> None: + result = _run_single_case( + _keyword_manifest(), + values={"claim.reason": "虚假发票报销"}, + expected_hit=True, + expected_severity="high", + ) + assert result.actual_hit is True + assert result.passed is True + assert result.actual_severity == "high" + + +def test_run_single_case_no_hit_matches() -> None: + result = _run_single_case( + _keyword_manifest(), + values={"claim.reason": "正常差旅报销"}, + expected_hit=False, + expected_severity="", + ) + assert result.actual_hit is False + assert result.passed is True + + +def test_run_single_case_mismatch_fails() -> None: + result = _run_single_case( + _keyword_manifest(), + values={"claim.reason": "虚假发票"}, + expected_hit=False, # 期望不命中,但实际命中 + expected_severity="", + ) + assert result.actual_hit is True + assert result.passed is False + + +def test_run_single_case_severity_mismatch_fails() -> None: + result = _run_single_case( + _keyword_manifest(), + values={"claim.reason": "虚假发票"}, + expected_hit=True, + expected_severity="critical", # 实际是 high + ) + assert result.passed is False + + +def test_aggregate_empty_returns_passed() -> None: + report = _aggregate([]) + assert report.total == 0 + assert report.all_passed is True + assert report.accuracy == 0.0 + + +def test_aggregate_all_passed() -> None: + from app.services.risk_rule_golden_evaluator import GoldenCaseResult + + results = [ + GoldenCaseResult("1", "a", True, True, "high", "high", True), + GoldenCaseResult("2", "b", False, False, "", "none", True), + ] + report = _aggregate(results) + assert report.total == 2 + assert report.passed_count == 2 + assert report.accuracy == 1.0 + assert report.all_passed is True + + +def test_aggregate_with_failure() -> None: + from app.services.risk_rule_golden_evaluator import GoldenCaseResult + + results = [ + GoldenCaseResult("1", "a", True, True, "high", "high", True), + GoldenCaseResult("2", "b", True, False, "high", "none", False), # FP + ] + report = _aggregate(results) + assert report.passed_count == 1 + assert report.failed_count == 1 + assert report.accuracy == 0.5 + assert report.all_passed is False + assert report.precision == 0.5 # 1/(1+1) + + +def test_evaluate_for_rule_empty_returns_passed() -> None: + with _build_session() as db: + report = RiskRuleGoldenEvaluator().evaluate_for_rule(db, _keyword_manifest(), "risk.test.keyword") + assert report.total == 0 + assert report.all_passed is True + + +def test_evaluate_for_rule_all_pass() -> None: + with _build_session() as db: + db.add(_golden_case("g1", reason="虚假发票", expected_hit=True)) + db.add(_golden_case("g2", reason="正常报销", expected_hit=False)) + db.commit() + report = RiskRuleGoldenEvaluator().evaluate_for_rule(db, _keyword_manifest(), "risk.test.keyword") + assert report.total == 2 + assert report.all_passed is True + assert report.accuracy == 1.0 + + +def test_evaluate_for_rule_with_failure() -> None: + with _build_session() as db: + db.add(_golden_case("g1", reason="虚假发票", expected_hit=False)) # 期望不命中但实际命中 + db.add(_golden_case("g2", reason="正常报销", expected_hit=True)) # 期望命中但实际不命中 + db.commit() + report = RiskRuleGoldenEvaluator().evaluate_for_rule(db, _keyword_manifest(), "risk.test.keyword") + assert report.total == 2 + assert report.all_passed is False + assert report.failed_count == 2 + + +def _asset(asset_id: str, code: str) -> AgentAsset: + return AgentAsset( + id=asset_id, + code=code, + name=code, + asset_type="rule", + domain="expense", + owner="tester", + status="review", + working_version="v1", + ) + + +def test_require_pass_passes_when_all_green() -> None: + with _build_session() as db: + asset = _asset("a1", "R1") + db.add(asset) + db.add(_golden_case("g1", reason="虚假", expected_hit=True)) + db.commit() + report = RiskRuleGoldenEvaluator().require_pass( + db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester" + ) + assert report.all_passed is True + # 应写一条 test_type='golden' 记录 + run = db.query(AgentAssetTestRun).filter_by(asset_id="a1", test_type="golden").one() + assert run.passed is True + + +def test_require_pass_raises_on_failure() -> None: + with _build_session() as db: + asset = _asset("a2", "R2") + db.add(asset) + db.add(_golden_case("g1", reason="虚假", expected_hit=False)) # 会失败 + db.commit() + with pytest.raises(PermissionError): + RiskRuleGoldenEvaluator().require_pass( + db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester" + ) + run = db.query(AgentAssetTestRun).filter_by(asset_id="a2", test_type="golden").one() + assert run.passed is False + + +def test_require_pass_empty_golden_set_passes() -> None: + with _build_session() as db: + asset = _asset("a3", "R3") + db.add(asset) + db.commit() + report = RiskRuleGoldenEvaluator().require_pass( + db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester" + ) + assert report.total == 0 + assert report.all_passed is True + + +def test_require_pass_respects_feature_flag(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("GOLDEN_SET_GATE_ENABLED", "false") + with _build_session() as db: + asset = _asset("a4", "R4") + db.add(asset) + db.add(_golden_case("g1", reason="虚假", expected_hit=False)) # 本应失败 + db.commit() + # 门禁关闭,应放行不抛异常 + report = RiskRuleGoldenEvaluator().require_pass( + db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester" + ) + assert report.total == 0 + + +def test_require_pass_swallows_evaluator_exception() -> None: + with _build_session() as db: + asset = _asset("a5", "R5") + db.add(asset) + db.commit() + evaluator = RiskRuleGoldenEvaluator() + with patch.object(evaluator, "evaluate_for_rule", side_effect=RuntimeError("boom")): + report = evaluator.require_pass( + db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester" + ) + assert report.total == 0 + assert report.all_passed is True # 降级放行