from __future__ import annotations import re from datetime import UTC, date, datetime, timedelta from decimal import Decimal, InvalidOperation from typing import Any from sqlalchemy import or_, select from app.core.agent_enums import ( AgentAssetDomain, AgentAssetStatus, AgentAssetType, AgentReviewStatus, ) from app.models.agent_asset import AgentAsset, AgentAssetReview, AgentAssetTestRun from app.models.employee import Employee from app.models.financial_record import ExpenseClaim, ExpenseClaimItem from app.schemas.agent_asset import ( AgentAssetRiskRuleLatestTestSummary, AgentAssetRiskRuleReportRequest, AgentAssetRiskRuleSampleCase, AgentAssetRiskRuleSampleTestRequest, AgentAssetRiskRuleScenarioTestRequest, AgentAssetRiskRuleTestRunRead, ) from app.services.expense_claims import ExpenseClaimService from app.services.risk_rule_template_executor import RiskRuleTemplateExecutor from app.services.risk_rule_manifest_normalizer import normalize_risk_rule_manifest class AgentAssetRiskRuleTestingMixin: def get_latest_risk_rule_test_summary( self, asset_or_id: AgentAsset | str, *, version: str | None = None, ) -> AgentAssetRiskRuleLatestTestSummary: asset = self._resolve_asset(asset_or_id) target_version = self._resolve_target_version(asset, version) sample = self.repository.get_latest_test_run( asset.id, version=target_version, test_type="sample" ) scenario = self.repository.get_latest_test_run( asset.id, version=target_version, test_type="scenario" ) report = self.repository.get_latest_test_run( asset.id, version=target_version, test_type="report", status="passed" ) return AgentAssetRiskRuleLatestTestSummary( version=target_version, sample=self._serialize_test_run(sample), scenario=self._serialize_test_run(scenario), report=self._serialize_test_run(report), test_passed=bool(report and report.passed), ) def run_risk_rule_sample_test( self, asset_id: str, body: AgentAssetRiskRuleSampleTestRequest, *, actor: str, request_id: str | None = None, ) -> AgentAssetRiskRuleTestRunRead: asset, version, manifest = self._load_risk_rule_for_test(asset_id, body.version) cases = body.cases or self._build_default_sample_cases(manifest) results = [self._run_sample_case(manifest, case) for case in cases] passed = bool(results) and all(item["passed"] for item in results) summary = f"快速样例测试 {'通过' if passed else '未通过'},共 {len(results)} 条。" return self._create_test_run( asset, version=version, test_type="sample", passed=passed, summary=summary, input_json={"cases": [case.model_dump() for case in cases]}, result_json={"cases": results, "case_count": len(results)}, actor=actor, request_id=request_id, ) def run_risk_rule_scenario_test( self, asset_id: str, body: AgentAssetRiskRuleScenarioTestRequest, *, actor: str, request_id: str | None = None, ) -> AgentAssetRiskRuleTestRunRead: asset, version, manifest = self._load_risk_rule_for_test(asset_id, body.version) if asset.domain != AgentAssetDomain.EXPENSE.value: raise ValueError("一期真实场景试运行仅支持报销业务域。") parsed_scope = self._parse_scenario_scope(body.intent, body.filters) claims = self._query_expense_claim_samples(parsed_scope) claim_results = [self._run_claim_scenario(manifest, claim) for claim in claims] hit_items = [item for item in claim_results if item["hit"]] severity_counts: dict[str, int] = {} for item in hit_items: severity = str(item.get("severity") or "unknown") severity_counts[severity] = severity_counts.get(severity, 0) + 1 passed = bool(claim_results) summary = ( f"真实场景试运行完成,样本 {len(claim_results)} 条,命中 {len(hit_items)} 条。" if passed else "真实场景试运行未找到可测样本。" ) return self._create_test_run( asset, version=version, test_type="scenario", passed=passed, summary=summary, input_json={ "intent": body.intent, "filters": body.filters, "parsed_scope": parsed_scope, }, result_json={ "total_count": len(claim_results), "hit_count": len(hit_items), "severity_counts": severity_counts, "items": claim_results[:50], }, actor=actor, request_id=request_id, ) def confirm_risk_rule_test_report( self, asset_id: str, body: AgentAssetRiskRuleReportRequest, *, actor: str, request_id: str | None = None, ) -> AgentAssetRiskRuleTestRunRead: asset, version, _ = self._load_risk_rule_for_test(asset_id, body.version) sample = self.repository.get_latest_test_run( asset.id, version=version, test_type="sample", status="passed" ) scenario = self.repository.get_latest_test_run( asset.id, version=version, test_type="scenario" ) if sample is None: raise ValueError("提交审核前必须先完成快速样例测试。") if not body.confirm_passed: raise ValueError("请确认测试通过后再保存测试报告。") summary = "测试报告已确认,当前版本可上线。" if scenario is None: summary = "快速样例测试已确认通过,真实场景试运行未执行。" elif not scenario.passed: summary = "快速样例测试已确认通过,真实场景试运行未找到可测样本。" self._mark_risk_rule_operation(asset, action="test", actor=actor) return self._create_test_run( asset, version=version, test_type="report", passed=True, summary=summary, input_json={"confirm_passed": True, "note": body.note or ""}, result_json={ "sample_test_run_id": sample.id, "scenario_test_run_id": scenario.id if scenario else "", "sample_summary": sample.summary, "scenario_summary": scenario.summary if scenario else "", }, actor=actor, request_id=request_id, ) def delete_unpublished_asset( self, asset_id: str, *, actor: str, request_id: str | None = None, ) -> None: asset = self._resolve_asset(asset_id) self._require_json_risk_asset(asset) if str(asset.published_version or "").strip(): raise PermissionError("已发布过的风险规则不能删除。") before = self._asset_snapshot(asset) self._delete_risk_rule_json_file(asset) self.repository.delete_asset(asset) self.audit_service.log_action( actor=actor, action="delete_agent_asset", resource_type=AgentAssetType.RULE.value, resource_id=asset_id, before_json=before, after_json={"deleted": True}, request_id=request_id, ) def return_risk_rule( self, asset_id: str, *, note: str, actor: str, request_id: str | None = None, ) -> AgentAssetRiskRuleLatestTestSummary: asset = self._resolve_asset(asset_id) self._require_json_risk_asset(asset) version = self._resolve_target_version(asset, None) if asset.status != AgentAssetStatus.REVIEW.value: raise ValueError("只有待审核风险规则可以回退。") before = self._asset_snapshot(asset) review = AgentAssetReview( asset_id=asset.id, version=version, reviewer=actor, review_status=AgentReviewStatus.REJECTED.value, review_note=str(note or "审核回退").strip() or "审核回退", reviewed_at=datetime.now(UTC), ) self.db.add(review) asset.reviewer = actor asset.status = AgentAssetStatus.DRAFT.value self.db.add(asset) self.db.commit() self.audit_service.log_action( actor=actor, action="return_agent_asset", resource_type=AgentAssetType.RULE.value, resource_id=asset.id, before_json=before, after_json={"version": version, "status": asset.status, "note": note}, request_id=request_id, ) return self.get_latest_risk_rule_test_summary(asset) def publish_risk_rule( self, asset_id: str, *, actor: str, request_id: str | None = None, ) -> AgentAsset: asset = self._resolve_asset(asset_id) self._require_json_risk_asset(asset) version = self._resolve_target_version(asset, None) if asset.status != AgentAssetStatus.REVIEW.value: raise ValueError("只有待审核风险规则可以发布上线。") if not self.get_latest_risk_rule_test_summary(asset, version=version).test_passed: raise PermissionError("当前规则版本尚未完成测试通过确认,不能发布。") before = self._asset_snapshot(asset) approved_review = self.repository.get_review( asset.id, version, AgentReviewStatus.APPROVED.value ) if approved_review is None: self.db.add( AgentAssetReview( asset_id=asset.id, version=version, reviewer=actor, review_status=AgentReviewStatus.APPROVED.value, review_note="发布上线前审核通过。", reviewed_at=datetime.now(UTC), ) ) asset.reviewer = actor asset.published_version = version asset.status = AgentAssetStatus.ACTIVE.value self.db.add(asset) self.db.commit() self.audit_service.log_action( actor=actor, action="publish_agent_asset", resource_type=AgentAssetType.RULE.value, resource_id=asset.id, before_json=before, after_json=self._asset_snapshot(asset), request_id=request_id, ) refreshed = self.repository.get(asset.id) if refreshed is None: raise LookupError("Asset not found") return refreshed def set_risk_rule_enabled( self, asset_id: str, *, enabled: bool, actor: str, request_id: str | None = None, ) -> AgentAsset: asset = self._resolve_asset(asset_id) self._require_json_risk_asset(asset) before = self._asset_snapshot(asset) rule_library, file_name = self._resolve_json_risk_rule_document(asset) manifest = self.rule_library_manager.read_rule_library_json( library=rule_library, file_name=file_name, ) manifest["enabled"] = bool(enabled) self.rule_library_manager.write_rule_library_json( library=rule_library, file_name=file_name, payload=manifest, ) config_json = dict(asset.config_json or {}) config_json["enabled"] = bool(enabled) self._set_risk_rule_status_for_online_toggle(asset, enabled=enabled, actor=actor) config_json["last_operation"] = self._build_last_operation( action="online" if enabled else "offline", actor=actor, ) asset.config_json = config_json updated = self.repository.save_asset(asset) self.audit_service.log_action( actor=actor, action="set_risk_rule_enabled", resource_type=AgentAssetType.RULE.value, resource_id=asset.id, before_json=before, after_json=self._asset_snapshot(updated), request_id=request_id, ) return updated def _set_risk_rule_status_for_online_toggle( self, asset: AgentAsset, *, enabled: bool, actor: str, ) -> None: if enabled: version = self._resolve_target_version(asset, None) approved_review = self.repository.get_review( asset.id, version, AgentReviewStatus.APPROVED.value ) if approved_review is None: self.db.add( AgentAssetReview( asset_id=asset.id, version=version, reviewer=actor, review_status=AgentReviewStatus.APPROVED.value, review_note="直接上线风险规则。", reviewed_at=datetime.now(UTC), ) ) asset.published_version = version asset.reviewer = actor asset.status = AgentAssetStatus.ACTIVE.value return asset.status = AgentAssetStatus.DISABLED.value def _mark_risk_rule_operation(self, asset: AgentAsset, *, action: str, actor: str) -> None: config_json = dict(asset.config_json or {}) config_json["last_operation"] = self._build_last_operation(action=action, actor=actor) asset.config_json = config_json self.db.add(asset) @staticmethod def _build_last_operation(*, action: str, actor: str) -> dict[str, str]: return { "action": action, "actor": str(actor or "system").strip() or "system", "at": datetime.now(UTC).isoformat(), } def _load_risk_rule_for_test( self, asset_id: str, version: str | None ) -> tuple[AgentAsset, str, dict[str, Any]]: asset = self._resolve_asset(asset_id) self._require_json_risk_asset(asset) target_version = self._resolve_target_version(asset, version) if self.repository.get_version(asset.id, target_version) is None: raise LookupError(f"版本 {target_version} 不存在") rule_library, file_name = self._resolve_json_risk_rule_document(asset) manifest = self.rule_library_manager.read_rule_library_json( library=rule_library, file_name=file_name, ) manifest = normalize_risk_rule_manifest(manifest) return asset, target_version, manifest def _create_test_run( self, asset: AgentAsset, *, version: str, test_type: str, passed: bool, summary: str, input_json: dict[str, Any], result_json: dict[str, Any], actor: str, request_id: str | None, ) -> AgentAssetRiskRuleTestRunRead: status = "passed" if passed else "failed" created = self.repository.create_test_run( AgentAssetTestRun( asset_id=asset.id, version=version, test_type=test_type, status=status, passed=passed, summary=summary, input_json=input_json, result_json=result_json, created_by=actor, ) ) self.audit_service.log_action( actor=actor, action=f"risk_rule_test_{test_type}", resource_type=AgentAssetType.RULE.value, resource_id=asset.id, before_json=None, after_json={"version": version, "status": status, "summary": summary}, request_id=request_id, ) return AgentAssetRiskRuleTestRunRead.model_validate(created) def _run_sample_case( self, manifest: dict[str, Any], case: AgentAssetRiskRuleSampleCase, ) -> dict[str, Any]: claim, contexts = self._build_synthetic_claim(case.values, manifest) execution = RiskRuleTemplateExecutor().evaluate_with_trace(manifest, claim=claim, contexts=contexts) result = execution["result"] actual_hit = result is not None actual_severity = ( str((manifest.get("outcomes") or {}).get("fail", {}).get("severity") or "").strip() if actual_hit else "none" ) expected_severity = str(case.expected_severity or "").strip() severity_passed = ( not actual_hit or not expected_severity or expected_severity == actual_severity ) passed = actual_hit == case.expected_hit and severity_passed return { "case_id": case.case_id or "", "name": case.name, "values": case.values, "expected_hit": case.expected_hit, "expected_severity": expected_severity, "actual_hit": actual_hit, "actual_severity": actual_severity, "passed": passed, "message": str(result.get("message") or "") if isinstance(result, dict) else "", "evidence": result.get("evidence") if isinstance(result, dict) else {}, "trace": execution["trace"] if isinstance(execution.get("trace"), dict) else {}, } def _run_claim_scenario(self, manifest: dict[str, Any], claim: ExpenseClaim) -> dict[str, Any]: contexts = ExpenseClaimService(self.db)._build_claim_attachment_contexts(claim) execution = RiskRuleTemplateExecutor().evaluate_with_trace(manifest, claim=claim, contexts=contexts) result = execution["result"] hit = result is not None return { "claim_id": claim.id, "claim_no": claim.claim_no, "employee_name": claim.employee_name, "department_name": claim.department_name, "expense_type": claim.expense_type, "amount": float(claim.amount or 0), "status": claim.status, "occurred_at": claim.occurred_at.isoformat() if claim.occurred_at else "", "hit": hit, "severity": str((manifest.get("outcomes") or {}).get("fail", {}).get("severity") or "") if hit else "none", "message": str(result.get("message") or "") if isinstance(result, dict) else "", "evidence": result.get("evidence") if isinstance(result, dict) else {}, "trace": execution["trace"] if isinstance(execution.get("trace"), dict) else {}, } def _build_synthetic_claim( self, values: dict[str, Any], manifest: dict[str, Any], ) -> tuple[ExpenseClaim, list[dict[str, Any]]]: claim = ExpenseClaim( claim_no="TEST-RISK-RULE", employee_name=str(values.get("claim.employee_name") or "测试员工"), department_name=str(values.get("claim.department_name") or "测试部门"), expense_type=str(values.get("item.item_type") or "差旅费"), reason=str(values.get("claim.reason") or "测试报销事由"), location=str(values.get("claim.location") or "北京"), amount=self._to_decimal(values.get("claim.amount")), currency="CNY", invoice_count=1, occurred_at=datetime.now(UTC), status="draft", ) item = ExpenseClaimItem( item_date=date.today(), item_type=str(values.get("item.item_type") or "住宿费"), item_reason=str(values.get("item.item_reason") or claim.reason), item_location=str(values.get("item.item_location") or claim.location), item_amount=self._to_decimal(values.get("item.item_amount") or claim.amount), ) claim.items = [item] if values.get("employee.location"): claim.employee = Employee( employee_no="TEST-EMPLOYEE", name=claim.employee_name, email="risk-rule-test@example.com", location=str(values.get("employee.location") or ""), ) attachment_fields = [] document_info: dict[str, Any] = {"fields": attachment_fields} for field in self._extract_manifest_fields(manifest): key = field["key"] if key not in values: continue value = self._coerce_sample_value(key, values.get(key)) if key.startswith("claim."): setattr(claim, key.removeprefix("claim."), value) elif key.startswith("item."): setattr(item, key.removeprefix("item."), value) elif key.startswith("attachment."): short_key = key.removeprefix("attachment.") document_info[short_key] = value attachment_fields.append( {"key": short_key, "label": field["label"], "value": value} ) return claim, [ { "document_info": document_info, "ocr_text": document_info.get("ocr_text", ""), } ] def _build_default_sample_cases( self, manifest: dict[str, Any], ) -> list[AgentAssetRiskRuleSampleCase]: fields = self._extract_manifest_fields(manifest) severity = str((manifest.get("outcomes") or {}).get("fail", {}).get("severity") or "") template_key = str(manifest.get("template_key") or "").strip() hit_values = self._find_case_values_for_expected(manifest, fields, expected_hit=True) pass_values = self._find_case_values_for_expected(manifest, fields, expected_hit=False) cases = [ AgentAssetRiskRuleSampleCase( case_id="hit", name="应该命中风险", values=hit_values, expected_hit=True, expected_severity=severity, note="验证规则能识别异常样本。", ), AgentAssetRiskRuleSampleCase( case_id="pass", name="应该不命中", values=pass_values, expected_hit=False, expected_severity="none", note="验证正常样本不会误触发。", ), ] if template_key == "field_required_v1": cases.append( AgentAssetRiskRuleSampleCase( case_id="missing", name="关键字段缺失", values={key: "" for key in hit_values}, expected_hit=True, expected_severity=severity, note="验证缺字段时会进入复核。", ) ) return cases def _find_case_values_for_expected( self, manifest: dict[str, Any], fields: list[dict[str, str]], *, expected_hit: bool, ) -> dict[str, Any]: candidates = [ self._build_case_values(manifest, fields, hit=expected_hit), {field["key"]: self._default_value_for_field(field["key"]) for field in fields}, { field["key"]: ("上海" if index == 0 else "北京") for index, field in enumerate(fields) }, {field["key"]: "北京" for field in fields}, {field["key"]: "" for field in fields}, ] severity = str((manifest.get("outcomes") or {}).get("fail", {}).get("severity") or "") for values in candidates: probe = AgentAssetRiskRuleSampleCase( name="默认样例探测", values=values, expected_hit=expected_hit, expected_severity=severity if expected_hit else "none", ) result = self._run_sample_case(manifest, probe) if bool(result["actual_hit"]) == expected_hit: return values return candidates[0] def _build_case_values( self, manifest: dict[str, Any], fields: list[dict[str, str]], *, hit: bool, ) -> dict[str, Any]: values = {field["key"]: self._default_value_for_field(field["key"]) for field in fields} template_key = str(manifest.get("template_key") or "").strip() params = manifest.get("params") if isinstance(manifest.get("params"), dict) else {} if template_key == "field_compare_v1": if str(params.get("semantic_type") or "").strip() in {"travel_city_consistency", "travel_route_city_consistency"}: values.update({"attachment.hotel_city": "上海" if hit else "北京", "attachment.route_cities": ["上海"] if hit else ["北京"], "claim.location": "北京", "item.item_location": "北京"}) return values condition = next( (item for item in params.get("conditions", []) if isinstance(item, dict)), {}, ) left = str(condition.get("left") or "").strip() right = str(condition.get("right") or "").strip() operator = str(condition.get("operator") or "overlap").strip() if left and operator == "is_empty": values[left] = "测试值" if hit else "" elif left and right and operator in {"not_equals", "not_in", "not_overlap"}: values[left] = "北京" if hit else "上海" values[right] = "北京" elif left and right: values[left] = "上海" if hit else "北京" values[right] = "北京" elif template_key == "field_required_v1" and hit and fields: values[fields[0]["key"]] = "" elif template_key == "keyword_match_v1": keywords = params.get("keywords") if isinstance(params.get("keywords"), list) else [] keyword = str(next(iter(keywords), "咨询费") or "咨询费") target_key = fields[0]["key"] if fields else "claim.reason" values[target_key] = f"本次报销包含{keyword}" if hit else "正常差旅报销" return values @staticmethod def _default_value_for_field(field_key: str) -> Any: if field_key.endswith("amount"): return "100.00" if ( field_key.endswith("issue_date") or field_key.endswith("stay_start_date") or field_key.endswith("stay_end_date") or field_key.endswith("trip_start_date") or field_key.endswith("trip_end_date") or field_key.endswith("item_date") ): return date.today().isoformat() if field_key.endswith("route_cities"): return ["北京"] if field_key.endswith("ocr_text"): return "正常发票内容" if "city" in field_key or "location" in field_key: return "北京" if field_key.endswith("item_type"): return "住宿费" return "测试值" def _query_expense_claim_samples(self, parsed_scope: dict[str, Any]) -> list[ExpenseClaim]: days = int(parsed_scope.get("days") or 30) limit = min(max(int(parsed_scope.get("limit") or 50), 1), 200) since = datetime.now(UTC) - timedelta(days=days) stmt = select(ExpenseClaim).where(ExpenseClaim.created_at >= since) expense_keyword = str(parsed_scope.get("expense_keyword") or "").strip() if expense_keyword: like_keyword = f"%{expense_keyword}%" stmt = stmt.where( or_( ExpenseClaim.expense_type.ilike(like_keyword), ExpenseClaim.reason.ilike(like_keyword), ) ) cities = [str(item or "").strip() for item in parsed_scope.get("cities", []) if item] if cities: city_filters = [] for city in cities[:8]: like_city = f"%{city}%" city_filters.extend( [ ExpenseClaim.location.ilike(like_city), ExpenseClaim.reason.ilike(like_city), ] ) stmt = stmt.where(or_(*city_filters)) stmt = stmt.order_by(ExpenseClaim.created_at.desc()).limit(limit) return list(self.db.scalars(stmt).all()) @staticmethod def _parse_scenario_scope(intent: str, filters: dict[str, Any]) -> dict[str, Any]: text = str(intent or "") raw_days = filters.get("days") or filters.get("recent_days") days = int(raw_days) if str(raw_days or "").isdigit() else 30 match = re.search(r"最近\s*(\d{1,3})\s*天", text) if match: days = int(match.group(1)) limit = filters.get("limit") if str(filters.get("limit") or "").isdigit() else 50 expense_keyword = str(filters.get("expense_keyword") or "").strip() if not expense_keyword and any(keyword in text for keyword in ("酒店", "住宿")): expense_keyword = "住宿" city_candidates = ("北京", "上海", "广州", "深圳", "武汉", "杭州", "成都", "南京") cities = [ city for city in city_candidates if city in text or city in [str(item) for item in filters.get("cities", []) or []] ] return { "business_domain": "expense", "days": max(1, min(days, 365)), "limit": max(1, min(int(limit), 200)), "expense_keyword": expense_keyword, "cities": cities, "execution_mode": "dry_run", } @staticmethod def _extract_manifest_fields(manifest: dict[str, Any]) -> list[dict[str, str]]: inputs = manifest.get("inputs") if isinstance(manifest.get("inputs"), dict) else {} fields = inputs.get("fields") if isinstance(inputs.get("fields"), list) else [] normalized = [] for item in fields: if not isinstance(item, dict): continue key = str(item.get("key") or "").strip() if key: normalized.append({"key": key, "label": str(item.get("label") or key).strip()}) return normalized @staticmethod def _coerce_sample_value(field_key: str, value: Any) -> Any: if field_key.endswith("route_cities") and isinstance(value, str): return [item.strip() for item in re.split(r"[,,、/ ]+", value) if item.strip()] return value @staticmethod def _to_decimal(value: Any) -> Decimal: try: return Decimal(str(value or "0")) except (InvalidOperation, ValueError): return Decimal("0") def _resolve_asset(self, asset_or_id: AgentAsset | str) -> AgentAsset: if isinstance(asset_or_id, AgentAsset): return asset_or_id asset = self.repository.get(str(asset_or_id)) if asset is None: raise LookupError("Asset not found") return asset @staticmethod def _require_json_risk_asset(asset: AgentAsset) -> None: config_json = asset.config_json if isinstance(asset.config_json, dict) else {} if asset.asset_type != AgentAssetType.RULE.value: raise ValueError("仅规则资产支持风险规则操作。") if str(config_json.get("detail_mode") or "").strip().lower() != "json_risk": raise ValueError("仅 JSON 风险规则支持该操作。") def _resolve_target_version(self, asset: AgentAsset, version: str | None) -> str: target = str(version or self._resolve_working_version(asset) or "").strip() if not target: raise ValueError("当前规则尚未配置工作版本。") return target def _delete_risk_rule_json_file(self, asset: AgentAsset) -> None: try: rule_library, file_name = self._resolve_json_risk_rule_document(asset) target = self.rule_library_manager.resolve_rule_library_path( library=rule_library, file_name=file_name, ) target.unlink(missing_ok=True) except (FileNotFoundError, ValueError): return @staticmethod def _serialize_test_run( run: AgentAssetTestRun | None, ) -> AgentAssetRiskRuleTestRunRead | None: return AgentAssetRiskRuleTestRunRead.model_validate(run) if run is not None else None