from __future__ import annotations import re from typing import Any from app.models.financial_record import ExpenseClaim class RiskRuleTemplateExecutor: def evaluate( self, manifest: dict[str, Any], *, claim: ExpenseClaim, contexts: list[dict[str, Any]], ) -> dict[str, Any] | None: params = manifest.get("params") if isinstance(manifest.get("params"), dict) else {} template_key = str(manifest.get("template_key") or params.get("template_key") or "").strip() if template_key == "field_required_v1": return self._evaluate_required_fields(params, claim=claim, contexts=contexts) if template_key == "field_compare_v1": return self._evaluate_compare_conditions(params, claim=claim, contexts=contexts) if template_key == "keyword_match_v1": return self._evaluate_keyword_match(params, claim=claim, contexts=contexts) return None def _evaluate_required_fields( self, params: dict[str, Any], *, claim: ExpenseClaim, contexts: list[dict[str, Any]], ) -> dict[str, Any] | None: required_fields = self._read_string_list( params.get("required_fields") or params.get("field_keys") ) missing = [ field_key for field_key in required_fields if not self._has_resolved_value(field_key, claim=claim, contexts=contexts) ] if not missing: return None return { "message": self._resolve_message( params, fallback=f"规则要求的字段未完整提供:{'、'.join(missing[:4])}。", ), "evidence": { "missing_fields": missing, "condition_summary": params.get("condition_summary"), }, } def _evaluate_compare_conditions( self, params: dict[str, Any], *, claim: ExpenseClaim, contexts: list[dict[str, Any]], ) -> dict[str, Any] | None: conditions = params.get("conditions") if isinstance(params.get("conditions"), list) else [] failures: list[dict[str, Any]] = [] for condition in conditions: if not isinstance(condition, dict): continue left_key = str(condition.get("left") or "").strip() right_key = str(condition.get("right") or "").strip() operator = str(condition.get("operator") or "not_overlap").strip() left_values = self._resolve_values(left_key, claim=claim, contexts=contexts) right_values = self._resolve_values(right_key, claim=claim, contexts=contexts) if self._condition_passes(operator, left_values, right_values): continue failures.append( { "left": left_key, "operator": operator, "right": right_key, "left_values": left_values[:5], "right_values": right_values[:5], } ) if not failures: return None return { "message": self._resolve_message( params, fallback=( "规则字段对比未通过:" f"{params.get('condition_summary') or '字段关系不符合要求'}。" ), ), "evidence": { "failed_conditions": failures[:5], "condition_summary": params.get("condition_summary"), }, } def _evaluate_keyword_match( self, params: dict[str, Any], *, claim: ExpenseClaim, contexts: list[dict[str, Any]], ) -> dict[str, Any] | None: keywords = self._read_string_list(params.get("keywords")) search_fields = self._read_string_list( params.get("search_fields") or params.get("field_keys") ) if not keywords: return None corpus_parts: list[str] = [] for field_key in search_fields: corpus_parts.extend(self._resolve_values(field_key, claim=claim, contexts=contexts)) if not corpus_parts: corpus_parts.extend( [ str(claim.reason or ""), str(claim.location or ""), *[str(item.item_reason or "") for item in list(claim.items or [])], *[str(context.get("ocr_text") or "") for context in contexts], ] ) corpus = "\n".join(corpus_parts) hits = [keyword for keyword in keywords if keyword and keyword in corpus] if not hits: return None return { "message": self._resolve_message( params, fallback=f"识别到风险关键词:{'、'.join(hits[:5])}。", ), "evidence": { "keyword_hits": hits[:8], "search_fields": search_fields, "condition_summary": params.get("condition_summary"), }, } def _resolve_values( self, field_key: str, *, claim: ExpenseClaim, contexts: list[dict[str, Any]], ) -> list[str]: normalized = str(field_key or "").strip() if not normalized: return [] if normalized.startswith("claim."): return self._normalize_values([getattr(claim, normalized.removeprefix("claim."), "")]) if normalized.startswith("item."): attr = normalized.removeprefix("item.") return self._normalize_values( [getattr(item, attr, "") for item in list(claim.items or [])] ) if normalized.startswith("attachment."): return self._resolve_attachment_values(normalized.removeprefix("attachment."), contexts) return [] def _resolve_attachment_values( self, field_key: str, contexts: list[dict[str, Any]] ) -> list[str]: values: list[Any] = [] for context in contexts: document_info = context.get("document_info") if isinstance(context, dict) else {} if not isinstance(document_info, dict): document_info = {} if field_key == "ocr_text": values.extend([context.get("ocr_text"), context.get("ocr_summary")]) if field_key in {"hotel_city", "route_cities"}: values.extend(self._scan_document_values(document_info, field_key)) values.extend(self._scan_document_values(document_info, "city")) else: values.extend(self._scan_document_values(document_info, field_key)) return self._normalize_values(values) def _scan_document_values(self, document_info: dict[str, Any], field_key: str) -> list[Any]: values: list[Any] = [] for key in {field_key, field_key.replace("_", ""), field_key.replace("_", "-")}: if key in document_info: values.append(document_info.get(key)) for field in list(document_info.get("fields") or []): if not isinstance(field, dict): continue key = str(field.get("key") or "").strip().lower() label = str(field.get("label") or "").strip() if self._field_matches(key, label, field_key): values.append(field.get("value")) return values @staticmethod def _field_matches(key: str, label: str, field_key: str) -> bool: compact_key = key.replace("_", "") compact_target = field_key.replace("_", "") if compact_target in compact_key: return True label_map = { "invoice_no": ("发票号", "发票号码", "票号"), "buyer_name": ("购买方", "抬头", "买方"), "goods_name": ("品名", "商品", "服务名称"), "issue_date": ("日期", "开票日期", "发票日期"), "hotel_city": ("住宿城市", "酒店城市", "酒店地点"), "route_cities": ("行程", "路线", "城市"), "city": ("城市", "地点"), } return any(item in label for item in label_map.get(field_key, ())) def _has_resolved_value( self, field_key: str, *, claim: ExpenseClaim, contexts: list[dict[str, Any]], ) -> bool: return bool(self._resolve_values(field_key, claim=claim, contexts=contexts)) @staticmethod def _condition_passes(operator: str, left_values: list[str], right_values: list[str]) -> bool: if operator == "is_empty": return not left_values if not left_values or not right_values: return False left_set = {value.lower() for value in left_values} right_set = {value.lower() for value in right_values} if operator in {"equals", "in", "overlap"}: return bool(left_set & right_set) if operator in {"not_equals", "not_in", "not_overlap"}: return not bool(left_set & right_set) if operator == "contains_any": return any(any(right in left for right in right_set) for left in left_set) return bool(left_set & right_set) @staticmethod def _normalize_values(values: list[Any]) -> list[str]: normalized: list[str] = [] for value in values: if isinstance(value, (list, tuple, set)): normalized.extend(RiskRuleTemplateExecutor._normalize_values(list(value))) continue text = re.sub(r"\s+", " ", str(value or "")).strip() if text and text not in normalized: normalized.append(text) return normalized @staticmethod def _read_string_list(value: Any) -> list[str]: if not isinstance(value, list): return [] return [str(item or "").strip() for item in value if str(item or "").strip()] @staticmethod def _resolve_message(params: dict[str, Any], *, fallback: str) -> str: template = str(params.get("message_template") or "").strip() return template or fallback