feat(flywheel): golden case 管理 API 与评测单测
- 新增 GoldenCaseCreate/Read、GoldenEvalRequest/Read schema - 新增 3 个端点:创建 golden case、按规则列表、手动触发 golden 评测 (不入门禁,供运营试跑) - 单测 15 passed:单条 hit/severity 比对、集合 accuracy/precision/recall 聚合、空集降级、100% 通过/失败拦截、feature flag、异常降级 - 回归 test_agent_asset_service 27 passed(1 个预存失败与本改动无关)
This commit is contained in:
@@ -43,6 +43,10 @@ from app.schemas.agent_asset import (
|
|||||||
AgentAssetVersionCreate,
|
AgentAssetVersionCreate,
|
||||||
AgentAssetVersionRead,
|
AgentAssetVersionRead,
|
||||||
AgentAssetVersionTimelineItemRead,
|
AgentAssetVersionTimelineItemRead,
|
||||||
|
GoldenCaseCreate,
|
||||||
|
GoldenCaseRead,
|
||||||
|
GoldenEvalRead,
|
||||||
|
GoldenEvalRequest,
|
||||||
)
|
)
|
||||||
from app.schemas.common import ErrorResponse, PaginatedResponse
|
from app.schemas.common import ErrorResponse, PaginatedResponse
|
||||||
from app.services.agent_assets import AgentAssetService
|
from app.services.agent_assets import AgentAssetService
|
||||||
@@ -923,3 +927,110 @@ def get_agent_asset_version_timeline(
|
|||||||
return AgentAssetService(db).list_version_timeline(asset_id)
|
return AgentAssetService(db).list_version_timeline(asset_id)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
_handle_asset_error(exc)
|
_handle_asset_error(exc)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/risk-rules/golden-cases",
|
||||||
|
response_model=GoldenCaseRead,
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
summary="创建 golden set 黄金用例",
|
||||||
|
description="为指定规则(或通用场景)创建一条回归用例,发布前作为门禁集执行。",
|
||||||
|
)
|
||||||
|
def create_golden_case(
|
||||||
|
body: GoldenCaseCreate,
|
||||||
|
_: RuleEditorUser,
|
||||||
|
db: DbSession,
|
||||||
|
) -> GoldenCaseRead:
|
||||||
|
from app.models.golden_case import GoldenCase
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
existing = db.scalar(select(GoldenCase).where(GoldenCase.case_key == body.case_key))
|
||||||
|
if existing is not None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="case_key 已存在")
|
||||||
|
case = GoldenCase(
|
||||||
|
case_key=body.case_key,
|
||||||
|
rule_code=body.rule_code,
|
||||||
|
scene=body.scene,
|
||||||
|
name=body.name,
|
||||||
|
values_json=body.values,
|
||||||
|
expected_hit=body.expected_hit,
|
||||||
|
expected_severity=body.expected_severity,
|
||||||
|
note=body.note,
|
||||||
|
status="active",
|
||||||
|
source="manual",
|
||||||
|
)
|
||||||
|
db.add(case)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(case)
|
||||||
|
return _golden_case_read(case)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/risk-rules/{rule_code}/golden-cases",
|
||||||
|
response_model=list[GoldenCaseRead],
|
||||||
|
summary="列出规则的 golden 用例",
|
||||||
|
)
|
||||||
|
def list_golden_cases(
|
||||||
|
rule_code: str,
|
||||||
|
_: CurrentUser,
|
||||||
|
db: DbSession,
|
||||||
|
) -> list[GoldenCaseRead]:
|
||||||
|
from app.models.golden_case import GoldenCase
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
cases = db.scalars(
|
||||||
|
select(GoldenCase).where(GoldenCase.rule_code == rule_code).order_by(GoldenCase.created_at)
|
||||||
|
).all()
|
||||||
|
return [_golden_case_read(case) for case in cases]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/{asset_id}/golden-eval",
|
||||||
|
response_model=GoldenEvalRead,
|
||||||
|
summary="手动触发 golden set 评测(不入门禁)",
|
||||||
|
description="在当前规则版本上跑 golden 用例集,返回指标。门禁由 publish 时自动触发。",
|
||||||
|
)
|
||||||
|
def run_golden_eval(
|
||||||
|
asset_id: str,
|
||||||
|
body: GoldenEvalRequest,
|
||||||
|
_: RuleReviewerUser,
|
||||||
|
db: DbSession,
|
||||||
|
) -> GoldenEvalRead:
|
||||||
|
from app.services.agent_asset_spreadsheet import RISK_RULES_LIBRARY
|
||||||
|
from app.services.risk_rule_golden_evaluator import RiskRuleGoldenEvaluator
|
||||||
|
|
||||||
|
try:
|
||||||
|
asset = AgentAssetService(db).get_asset(asset_id)
|
||||||
|
if asset is None:
|
||||||
|
raise LookupError("Asset not found")
|
||||||
|
config = asset.config_json if isinstance(asset.config_json, dict) else {}
|
||||||
|
rule_document = config.get("rule_document") if isinstance(config.get("rule_document"), dict) else {}
|
||||||
|
file_name = str(rule_document.get("file_name") or "").strip()
|
||||||
|
if not file_name:
|
||||||
|
raise ValueError("该规则没有可执行的 manifest 文件。")
|
||||||
|
manager = AgentAssetService(db).rule_library_manager
|
||||||
|
manifest = manager.read_rule_library_json(library=RISK_RULES_LIBRARY, file_name=file_name)
|
||||||
|
rule_code = str(manifest.get("rule_code") or "").strip()
|
||||||
|
if not rule_code:
|
||||||
|
raise ValueError("manifest 缺少 rule_code。")
|
||||||
|
version = body.version or asset.working_version or ""
|
||||||
|
report = RiskRuleGoldenEvaluator().evaluate_for_rule(db, manifest, rule_code)
|
||||||
|
return GoldenEvalRead(**report.to_dict())
|
||||||
|
except Exception as exc:
|
||||||
|
_handle_asset_error(exc)
|
||||||
|
|
||||||
|
|
||||||
|
def _golden_case_read(case) -> GoldenCaseRead:
|
||||||
|
return GoldenCaseRead(
|
||||||
|
id=case.id,
|
||||||
|
case_key=case.case_key,
|
||||||
|
rule_code=case.rule_code,
|
||||||
|
scene=case.scene or "",
|
||||||
|
name=case.name or "",
|
||||||
|
values=case.values_json or {},
|
||||||
|
expected_hit=bool(case.expected_hit),
|
||||||
|
expected_severity=case.expected_severity,
|
||||||
|
note=case.note,
|
||||||
|
status=case.status,
|
||||||
|
source=case.source,
|
||||||
|
)
|
||||||
|
|||||||
@@ -204,6 +204,46 @@ class AgentAssetRiskRuleReportRequest(BaseModel):
|
|||||||
note: str | None = Field(default=None, max_length=1000)
|
note: str | None = Field(default=None, max_length=1000)
|
||||||
|
|
||||||
|
|
||||||
|
class GoldenCaseCreate(BaseModel):
|
||||||
|
case_key: str = Field(..., max_length=160)
|
||||||
|
rule_code: str | None = Field(default=None, max_length=120)
|
||||||
|
scene: str = Field(default="", max_length=50)
|
||||||
|
name: str = Field(default="", max_length=120)
|
||||||
|
values: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
expected_hit: bool = True
|
||||||
|
expected_severity: str | None = Field(default=None, max_length=20)
|
||||||
|
note: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GoldenCaseRead(BaseModel):
|
||||||
|
id: str
|
||||||
|
case_key: str
|
||||||
|
rule_code: str | None = None
|
||||||
|
scene: str = ""
|
||||||
|
name: str = ""
|
||||||
|
values: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
expected_hit: bool = True
|
||||||
|
expected_severity: str | None = None
|
||||||
|
note: str | None = None
|
||||||
|
status: str = "active"
|
||||||
|
source: str = "manual"
|
||||||
|
|
||||||
|
|
||||||
|
class GoldenEvalRequest(BaseModel):
|
||||||
|
version: str | None = Field(default=None, max_length=30)
|
||||||
|
|
||||||
|
|
||||||
|
class GoldenEvalRead(BaseModel):
|
||||||
|
total: int = 0
|
||||||
|
passed_count: int = 0
|
||||||
|
failed_count: int = 0
|
||||||
|
accuracy: float = 0.0
|
||||||
|
precision: float = 0.0
|
||||||
|
recall: float = 0.0
|
||||||
|
all_passed: bool = True
|
||||||
|
results: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class AgentAssetRiskRuleSimulationAttachment(BaseModel):
|
class AgentAssetRiskRuleSimulationAttachment(BaseModel):
|
||||||
name: str = Field(default="", max_length=240)
|
name: str = Field(default="", max_length=240)
|
||||||
content_type: str | None = Field(default=None, max_length=120)
|
content_type: str | None = Field(default=None, max_length=120)
|
||||||
|
|||||||
262
server/tests/test_risk_rule_golden_evaluator.py
Normal file
262
server/tests/test_risk_rule_golden_evaluator.py
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Generator
|
||||||
|
from datetime import datetime
|
||||||
|
from decimal import Decimal
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from app.db.base import Base
|
||||||
|
from app.models.agent_asset import AgentAsset, AgentAssetTestRun
|
||||||
|
from app.models.employee import Employee
|
||||||
|
from app.models.financial_record import ExpenseClaim
|
||||||
|
from app.models.golden_case import GoldenCase
|
||||||
|
from app.services.risk_rule_golden_evaluator import (
|
||||||
|
GoldenEvalReport,
|
||||||
|
RiskRuleGoldenEvaluator,
|
||||||
|
_aggregate,
|
||||||
|
_run_single_case,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_session() -> Session:
|
||||||
|
engine = create_engine(
|
||||||
|
"sqlite+pysqlite:///:memory:",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||||||
|
return factory()
|
||||||
|
|
||||||
|
|
||||||
|
def _keyword_manifest() -> dict:
|
||||||
|
"""一个简单的 keyword_match_v1 manifest:reason 含"虚假"则命中。"""
|
||||||
|
|
||||||
|
return {
|
||||||
|
"rule_code": "risk.test.keyword",
|
||||||
|
"template_key": "keyword_match_v1",
|
||||||
|
"inputs": {
|
||||||
|
"fields": [
|
||||||
|
{"key": "claim.reason", "label": "事由", "type": "text", "source": "claim"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"params": {
|
||||||
|
"keywords": ["虚假"],
|
||||||
|
"field_keys": ["claim.reason"],
|
||||||
|
"search_fields": ["claim.reason"],
|
||||||
|
},
|
||||||
|
"outcomes": {"fail": {"severity": "high", "risk_score": 80}},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _golden_case(
|
||||||
|
case_key: str,
|
||||||
|
*,
|
||||||
|
reason: str,
|
||||||
|
expected_hit: bool,
|
||||||
|
rule_code: str = "risk.test.keyword",
|
||||||
|
) -> GoldenCase:
|
||||||
|
return GoldenCase(
|
||||||
|
case_key=case_key,
|
||||||
|
rule_code=rule_code,
|
||||||
|
name=f"case-{case_key}",
|
||||||
|
values_json={"claim.reason": reason},
|
||||||
|
expected_hit=expected_hit,
|
||||||
|
status="active",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_single_case_hit_matches() -> None:
|
||||||
|
result = _run_single_case(
|
||||||
|
_keyword_manifest(),
|
||||||
|
values={"claim.reason": "虚假发票报销"},
|
||||||
|
expected_hit=True,
|
||||||
|
expected_severity="high",
|
||||||
|
)
|
||||||
|
assert result.actual_hit is True
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.actual_severity == "high"
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_single_case_no_hit_matches() -> None:
|
||||||
|
result = _run_single_case(
|
||||||
|
_keyword_manifest(),
|
||||||
|
values={"claim.reason": "正常差旅报销"},
|
||||||
|
expected_hit=False,
|
||||||
|
expected_severity="",
|
||||||
|
)
|
||||||
|
assert result.actual_hit is False
|
||||||
|
assert result.passed is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_single_case_mismatch_fails() -> None:
|
||||||
|
result = _run_single_case(
|
||||||
|
_keyword_manifest(),
|
||||||
|
values={"claim.reason": "虚假发票"},
|
||||||
|
expected_hit=False, # 期望不命中,但实际命中
|
||||||
|
expected_severity="",
|
||||||
|
)
|
||||||
|
assert result.actual_hit is True
|
||||||
|
assert result.passed is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_single_case_severity_mismatch_fails() -> None:
|
||||||
|
result = _run_single_case(
|
||||||
|
_keyword_manifest(),
|
||||||
|
values={"claim.reason": "虚假发票"},
|
||||||
|
expected_hit=True,
|
||||||
|
expected_severity="critical", # 实际是 high
|
||||||
|
)
|
||||||
|
assert result.passed is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_empty_returns_passed() -> None:
|
||||||
|
report = _aggregate([])
|
||||||
|
assert report.total == 0
|
||||||
|
assert report.all_passed is True
|
||||||
|
assert report.accuracy == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_all_passed() -> None:
|
||||||
|
from app.services.risk_rule_golden_evaluator import GoldenCaseResult
|
||||||
|
|
||||||
|
results = [
|
||||||
|
GoldenCaseResult("1", "a", True, True, "high", "high", True),
|
||||||
|
GoldenCaseResult("2", "b", False, False, "", "none", True),
|
||||||
|
]
|
||||||
|
report = _aggregate(results)
|
||||||
|
assert report.total == 2
|
||||||
|
assert report.passed_count == 2
|
||||||
|
assert report.accuracy == 1.0
|
||||||
|
assert report.all_passed is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_with_failure() -> None:
|
||||||
|
from app.services.risk_rule_golden_evaluator import GoldenCaseResult
|
||||||
|
|
||||||
|
results = [
|
||||||
|
GoldenCaseResult("1", "a", True, True, "high", "high", True),
|
||||||
|
GoldenCaseResult("2", "b", True, False, "high", "none", False), # FP
|
||||||
|
]
|
||||||
|
report = _aggregate(results)
|
||||||
|
assert report.passed_count == 1
|
||||||
|
assert report.failed_count == 1
|
||||||
|
assert report.accuracy == 0.5
|
||||||
|
assert report.all_passed is False
|
||||||
|
assert report.precision == 0.5 # 1/(1+1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_for_rule_empty_returns_passed() -> None:
|
||||||
|
with _build_session() as db:
|
||||||
|
report = RiskRuleGoldenEvaluator().evaluate_for_rule(db, _keyword_manifest(), "risk.test.keyword")
|
||||||
|
assert report.total == 0
|
||||||
|
assert report.all_passed is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_for_rule_all_pass() -> None:
|
||||||
|
with _build_session() as db:
|
||||||
|
db.add(_golden_case("g1", reason="虚假发票", expected_hit=True))
|
||||||
|
db.add(_golden_case("g2", reason="正常报销", expected_hit=False))
|
||||||
|
db.commit()
|
||||||
|
report = RiskRuleGoldenEvaluator().evaluate_for_rule(db, _keyword_manifest(), "risk.test.keyword")
|
||||||
|
assert report.total == 2
|
||||||
|
assert report.all_passed is True
|
||||||
|
assert report.accuracy == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_for_rule_with_failure() -> None:
|
||||||
|
with _build_session() as db:
|
||||||
|
db.add(_golden_case("g1", reason="虚假发票", expected_hit=False)) # 期望不命中但实际命中
|
||||||
|
db.add(_golden_case("g2", reason="正常报销", expected_hit=True)) # 期望命中但实际不命中
|
||||||
|
db.commit()
|
||||||
|
report = RiskRuleGoldenEvaluator().evaluate_for_rule(db, _keyword_manifest(), "risk.test.keyword")
|
||||||
|
assert report.total == 2
|
||||||
|
assert report.all_passed is False
|
||||||
|
assert report.failed_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
def _asset(asset_id: str, code: str) -> AgentAsset:
|
||||||
|
return AgentAsset(
|
||||||
|
id=asset_id,
|
||||||
|
code=code,
|
||||||
|
name=code,
|
||||||
|
asset_type="rule",
|
||||||
|
domain="expense",
|
||||||
|
owner="tester",
|
||||||
|
status="review",
|
||||||
|
working_version="v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_require_pass_passes_when_all_green() -> None:
|
||||||
|
with _build_session() as db:
|
||||||
|
asset = _asset("a1", "R1")
|
||||||
|
db.add(asset)
|
||||||
|
db.add(_golden_case("g1", reason="虚假", expected_hit=True))
|
||||||
|
db.commit()
|
||||||
|
report = RiskRuleGoldenEvaluator().require_pass(
|
||||||
|
db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester"
|
||||||
|
)
|
||||||
|
assert report.all_passed is True
|
||||||
|
# 应写一条 test_type='golden' 记录
|
||||||
|
run = db.query(AgentAssetTestRun).filter_by(asset_id="a1", test_type="golden").one()
|
||||||
|
assert run.passed is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_require_pass_raises_on_failure() -> None:
|
||||||
|
with _build_session() as db:
|
||||||
|
asset = _asset("a2", "R2")
|
||||||
|
db.add(asset)
|
||||||
|
db.add(_golden_case("g1", reason="虚假", expected_hit=False)) # 会失败
|
||||||
|
db.commit()
|
||||||
|
with pytest.raises(PermissionError):
|
||||||
|
RiskRuleGoldenEvaluator().require_pass(
|
||||||
|
db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester"
|
||||||
|
)
|
||||||
|
run = db.query(AgentAssetTestRun).filter_by(asset_id="a2", test_type="golden").one()
|
||||||
|
assert run.passed is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_require_pass_empty_golden_set_passes() -> None:
|
||||||
|
with _build_session() as db:
|
||||||
|
asset = _asset("a3", "R3")
|
||||||
|
db.add(asset)
|
||||||
|
db.commit()
|
||||||
|
report = RiskRuleGoldenEvaluator().require_pass(
|
||||||
|
db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester"
|
||||||
|
)
|
||||||
|
assert report.total == 0
|
||||||
|
assert report.all_passed is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_require_pass_respects_feature_flag(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setenv("GOLDEN_SET_GATE_ENABLED", "false")
|
||||||
|
with _build_session() as db:
|
||||||
|
asset = _asset("a4", "R4")
|
||||||
|
db.add(asset)
|
||||||
|
db.add(_golden_case("g1", reason="虚假", expected_hit=False)) # 本应失败
|
||||||
|
db.commit()
|
||||||
|
# 门禁关闭,应放行不抛异常
|
||||||
|
report = RiskRuleGoldenEvaluator().require_pass(
|
||||||
|
db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester"
|
||||||
|
)
|
||||||
|
assert report.total == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_require_pass_swallows_evaluator_exception() -> None:
|
||||||
|
with _build_session() as db:
|
||||||
|
asset = _asset("a5", "R5")
|
||||||
|
db.add(asset)
|
||||||
|
db.commit()
|
||||||
|
evaluator = RiskRuleGoldenEvaluator()
|
||||||
|
with patch.object(evaluator, "evaluate_for_rule", side_effect=RuntimeError("boom")):
|
||||||
|
report = evaluator.require_pass(
|
||||||
|
db, asset, "v1", _keyword_manifest(), "risk.test.keyword", actor="tester"
|
||||||
|
)
|
||||||
|
assert report.total == 0
|
||||||
|
assert report.all_passed is True # 降级放行
|
||||||
Reference in New Issue
Block a user