from __future__ import annotations import re from typing import Any from sqlalchemy import select from app.core.agent_enums import AgentAssetDomain, AgentAssetStatus, AgentAssetType from app.models.agent_asset import AgentAsset from app.models.financial_record import ExpenseClaim, ExpenseClaimItem from app.services.agent_asset_rule_library import AgentAssetRuleLibraryManager from app.services.agent_asset_spreadsheet import RISK_RULES_LIBRARY from app.services.expense_rule_runtime import ( RuntimeTravelPolicy, ) from app.services.risk_rule_manifest_normalizer import normalize_risk_rule_manifest from app.services.risk_rule_template_executor import RiskRuleTemplateExecutor class ExpenseClaimPlatformRiskMixin: def evaluate_platform_risk_rules( self, claim: ExpenseClaim, *, rule_codes: list[str] | None = None, ) -> dict[str, list[Any]]: manifests = self._load_platform_risk_rule_manifests(rule_codes=rule_codes) if not manifests: return {"flags": [], "blocking_reasons": []} contexts = self._build_claim_attachment_contexts(claim) flags: list[dict[str, Any]] = [] blocking_reasons: list[str] = [] for manifest in manifests: if not self._risk_manifest_applies_to_claim(manifest, claim=claim, contexts=contexts): continue flag = self._evaluate_platform_risk_manifest( manifest, claim=claim, contexts=contexts, ) if flag is None: continue flags.append(flag) severity = str(flag.get("severity") or "").strip().lower() action = str(flag.get("action") or "").strip().lower() if severity in {"high", "critical"} or action == "block": blocking_reasons.append(str(flag.get("message") or flag.get("label") or "").strip()) deduplicated_reasons = list(dict.fromkeys(reason for reason in blocking_reasons if reason)) return {"flags": flags, "blocking_reasons": deduplicated_reasons} def _load_platform_risk_rule_manifests( self, *, rule_codes: list[str] | None, ) -> list[dict[str, Any]]: code_filter = { str(code or "").strip() for code in list(rule_codes or []) if str(code or "").strip() } manifests_by_code: dict[str, dict[str, Any]] = {} assets = list( self.db.scalars( select(AgentAsset) .where(AgentAsset.asset_type == AgentAssetType.RULE.value) .where(AgentAsset.status == AgentAssetStatus.ACTIVE.value) .where(AgentAsset.domain == AgentAssetDomain.EXPENSE.value) .order_by(AgentAsset.updated_at.desc(), AgentAsset.created_at.desc()) ).all() ) library_manager = AgentAssetRuleLibraryManager() for asset in assets: config_json = asset.config_json if isinstance(asset.config_json, dict) else {} if str(config_json.get("detail_mode") or "").strip().lower() != "json_risk": continue rule_code = str(asset.code or "").strip() if code_filter and rule_code not in code_filter: continue rule_document = config_json.get("rule_document") if not isinstance(rule_document, dict): continue file_name = str(rule_document.get("file_name") or "").strip() rule_library = ( str(config_json.get("rule_library") or RISK_RULES_LIBRARY).strip() or RISK_RULES_LIBRARY ) if not file_name: continue try: payload = library_manager.read_rule_library_json( library=rule_library, file_name=file_name, ) except (FileNotFoundError, ValueError): continue payload = normalize_risk_rule_manifest(payload) manifest_code = str(payload.get("rule_code") or rule_code).strip() if not manifest_code or (code_filter and manifest_code not in code_filter): continue if payload.get("enabled") is False: continue payload = dict(payload) payload.setdefault("rule_code", manifest_code) payload["_rule_version"] = str( asset.published_version or asset.current_version or "v1.0.0" ) payload["_rule_asset_id"] = asset.id manifests_by_code[manifest_code] = payload missing_codes = code_filter - set(manifests_by_code) should_load_fallback = not code_filter or bool(missing_codes) if should_load_fallback: try: files = library_manager.list_rule_library_json_files(library=RISK_RULES_LIBRARY) except ValueError: files = [] for file_name in files: try: payload = library_manager.read_rule_library_json( library=RISK_RULES_LIBRARY, file_name=file_name, ) except (FileNotFoundError, ValueError): continue payload = normalize_risk_rule_manifest(payload) rule_code = str(payload.get("rule_code") or "").strip() if not rule_code or rule_code in manifests_by_code: continue if code_filter and rule_code not in missing_codes: continue if payload.get("enabled") is False: continue payload = dict(payload) payload["_rule_version"] = "v1.0.0" manifests_by_code[rule_code] = payload return list(manifests_by_code.values()) def _risk_manifest_applies_to_claim( self, manifest: dict[str, Any], *, claim: ExpenseClaim, contexts: list[dict[str, Any]], ) -> bool: applies_to = manifest.get("applies_to") if not isinstance(applies_to, dict): applies_to = {} try: min_attachments = int(applies_to.get("min_attachments") or 0) except (TypeError, ValueError): min_attachments = 0 if min_attachments and int(claim.invoice_count or 0) < min_attachments and not contexts: return False expense_types = { str(claim.expense_type or "").strip().lower(), *{ str(item.item_type or "").strip().lower() for item in list(claim.items or []) if str(item.item_type or "").strip() }, } domains = { str(value or "").strip().lower() for value in list(applies_to.get("domains") or []) if str(value or "").strip() } configured_expense_types = { str(value or "").strip().lower() for value in list(applies_to.get("expense_types") or []) if str(value or "").strip() } if configured_expense_types and not (expense_types & configured_expense_types): return False if domains and not self._risk_domains_match_claim( domains, expense_types=expense_types, contexts=contexts, ): return False return True def _risk_domains_match_claim( self, domains: set[str], *, expense_types: set[str], contexts: list[dict[str, Any]], ) -> bool: normalized_contexts: list[dict[str, str]] = [] for context in contexts: document_info = context.get("document_info") or {} normalized_contexts.append( { "scene_code": str(document_info.get("scene_code") or "").strip().lower(), "document_type": str(document_info.get("document_type") or "").strip().lower(), "item_type": str(getattr(context.get("item"), "item_type", "") or "") .strip() .lower(), } ) if "travel" in domains: if expense_types & {"travel", "hotel", "transport"}: return True if any( item["scene_code"] in {"travel", "hotel", "transport"} or item["document_type"] in { "flight_itinerary", "train_ticket", "hotel_invoice", "taxi_receipt", } for item in normalized_contexts ): return True if "meal" in domains: if expense_types & {"meal", "entertainment"}: return True if any( item["scene_code"] == "meal" or item["document_type"] == "meal_receipt" for item in normalized_contexts ): return True return bool(domains & expense_types) def _evaluate_platform_risk_manifest( self, manifest: dict[str, Any], *, claim: ExpenseClaim, contexts: list[dict[str, Any]], ) -> dict[str, Any] | None: evaluator = str(manifest.get("evaluator") or "").strip().lower() if evaluator == "reason_too_brief": return self._evaluate_reason_too_brief_risk(manifest, claim=claim) if evaluator == "entertainment_reason_missing": return self._evaluate_entertainment_reason_missing_risk(manifest, claim=claim) if evaluator == "document_expense_mismatch": return self._evaluate_document_expense_mismatch_risk( manifest, claim=claim, contexts=contexts, ) if evaluator == "location_consistency": return self._evaluate_location_consistency_risk( manifest, claim=claim, contexts=contexts, ) if evaluator == "duplicate_invoice": return self._evaluate_duplicate_invoice_risk(manifest, claim=claim, contexts=contexts) if evaluator == "identity_consistency": return self._evaluate_identity_consistency_risk( manifest, claim=claim, contexts=contexts, ) if evaluator == "cross_year_invoice": return self._evaluate_cross_year_invoice_risk(manifest, claim=claim, contexts=contexts) if evaluator == "void_or_red_invoice": return self._evaluate_text_keyword_risk( manifest, contexts=contexts, keywords=["作废", "红冲", "红字", "冲红"], fallback_message="票据文本中出现作废、红冲或红字发票相关信息,建议退回补充或人工复核。", ) if evaluator == "vague_goods_description": return self._evaluate_text_keyword_risk( manifest, contexts=contexts, keywords=["详见清单", "服务费", "咨询费", "其他", "办公用品"], fallback_message="票据商品或服务描述较笼统,建议审批人核对真实用途和明细清单。", ) if evaluator == "multi_city_reason_required": return self._evaluate_multi_city_reason_required_risk( manifest, claim=claim, contexts=contexts, ) if evaluator == "template_rule": result = RiskRuleTemplateExecutor().evaluate( manifest, claim=claim, contexts=contexts, ) if result is None: return None return self._build_platform_risk_flag( manifest, message=str(result.get("message") or "自然语言风险规则命中。"), evidence=result.get("evidence") if isinstance(result.get("evidence"), dict) else {}, ) return None def _evaluate_reason_too_brief_risk( self, manifest: dict[str, Any], *, claim: ExpenseClaim, ) -> dict[str, Any] | None: params = manifest.get("params") if isinstance(manifest.get("params"), dict) else {} try: min_reason_length = max(1, int(params.get("min_reason_length") or 6)) except (TypeError, ValueError): min_reason_length = 6 reason_corpus = re.sub(r"\s+", "", self._build_scene_reason_corpus(claim)) if len(reason_corpus) >= min_reason_length: return None return self._build_platform_risk_flag( manifest, message=f"报销事由有效描述不足 {min_reason_length} 个字符,暂不足以支撑真实性判断。", evidence={"reason_length": len(reason_corpus), "min_reason_length": min_reason_length}, ) def _evaluate_entertainment_reason_missing_risk( self, manifest: dict[str, Any], *, claim: ExpenseClaim, ) -> dict[str, Any] | None: expense_types = { str(claim.expense_type or "").strip().lower(), *{str(item.item_type or "").strip().lower() for item in list(claim.items or [])}, } reason_corpus = self._build_scene_reason_corpus(claim) compact_reason = re.sub(r"\s+", "", reason_corpus) looks_like_entertainment = ( "entertainment" in expense_types or "招待" in compact_reason or "客户" in compact_reason ) if not looks_like_entertainment: return None required_keywords = ("客户", "项目", "参与", "人员", "对象", "商务", "会议") has_detail = any(keyword in compact_reason for keyword in required_keywords) if has_detail: return None return self._build_platform_risk_flag( manifest, message="招待或餐饮类费用未识别到客户、项目、参与人员等必要说明,建议补充后再流转。", evidence={"reason": reason_corpus[:300]}, ) def _evaluate_document_expense_mismatch_risk( self, manifest: dict[str, Any], *, claim: ExpenseClaim, contexts: list[dict[str, Any]], ) -> dict[str, Any] | None: mismatches: list[str] = [] for context in contexts: item = context["item"] item_type = ( str(item.item_type or claim.expense_type or "other").strip().lower() or "other" ) policy = self._get_expense_scene_policy(item_type) if policy is None: continue document_info = context.get("document_info") or {} recognized_scene_code = ( str(document_info.get("scene_code") or "other").strip().lower() or "other" ) recognized_document_type = ( str(document_info.get("document_type") or "other").strip().lower() or "other" ) if recognized_scene_code in set( policy.allowed_scene_codes ) or recognized_document_type in set(policy.allowed_document_types): continue recognized_label = str( document_info.get("document_type_label") or recognized_document_type or "未知票据" ) mismatches.append( f"第 {context['index']} 条明细为{policy.label},附件识别为{recognized_label}" ) if not mismatches: return None return self._build_platform_risk_flag( manifest, message=";".join(mismatches[:3]) + ",与当前费用场景不匹配。", evidence={"mismatches": mismatches[:5]}, ) def _evaluate_location_consistency_risk( self, manifest: dict[str, Any], *, claim: ExpenseClaim, contexts: list[dict[str, Any]], ) -> dict[str, Any] | None: policy = self._get_expense_rule_catalog().travel_policy if policy is None: return None declared_cities = self._extract_known_cities_from_text( " ".join( [ str(claim.location or ""), *[str(item.item_location or "") for item in list(claim.items or [])], ] ), policy, ) evidence_cities = self._collect_attachment_cities(contexts, policy) if not declared_cities or not evidence_cities: return None if set(declared_cities) & set(evidence_cities): return None declared_text = "、".join(declared_cities) evidence_text = "、".join(evidence_cities[:5]) return self._build_platform_risk_flag( manifest, message=( f"申报地点 {declared_text} 与票据识别地点 {evidence_text} 不一致," "建议补充异地说明或更换附件。" ), evidence={"declared_cities": declared_cities, "evidence_cities": evidence_cities}, ) def _evaluate_duplicate_invoice_risk( self, manifest: dict[str, Any], *, claim: ExpenseClaim, contexts: list[dict[str, Any]], ) -> dict[str, Any] | None: invoice_keys = self._collect_invoice_keys_from_contexts(contexts) duplicate_keys = [ key for key, count in self._count_values(invoice_keys).items() if count > 1 ] if duplicate_keys: return self._build_platform_risk_flag( manifest, message=f"当前报销单内存在重复票据号码:{'、'.join(duplicate_keys[:3])}。", evidence={"duplicate_invoice_keys": duplicate_keys[:5]}, ) if not invoice_keys: return None other_items = list( self.db.scalars( select(ExpenseClaimItem) .where(ExpenseClaimItem.claim_id != claim.id) .where(ExpenseClaimItem.invoice_id.is_not(None)) ).all() ) matched_claim_ids: set[str] = set() for other_item in other_items: other_path = self._attachment_storage.resolve_path(other_item.invoice_id) if other_path is None or not other_path.exists(): continue other_meta = self._attachment_storage.read_meta(other_path) other_document_info = other_meta.get("document_info") if not isinstance(other_document_info, dict): continue other_keys = self._collect_invoice_keys_from_document_info(other_document_info) if set(invoice_keys) & set(other_keys): matched_claim_ids.add(str(other_item.claim_id or "")) if not matched_claim_ids: return None return self._build_platform_risk_flag( manifest, message=f"票据号码已在其他报销单中出现,疑似重复报销:{'、'.join(invoice_keys[:3])}。", evidence={ "invoice_keys": invoice_keys[:5], "matched_claim_ids": sorted(matched_claim_ids)[:5], }, ) def _evaluate_identity_consistency_risk( self, manifest: dict[str, Any], *, claim: ExpenseClaim, contexts: list[dict[str, Any]], ) -> dict[str, Any] | None: params = manifest.get("params") if isinstance(manifest.get("params"), dict) else {} allow_keywords = [ str(value) for value in list(params.get("allow_keywords") or []) if str(value).strip() ] claimant = str(claim.employee_name or "").strip() if not claimant: return None mismatched_buyers: list[str] = [] for context in contexts: buyer = self._resolve_first_document_field_value( context.get("document_info") or {}, keys={"buyer_name", "buyer", "purchaser_name", "claimant"}, labels={"购买方", "抬头", "买方", "购方"}, ) if not buyer: continue if claimant in buyer or any(keyword in buyer for keyword in allow_keywords): continue mismatched_buyers.append(buyer) if not mismatched_buyers: return None return self._build_platform_risk_flag( manifest, message=f"发票抬头 {mismatched_buyers[0]} 与报销人 {claimant} 不一致,建议人工复核。", evidence={"claimant": claimant, "buyers": mismatched_buyers[:5]}, ) def _evaluate_cross_year_invoice_risk( self, manifest: dict[str, Any], *, claim: ExpenseClaim, contexts: list[dict[str, Any]], ) -> dict[str, Any] | None: claim_year = claim.occurred_at.year if claim.occurred_at is not None else None if claim_year is None: return None issue_years: list[int] = [] for context in contexts: text = " ".join( [ self._resolve_first_document_field_value( context.get("document_info") or {}, keys={"date", "issue_date", "invoice_date"}, labels={"日期", "开票日期", "发生时间"}, ), str(context.get("ocr_summary") or ""), str(context.get("ocr_text") or ""), ] ) for match in re.findall(r"(20\d{2}|19\d{2})[年/\-.]", text): try: issue_years.append(int(match)) except ValueError: continue mismatch_years = sorted({year for year in issue_years if year != claim_year}) if not mismatch_years: return None return self._build_platform_risk_flag( manifest, message=( f"票据年份 {mismatch_years[0]} 与费用发生年份 {claim_year} 不一致," "建议确认是否跨年报销。" ), evidence={"claim_year": claim_year, "invoice_years": mismatch_years}, ) def _evaluate_text_keyword_risk( self, manifest: dict[str, Any], *, contexts: list[dict[str, Any]], keywords: list[str], fallback_message: str, ) -> dict[str, Any] | None: matched: list[str] = [] for context in contexts: text = f"{context.get('ocr_summary') or ''}\n{context.get('ocr_text') or ''}" for keyword in keywords: if keyword in text and keyword not in matched: matched.append(keyword) if not matched: return None return self._build_platform_risk_flag( manifest, message=fallback_message, evidence={"matched_keywords": matched}, ) def _evaluate_multi_city_reason_required_risk( self, manifest: dict[str, Any], *, claim: ExpenseClaim, contexts: list[dict[str, Any]], ) -> dict[str, Any] | None: policy = self._get_expense_rule_catalog().travel_policy if policy is None: return None cities = self._collect_attachment_cities(contexts, policy) for item in list(claim.items or []): for city in self._extract_known_cities_from_text(str(item.item_location or ""), policy): if city not in cities: cities.append(city) if len(cities) <= 2: return None reason_corpus = self._build_travel_reason_corpus(claim) if self._text_contains_keywords(reason_corpus, policy.route_exception_keywords): return None return self._build_platform_risk_flag( manifest, message=f"本次报销识别到多城市行程({'、'.join(cities[:5])}),但事由中未说明中转、多地拜访或改签原因。", evidence={"cities": cities[:8]}, ) def _build_platform_risk_flag( self, manifest: dict[str, Any], *, message: str, evidence: dict[str, Any], ) -> dict[str, Any]: outcomes = manifest.get("outcomes") if isinstance(manifest.get("outcomes"), dict) else {} fail_outcome = outcomes.get("fail") if isinstance(outcomes.get("fail"), dict) else {} severity = str(fail_outcome.get("severity") or "medium").strip().lower() or "medium" default_action = "block" if severity in {"high", "critical"} else "manual_review" action = str(fail_outcome.get("action") or default_action).strip() label = str(manifest.get("name") or manifest.get("rule_code") or "风险规则命中").strip() return { "source": "submission_review", "hit_source": "rule_center", "rule_type": "risk", "rule_code": str(manifest.get("rule_code") or "").strip(), "rule_version": str(manifest.get("_rule_version") or "v1.0.0").strip(), "severity": severity, "action": action, "label": label, "message": message, "evidence": evidence, } @staticmethod def _count_values(values: list[str]) -> dict[str, int]: counts: dict[str, int] = {} for value in values: normalized = str(value or "").strip() if not normalized: continue counts[normalized] = counts.get(normalized, 0) + 1 return counts def _collect_invoice_keys_from_contexts(self, contexts: list[dict[str, Any]]) -> list[str]: invoice_keys: list[str] = [] for context in contexts: document_info = context.get("document_info") or {} for key in self._collect_invoice_keys_from_document_info(document_info): if key not in invoice_keys: invoice_keys.append(key) return invoice_keys def _collect_invoice_keys_from_document_info(self, document_info: dict[str, Any]) -> list[str]: keys: list[str] = [] for field in list(document_info.get("fields") or []): if not isinstance(field, dict): continue field_key = str(field.get("key") or "").strip().lower().replace("_", "") label = str(field.get("label") or "").replace(" ", "") value = str(field.get("value") or "").strip() if not value: continue if field_key in {"invoiceno", "invoicenumber", "number", "code"} or any( token in label for token in ("发票号码", "票号", "发票代码", "号码") ): normalized = re.sub(r"\s+", "", value) if normalized and normalized not in keys: keys.append(normalized) return keys def _collect_attachment_cities( self, contexts: list[dict[str, Any]], policy: RuntimeTravelPolicy, ) -> list[str]: cities: list[str] = [] for context in contexts: document_info = context.get("document_info") or {} parts = [ str(context.get("ocr_summary") or ""), str(context.get("ocr_text") or ""), str(context.get("item").item_location if context.get("item") is not None else ""), ] for field in list(document_info.get("fields") or []): if isinstance(field, dict): parts.append(str(field.get("value") or "")) for city in self._extract_known_cities_from_text(" ".join(parts), policy): if city not in cities: cities.append(city) return cities @staticmethod def _extract_known_cities_from_text(text: str, policy: RuntimeTravelPolicy) -> list[str]: normalized = str(text or "").strip() if not normalized: return [] cities: list[str] = [] for city in sorted(policy.city_tiers.keys(), key=lambda item: len(item), reverse=True): if city in normalized and city not in cities: cities.append(city) return cities @staticmethod def _resolve_first_document_field_value( document_info: dict[str, Any], *, keys: set[str], labels: set[str], ) -> str: normalized_keys = {key.replace("_", "").lower() for key in keys} for field in list(document_info.get("fields") or []): if not isinstance(field, dict): continue field_key = str(field.get("key") or "").strip().lower().replace("_", "") label = str(field.get("label") or "").replace(" ", "") value = str(field.get("value") or "").strip() if not value: continue if field_key in normalized_keys or any(token in label for token in labels): return value return ""