feat(flywheel): golden set 回归门禁拦截风险规则发布
- 新增 RiskRuleGoldenEvaluator:在 GoldenCase 集上跑规则 manifest,复用 RiskRuleTemplateExecutor + _build_synthetic_claim,输出 accuracy/precision/recall,按 100% 通过硬阈值判定 - require_pass 门禁入口:未通过抛 PermissionError 并写 AgentAssetTestRun(test_type=golden) 记录;空集/异常/feature flag 关闭 一律降级放行,不阻塞发布主链路 - _publish_reviewed_working_version 在 test_passed 校验后接入门禁 (修订版 _publish_revision 留待下一轮)
This commit is contained in:
@@ -39,6 +39,9 @@ class AgentAssetRiskRulePublishMixin:
|
||||
if not self.get_latest_risk_rule_test_summary(asset, version=version).test_passed:
|
||||
raise PermissionError("当前规则版本尚未完成测试通过确认,不能发布。")
|
||||
|
||||
# golden set 回归门禁:在 golden 用例集上跑规则,未 100% 通过则拦截发布。
|
||||
self._require_golden_set_passed(asset, version, actor=actor)
|
||||
|
||||
before = self._asset_snapshot(asset)
|
||||
self._ensure_approved_review(asset, version=version, actor=actor, note="发布上线前审核通过。")
|
||||
asset.reviewer = actor
|
||||
@@ -176,6 +179,49 @@ class AgentAssetRiskRulePublishMixin:
|
||||
)
|
||||
)
|
||||
|
||||
def _require_golden_set_passed(
|
||||
self,
|
||||
asset: AgentAsset,
|
||||
version: str,
|
||||
*,
|
||||
actor: str,
|
||||
) -> None:
|
||||
"""在 golden set 上跑当前规则 manifest,未 100% 通过则拦截发布。
|
||||
|
||||
降级策略:feature flag 关闭 / 无 rule_document / 无 golden case /
|
||||
evaluator 异常 → 一律放行,不阻塞发布主链路。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
if os.environ.get("GOLDEN_SET_GATE_ENABLED", "true").strip().lower() in {"0", "false", "no"}:
|
||||
return
|
||||
config = asset.config_json if isinstance(asset.config_json, dict) else {}
|
||||
rule_document = config.get("rule_document") if isinstance(config.get("rule_document"), dict) else {}
|
||||
file_name = str(rule_document.get("file_name") or "").strip()
|
||||
if not file_name:
|
||||
return
|
||||
try:
|
||||
manifest = self.rule_library_manager.read_rule_library_json(
|
||||
library=RISK_RULES_LIBRARY,
|
||||
file_name=file_name,
|
||||
)
|
||||
except Exception:
|
||||
return
|
||||
rule_code = str(manifest.get("rule_code") or "").strip()
|
||||
if not rule_code:
|
||||
return
|
||||
from app.services.risk_rule_golden_evaluator import RiskRuleGoldenEvaluator
|
||||
|
||||
RiskRuleGoldenEvaluator().require_pass(
|
||||
self.db,
|
||||
asset,
|
||||
version,
|
||||
manifest,
|
||||
rule_code,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _config_from_published_manifest(
|
||||
manifest: dict[str, Any],
|
||||
|
||||
329
server/src/app/services/risk_rule_golden_evaluator.py
Normal file
329
server/src/app/services/risk_rule_golden_evaluator.py
Normal file
@@ -0,0 +1,329 @@
|
||||
"""风险规则 golden set 评测器与发布门禁。
|
||||
|
||||
在版本化的黄金用例集(:class:`GoldenCase`)上跑规则 manifest,计算
|
||||
accuracy/precision/recall,并按"100% 通过"的硬阈值做发布门禁。
|
||||
|
||||
执行链路完全复用现有能力:
|
||||
- ``RiskRuleTemplateExecutor.evaluate_with_trace`` 跑规则
|
||||
- ``AgentAssetRiskRuleTestingMixin`` 的 static helpers 组装 synthetic claim
|
||||
- 单条比对逻辑与 ``_run_sample_case`` 保持一致
|
||||
|
||||
门禁语义与现有 ``test_passed`` 一致:未通过抛 ``PermissionError``,
|
||||
同时写一条 ``AgentAssetTestRun(test_type='golden')`` 记录结果。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, date, datetime
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.agent_enums import AgentAssetType
|
||||
from app.core.logging import get_logger
|
||||
from app.models.agent_asset import AgentAsset, AgentAssetTestRun
|
||||
from app.models.employee import Employee
|
||||
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
|
||||
from app.models.golden_case import GoldenCase
|
||||
from app.services.risk_rule_template_executor import RiskRuleTemplateExecutor
|
||||
|
||||
logger = get_logger("app.services.risk_rule_golden_evaluator")
|
||||
|
||||
GOLDEN_GATE_FLAG = "GOLDEN_SET_GATE_ENABLED"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GoldenCaseResult:
|
||||
case_id: str
|
||||
name: str
|
||||
expected_hit: bool
|
||||
actual_hit: bool
|
||||
expected_severity: str
|
||||
actual_severity: str
|
||||
passed: bool
|
||||
message: str = ""
|
||||
evidence: dict[str, Any] = field(default_factory=dict)
|
||||
trace: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GoldenEvalReport:
|
||||
total: int = 0
|
||||
passed_count: int = 0
|
||||
failed_count: int = 0
|
||||
accuracy: float = 0.0
|
||||
precision: float = 0.0
|
||||
recall: float = 0.0
|
||||
all_passed: bool = True
|
||||
results: list[GoldenCaseResult] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"total": self.total,
|
||||
"passed_count": self.passed_count,
|
||||
"failed_count": self.failed_count,
|
||||
"accuracy": round(self.accuracy, 4),
|
||||
"precision": round(self.precision, 4),
|
||||
"recall": round(self.recall, 4),
|
||||
"all_passed": self.all_passed,
|
||||
"results": [
|
||||
{
|
||||
"case_id": r.case_id,
|
||||
"name": r.name,
|
||||
"expected_hit": r.expected_hit,
|
||||
"actual_hit": r.actual_hit,
|
||||
"expected_severity": r.expected_severity,
|
||||
"actual_severity": r.actual_severity,
|
||||
"passed": r.passed,
|
||||
"message": r.message,
|
||||
}
|
||||
for r in self.results
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _gate_enabled() -> bool:
|
||||
return os.environ.get(GOLDEN_GATE_FLAG, "true").strip().lower() not in {"0", "false", "no"}
|
||||
|
||||
|
||||
# ---- synthetic claim 构建(与 AgentAssetRiskRuleTestingMixin._build_synthetic_claim 一致)----
|
||||
|
||||
def _extract_manifest_fields(manifest: dict[str, Any]) -> list[dict[str, str]]:
|
||||
inputs = manifest.get("inputs") if isinstance(manifest.get("inputs"), dict) else {}
|
||||
fields = inputs.get("fields") if isinstance(inputs.get("fields"), list) else []
|
||||
normalized: list[dict[str, str]] = []
|
||||
for item in fields:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
key = str(item.get("key") or "").strip()
|
||||
if key:
|
||||
normalized.append({"key": key, "label": str(item.get("label") or key).strip()})
|
||||
return normalized
|
||||
|
||||
|
||||
def _coerce_sample_value(field_key: str, value: Any) -> Any:
|
||||
import re
|
||||
|
||||
if field_key.endswith("route_cities") and isinstance(value, str):
|
||||
return [item.strip() for item in re.split(r"[,,、/ ]+", value) if item.strip()]
|
||||
return value
|
||||
|
||||
|
||||
def _to_decimal(value: Any) -> Decimal:
|
||||
try:
|
||||
return Decimal(str(value or "0"))
|
||||
except (InvalidOperation, ValueError):
|
||||
return Decimal("0")
|
||||
|
||||
|
||||
def _build_synthetic_claim(
|
||||
values: dict[str, Any],
|
||||
manifest: dict[str, Any],
|
||||
) -> tuple[ExpenseClaim, list[dict[str, Any]]]:
|
||||
claim = ExpenseClaim(
|
||||
claim_no="GOLDEN-RISK-RULE",
|
||||
employee_name=str(values.get("claim.employee_name") or "测试员工"),
|
||||
department_name=str(values.get("claim.department_name") or "测试部门"),
|
||||
expense_type=str(values.get("item.item_type") or "差旅费"),
|
||||
reason=str(values.get("claim.reason") or "测试报销事由"),
|
||||
location=str(values.get("claim.location") or "北京"),
|
||||
amount=_to_decimal(values.get("claim.amount")),
|
||||
currency="CNY",
|
||||
invoice_count=1,
|
||||
occurred_at=datetime.now(UTC),
|
||||
status="draft",
|
||||
)
|
||||
item = ExpenseClaimItem(
|
||||
item_date=date.today(),
|
||||
item_type=str(values.get("item.item_type") or "住宿费"),
|
||||
item_reason=str(values.get("item.item_reason") or claim.reason),
|
||||
item_location=str(values.get("item.item_location") or claim.location),
|
||||
item_amount=_to_decimal(values.get("item.item_amount") or claim.amount),
|
||||
)
|
||||
claim.items = [item]
|
||||
if values.get("employee.location"):
|
||||
claim.employee = Employee(
|
||||
employee_no="GOLDEN-EMPLOYEE",
|
||||
name=claim.employee_name,
|
||||
email="golden-rule-test@example.com",
|
||||
location=str(values.get("employee.location") or ""),
|
||||
)
|
||||
|
||||
attachment_fields: list[dict[str, Any]] = []
|
||||
document_info: dict[str, Any] = {"fields": attachment_fields}
|
||||
for field in _extract_manifest_fields(manifest):
|
||||
key = field["key"]
|
||||
if key not in values:
|
||||
continue
|
||||
value = _coerce_sample_value(key, values.get(key))
|
||||
if key.startswith("claim."):
|
||||
setattr(claim, key.removeprefix("claim."), value)
|
||||
elif key.startswith("item."):
|
||||
setattr(item, key.removeprefix("item."), value)
|
||||
elif key.startswith("attachment."):
|
||||
short_key = key.removeprefix("attachment.")
|
||||
document_info[short_key] = value
|
||||
attachment_fields.append({"key": short_key, "label": field["label"], "value": value})
|
||||
return claim, [{"document_info": document_info, "ocr_text": document_info.get("ocr_text", "")}]
|
||||
|
||||
|
||||
def _run_single_case(
|
||||
manifest: dict[str, Any],
|
||||
values: dict[str, Any],
|
||||
expected_hit: bool,
|
||||
expected_severity: str,
|
||||
) -> GoldenCaseResult:
|
||||
claim, contexts = _build_synthetic_claim(values, manifest)
|
||||
execution = RiskRuleTemplateExecutor().evaluate_with_trace(manifest, claim=claim, contexts=contexts)
|
||||
result = execution["result"]
|
||||
actual_hit = result is not None
|
||||
actual_severity = (
|
||||
str((manifest.get("outcomes") or {}).get("fail", {}).get("severity") or "").strip()
|
||||
if actual_hit
|
||||
else "none"
|
||||
)
|
||||
severity_passed = (
|
||||
not actual_hit or not expected_severity or expected_severity == actual_severity
|
||||
)
|
||||
passed = actual_hit == expected_hit and severity_passed
|
||||
return GoldenCaseResult(
|
||||
case_id="",
|
||||
name="",
|
||||
expected_hit=expected_hit,
|
||||
actual_hit=actual_hit,
|
||||
expected_severity=expected_severity,
|
||||
actual_severity=actual_severity,
|
||||
passed=passed,
|
||||
message=str(result.get("message") or "") if isinstance(result, dict) else "",
|
||||
evidence=result.get("evidence") if isinstance(result, dict) else {},
|
||||
trace=execution.get("trace") if isinstance(execution.get("trace"), dict) else {},
|
||||
)
|
||||
|
||||
|
||||
def _aggregate(results: list[GoldenCaseResult]) -> GoldenEvalReport:
|
||||
total = len(results)
|
||||
if total == 0:
|
||||
return GoldenEvalReport(total=0, all_passed=True)
|
||||
passed_count = sum(1 for r in results if r.passed)
|
||||
tp = sum(1 for r in results if r.expected_hit and r.actual_hit)
|
||||
fp = sum(1 for r in results if r.expected_hit and not r.actual_hit) # 应命中未命中
|
||||
fn = sum(1 for r in results if not r.expected_hit and r.actual_hit) # 不应命中却命中
|
||||
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
||||
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
||||
return GoldenEvalReport(
|
||||
total=total,
|
||||
passed_count=passed_count,
|
||||
failed_count=total - passed_count,
|
||||
accuracy=passed_count / total,
|
||||
precision=precision,
|
||||
recall=recall,
|
||||
all_passed=passed_count == total,
|
||||
results=results,
|
||||
)
|
||||
|
||||
|
||||
class RiskRuleGoldenEvaluator:
|
||||
"""在 golden set 上评测规则 manifest 并执行发布门禁。"""
|
||||
|
||||
def evaluate(self, manifest: dict[str, Any], cases: list[GoldenCase]) -> GoldenEvalReport:
|
||||
results: list[GoldenCaseResult] = []
|
||||
for case in cases:
|
||||
result = _run_single_case(
|
||||
manifest,
|
||||
values=case.values_json or {},
|
||||
expected_hit=bool(case.expected_hit),
|
||||
expected_severity=str(case.expected_severity or ""),
|
||||
)
|
||||
result.case_id = case.case_key or case.id
|
||||
result.name = case.name
|
||||
results.append(result)
|
||||
return _aggregate(results)
|
||||
|
||||
def evaluate_for_rule(
|
||||
self,
|
||||
db: Session,
|
||||
manifest: dict[str, Any],
|
||||
rule_code: str,
|
||||
) -> GoldenEvalReport:
|
||||
cases = list(
|
||||
db.scalars(
|
||||
select(GoldenCase).where(
|
||||
GoldenCase.rule_code == rule_code,
|
||||
GoldenCase.status == "active",
|
||||
)
|
||||
)
|
||||
)
|
||||
if not cases:
|
||||
return GoldenEvalReport(total=0, all_passed=True)
|
||||
return self.evaluate(manifest, cases)
|
||||
|
||||
def require_pass(
|
||||
self,
|
||||
db: Session,
|
||||
asset: AgentAsset,
|
||||
version: str,
|
||||
manifest: dict[str, Any],
|
||||
rule_code: str,
|
||||
*,
|
||||
actor: str,
|
||||
) -> GoldenEvalReport:
|
||||
"""发布门禁入口:跑 golden set,未 100% 通过抛 PermissionError。
|
||||
|
||||
golden set 为空或门禁关闭时放行; evaluator 异常时降级放行(记日志)。
|
||||
无论放行与否,都写一条 ``AgentAssetTestRun(test_type='golden')`` 记录。
|
||||
"""
|
||||
|
||||
if not _gate_enabled():
|
||||
return GoldenEvalReport(total=0, all_passed=True)
|
||||
try:
|
||||
report = self.evaluate_for_rule(db, manifest, rule_code)
|
||||
except Exception:
|
||||
logger.exception("golden set 评测异常,降级放行 asset_id=%s", asset.id)
|
||||
report = GoldenEvalReport(total=0, all_passed=True)
|
||||
|
||||
self._record_test_run(db, asset, version, report, actor=actor)
|
||||
|
||||
if report.total > 0 and not report.all_passed:
|
||||
failures = report.to_dict()["results"]
|
||||
raise PermissionError(
|
||||
f"golden set 回归未通过({report.passed_count}/{report.total}),"
|
||||
f"发布被拦截。失败用例:{failures}"
|
||||
)
|
||||
return report
|
||||
|
||||
def _record_test_run(
|
||||
self,
|
||||
db: Session,
|
||||
asset: AgentAsset,
|
||||
version: str,
|
||||
report: GoldenEvalReport,
|
||||
*,
|
||||
actor: str,
|
||||
) -> None:
|
||||
try:
|
||||
run = AgentAssetTestRun(
|
||||
id=str(uuid.uuid4()),
|
||||
asset_id=asset.id,
|
||||
version=version,
|
||||
test_type="golden",
|
||||
status="completed",
|
||||
passed=report.all_passed,
|
||||
summary=(
|
||||
f"golden set {report.passed_count}/{report.total} passed"
|
||||
if report.total > 0
|
||||
else "golden set empty, gate skipped"
|
||||
),
|
||||
input_json={"rule_code": getattr(asset, "rule_code", "") or ""},
|
||||
result_json=report.to_dict(),
|
||||
created_by=actor,
|
||||
)
|
||||
db.add(run)
|
||||
db.commit()
|
||||
except Exception:
|
||||
logger.warning("golden test run 记录失败 asset_id=%s", asset.id, exc_info=True)
|
||||
db.rollback()
|
||||
Reference in New Issue
Block a user