from __future__ import annotations from datetime import UTC, datetime, timedelta from decimal import Decimal from typing import Any from sqlalchemy import func, select from sqlalchemy.orm import Session, joinedload from app.algorithem.risk_graph import RiskHistoryStats, RiskObservationDraft from app.db.base import Base from app.models.financial_record import ExpenseClaim from app.models.risk_observation import RiskObservation, RiskObservationFeedback from app.schemas.risk_observation import ( RiskObservationDashboardRead, RiskObservationFeedbackCreate, ) from app.services.expense_claim_risk_stage import normalize_risk_business_stage HIGH_LEVELS = {"high", "critical"} SEVERITY_SCORE = { "low": 32, "medium": 58, "high": 82, "critical": 100, } FEEDBACK_STATUS_MAP = { "confirm": ("confirmed", "confirmed"), "false_positive": ("false_positive", "false_positive"), "ignore": ("ignored", "ignored"), "resolve": ("resolved", "resolved"), } class RiskObservationService: _storage_ready_cache: set[str] = set() def __init__(self, db: Session) -> None: self.db = db def ensure_storage_ready(self) -> None: bind = self.db.get_bind() cache_key = str(getattr(bind, "url", "") or id(bind)) if cache_key in self._storage_ready_cache: return Base.metadata.create_all( bind=bind, tables=[ RiskObservation.__table__, RiskObservationFeedback.__table__, ], ) self._storage_ready_cache.add(cache_key) def upsert_observation( self, observation: RiskObservationDraft | dict[str, Any], *, run_id: str | None = None, execution_log_id: str | None = None, ) -> RiskObservation: self.ensure_storage_ready() payload = ( observation.as_dict() if isinstance(observation, RiskObservationDraft) else dict(observation) ) observation_key = str(payload.get("observation_key") or "").strip() if not observation_key: raise ValueError("Risk observation requires observation_key.") item = self.db.scalar( select(RiskObservation).where(RiskObservation.observation_key == observation_key) ) if item is None: item = RiskObservation(observation_key=observation_key) self.db.add(item) item.subject_type = _text(payload.get("subject_type")) item.subject_key = _text(payload.get("subject_key")) item.subject_label = _text(payload.get("subject_label")) item.claim_id = _optional_text(payload.get("claim_id")) item.claim_no = _text(payload.get("claim_no")) item.run_id = _optional_text(run_id or payload.get("run_id")) item.execution_log_id = _optional_text(execution_log_id or payload.get("execution_log_id")) item.risk_type = _text(payload.get("risk_type")) item.risk_signal = _text(payload.get("risk_signal")) item.title = _text(payload.get("title")) item.description = _text(payload.get("description")) item.risk_score = _clamp_score(payload.get("risk_score")) item.risk_level = _text(payload.get("risk_level")) or "low" item.confidence_score = _float(payload.get("confidence_score")) item.control_stage = _text(payload.get("control_stage")) item.control_mode = _text(payload.get("control_mode")) item.automation_mode = _text(payload.get("automation_mode")) item.source = _text(payload.get("source")) item.algorithm_version = _text(payload.get("algorithm_version")) item.contribution_scores_json = _dict(payload.get("contribution_scores")) item.baseline_json = _dict(payload.get("baseline")) item.evidence_json = _list(payload.get("evidence")) item.graph_node_keys_json = _list(payload.get("graph_node_keys")) item.graph_edge_keys_json = _list(payload.get("graph_edge_keys")) item.policy_refs_json = _list(payload.get("policy_refs")) item.similar_case_claim_ids_json = _list(payload.get("similar_case_claim_ids")) item.ontology_json = _risk_ontology_payload(payload) item.decision_trace_json = _risk_decision_trace_payload(payload) self.db.flush() return item def upsert_platform_risk_flags( self, claim: ExpenseClaim, flags: list[dict[str, Any]], *, run_id: str | None = None, execution_log_id: str | None = None, ) -> list[RiskObservation]: observations: list[RiskObservation] = [] for flag in flags: if not isinstance(flag, dict): continue if str(flag.get("rule_type") or "").strip() and flag.get("rule_type") != "risk": continue if str(flag.get("hit_source") or "").strip() not in {"", "rule_center"}: continue signal = _risk_signal_from_flag(flag) if not signal: continue severity = _normalize_level(flag.get("severity")) score = SEVERITY_SCORE.get(severity, SEVERITY_SCORE["medium"]) rule_code = _text(flag.get("rule_code")) business_stage = normalize_risk_business_stage(flag.get("business_stage")) observation_key = ( f"risk:{claim.id}:platform:{rule_code or signal}" ) observations.append( self.upsert_observation( { "observation_key": observation_key, "subject_type": "expense_claim", "subject_key": f"claim:{claim.id}", "subject_label": claim.claim_no, "claim_id": claim.id, "claim_no": claim.claim_no, "risk_type": signal, "risk_signal": signal, "title": _text(flag.get("label")) or signal, "description": _text(flag.get("message")), "risk_score": score, "risk_level": severity, "confidence_score": "0.78", "control_stage": business_stage, "control_mode": "risk_observation", "automation_mode": ( "semi_auto_review" if severity in HIGH_LEVELS else "manual_review" ), "source": "rule_center", "algorithm_version": _text(flag.get("rule_version")) or "v1.0.0", "contribution_scores": {"S_rule": score}, "baseline": {}, "evidence": [ { "code": "platform_risk_rule", "title": _text(flag.get("label")) or signal, "detail": _text(flag.get("message")), "source": "rule_center", "score": score, "metadata": flag, } ], "graph_node_keys": [f"claim:{claim.id}"], "graph_edge_keys": [], "policy_refs": [rule_code] if rule_code else [], "similar_case_claim_ids": [], "ontology_json": {}, "decision_trace": { "rule_code": rule_code, "rule_version": _text(flag.get("rule_version")), "action": _text(flag.get("action")), }, }, run_id=run_id, execution_log_id=execution_log_id, ) ) return observations def build_history_stats( self, *, risk_signals: set[str] | None = None, expense_types: set[str] | None = None, limit: int = 2000, ) -> list[RiskHistoryStats]: self.ensure_storage_ready() stmt = ( select(RiskObservation, ExpenseClaim.expense_type) .outerjoin(ExpenseClaim, RiskObservation.claim_id == ExpenseClaim.id) .order_by(RiskObservation.created_at.desc()) .limit(limit) ) rows = list(self.db.execute(stmt).all()) signal_filter = {_canonical_key(item) for item in (risk_signals or set()) if item} expense_filter = {_canonical_key(item) for item in (expense_types or set()) if item} grouped: dict[tuple[str, str], RiskHistoryStats] = {} for observation, expense_type in rows: signal = _canonical_key(observation.risk_signal) expense = _canonical_key(expense_type or "") if signal_filter and signal not in signal_filter: continue if expense_filter and expense and expense not in expense_filter: continue key = (signal, expense) stats = grouped.setdefault( key, RiskHistoryStats(risk_signal=signal, expense_type=expense), ) stats.similar_case_count += 1 feedback_status = _canonical_key(observation.feedback_status) if feedback_status == "confirmed": stats.confirmed_count += 1 elif feedback_status == "false_positive": stats.false_positive_count += 1 if _has_return_feedback(observation): stats.returned_count += 1 return list(grouped.values()) def list_observations( self, *, claim_id: str | None = None, run_id: str | None = None, execution_log_id: str | None = None, risk_level: str | None = None, risk_signal: str | None = None, status: str | None = None, source: str | None = None, limit: int = 50, offset: int = 0, ) -> tuple[list[RiskObservation], int]: self.ensure_storage_ready() conditions = [] if claim_id: conditions.append(RiskObservation.claim_id == claim_id) if run_id: conditions.append(RiskObservation.run_id == run_id) if execution_log_id: conditions.append(RiskObservation.execution_log_id == execution_log_id) if risk_level: conditions.append(RiskObservation.risk_level == risk_level) if risk_signal: conditions.append(RiskObservation.risk_signal == risk_signal) if status: conditions.append(RiskObservation.status == status) if source: conditions.append(RiskObservation.source == source) count_stmt = select(func.count()).select_from(RiskObservation) stmt = select(RiskObservation).order_by( RiskObservation.risk_score.desc(), RiskObservation.created_at.desc(), ) if conditions: count_stmt = count_stmt.where(*conditions) stmt = stmt.where(*conditions) total = int(self.db.scalar(count_stmt) or 0) items = list(self.db.scalars(stmt.offset(offset).limit(limit)).all()) return items, total def get_observation(self, observation_key_or_id: str) -> RiskObservation | None: self.ensure_storage_ready() value = str(observation_key_or_id or "").strip() if not value: return None return self.db.scalar( select(RiskObservation).where( (RiskObservation.observation_key == value) | (RiskObservation.id == value) ) ) def list_claim_observations(self, claim_id: str) -> list[RiskObservation]: items, _ = self.list_observations(claim_id=claim_id, limit=100, offset=0) return items def list_execution_log_observations(self, execution_log_id: str) -> list[RiskObservation]: items, _ = self.list_observations( execution_log_id=execution_log_id, limit=200, offset=0, ) return items def create_feedback( self, observation_key_or_id: str, payload: RiskObservationFeedbackCreate, ) -> RiskObservationFeedback: self.ensure_storage_ready() observation = self.get_observation(observation_key_or_id) if observation is None: raise LookupError("Risk observation not found.") feedback = RiskObservationFeedback( observation_id=observation.id, feedback_type=payload.feedback_type, action=payload.action or "", actor=payload.actor or "", comment=payload.comment, payload_json=payload.payload_json, ) self.db.add(feedback) mapped = FEEDBACK_STATUS_MAP.get(payload.feedback_type) if mapped: observation.status, observation.feedback_status = mapped self.db.commit() self.db.refresh(feedback) return feedback def summarize_dashboard( self, *, window_days: int = 30, limit: int = 500, ) -> RiskObservationDashboardRead: self.ensure_storage_ready() since = datetime.now(UTC) - timedelta(days=window_days) stmt = ( select(RiskObservation) .options(joinedload(RiskObservation.claim)) .where(RiskObservation.created_at >= since) .order_by(RiskObservation.created_at.desc()) .limit(limit) ) observations = list(self.db.scalars(stmt).all()) total = len(observations) confirmed = sum(1 for item in observations if item.feedback_status == "confirmed") false_positive = sum(1 for item in observations if item.feedback_status == "false_positive") pending = sum(1 for item in observations if item.status == "pending_review") feedback_samples = int( self.db.scalar( select(func.count()) .select_from(RiskObservationFeedback) .where(RiskObservationFeedback.created_at >= since) ) or 0 ) high_or_above = sum(1 for item in observations if item.risk_level in HIGH_LEVELS) score_sum = sum(int(item.risk_score or 0) for item in observations) reviewed = confirmed + false_positive signal_distribution = _count_by(observations, "risk_signal") total_amount = sum((_claim_amount(item.claim) for item in observations), Decimal("0")) return RiskObservationDashboardRead( window_days=window_days, total_observations=total, pending_count=pending, risk_clue_count=pending, high_or_above_count=high_or_above, confirmed_count=confirmed, false_positive_count=false_positive, feedback_sample_count=feedback_samples, total_amount=float(total_amount), average_score=round(score_sum / total, 2) if total else 0.0, level_distribution=_count_by(observations, "risk_level"), status_distribution=_count_by(observations, "status"), signal_distribution=signal_distribution, risk_type_distribution=_count_by(observations, "risk_type"), source_distribution=_count_by(observations, "source"), automation_distribution=_count_by(observations, "automation_mode"), department_distribution=_claim_distribution( observations, lambda claim: claim.department_name if claim else "", ), expense_type_distribution=_claim_distribution( observations, lambda claim: claim.expense_type if claim else "", ), supplier_distribution=_supplier_distribution(observations), employee_grade_distribution=_claim_distribution( observations, lambda claim: claim.employee_grade if claim else "", ), daily_trend=_daily_trend(observations), top_risk_signals=_top_counts(signal_distribution), top_departments=_top_claim_dimension( observations, lambda claim: claim.department_name if claim else "", ), top_employees=_top_claim_dimension( observations, lambda claim: claim.employee_name if claim else "", ), top_suppliers=_top_suppliers(observations), top_expense_types=_top_claim_dimension( observations, lambda claim: claim.expense_type if claim else "", ), top_rules=_top_rules(observations), candidate_rule_count=0, confirmation_rate=round(confirmed / reviewed, 4) if reviewed else 0.0, false_positive_rate=round(false_positive / reviewed, 4) if reviewed else 0.0, recent_high_observations=[ item for item in observations if item.risk_level in HIGH_LEVELS ][:10], ) def _count_by(items: list[RiskObservation], field: str) -> dict[str, int]: counts: dict[str, int] = {} for item in items: value = _text(getattr(item, field, "")) or "unknown" counts[value] = counts.get(value, 0) + 1 return counts def _claim_distribution( items: list[RiskObservation], getter: Any, ) -> dict[str, int]: counts: dict[str, int] = {} for item in items: value = _text(getter(item.claim)) or "unknown" counts[value] = counts.get(value, 0) + 1 return counts def _supplier_distribution(items: list[RiskObservation]) -> dict[str, int]: counts: dict[str, int] = {} for item in items: for supplier in _supplier_names(item): counts[supplier] = counts.get(supplier, 0) + 1 return counts def _top_claim_dimension( items: list[RiskObservation], getter: Any, *, limit: int = 5, ) -> list[dict[str, Any]]: buckets: dict[str, dict[str, Any]] = {} for item in items: name = _text(getter(item.claim)) or "unknown" bucket = buckets.setdefault(name, {"name": name, "count": 0, "amount": Decimal("0")}) bucket["count"] += 1 bucket["amount"] += _claim_amount(item.claim) return _top_dimension_rows(buckets, limit=limit) def _top_suppliers(items: list[RiskObservation], *, limit: int = 5) -> list[dict[str, Any]]: buckets: dict[str, dict[str, Any]] = {} for item in items: suppliers = _supplier_names(item) if not suppliers: continue amount = _claim_amount(item.claim) for supplier in suppliers: bucket = buckets.setdefault( supplier, {"name": supplier, "count": 0, "amount": Decimal("0")}, ) bucket["count"] += 1 bucket["amount"] += amount return _top_dimension_rows(buckets, limit=limit) def _top_rules(items: list[RiskObservation], *, limit: int = 5) -> list[dict[str, Any]]: buckets: dict[str, dict[str, Any]] = {} for item in items: rules = [_text(value) for value in (item.policy_refs_json or []) if _text(value)] if not rules and item.source == "rule_center": rules = [_text(item.risk_signal)] for rule in rules: bucket = buckets.setdefault(rule, {"name": rule, "count": 0, "amount": Decimal("0")}) bucket["count"] += 1 bucket["amount"] += _claim_amount(item.claim) return _top_dimension_rows(buckets, limit=limit) def _top_dimension_rows( buckets: dict[str, dict[str, Any]], *, limit: int, ) -> list[dict[str, Any]]: ranked = sorted( buckets.values(), key=lambda item: (item["count"], item["amount"]), reverse=True, )[:limit] return [ { "name": item["name"], "count": item["count"], "amount": float(item["amount"]), } for item in ranked ] def _supplier_names(item: RiskObservation) -> list[str]: names: list[str] = [] for value in item.graph_node_keys_json or []: text = _text(value) lowered = text.lower() if lowered.startswith(("supplier:", "vendor:", "merchant:")): names.append(text.split(":", 1)[1] or text) for evidence in item.evidence_json or []: if isinstance(evidence, dict): metadata = evidence.get("metadata") if isinstance(evidence.get("metadata"), dict) else {} for key in ("supplier_name", "vendor_name", "merchant_name", "supplier", "vendor"): name = _text(evidence.get(key)) or _text(metadata.get(key)) if name: names.append(name) return list(dict.fromkeys(names)) def _claim_amount(claim: ExpenseClaim | None) -> Decimal: if claim is None: return Decimal("0") try: return Decimal(str(claim.amount or "0")) except Exception: return Decimal("0") def _daily_trend(items: list[RiskObservation]) -> list[dict[str, Any]]: grouped: dict[str, dict[str, int]] = {} for item in items: day = item.created_at.date().isoformat() if item.created_at else "unknown" bucket = grouped.setdefault(day, {"date": day, "total": 0, "high_or_above": 0}) bucket["total"] += 1 if item.risk_level in HIGH_LEVELS: bucket["high_or_above"] += 1 return [grouped[key] for key in sorted(grouped)] def _top_counts(counts: dict[str, int], limit: int = 10) -> list[dict[str, Any]]: return [ {"name": key, "count": value} for key, value in sorted(counts.items(), key=lambda item: item[1], reverse=True)[:limit] ] def _risk_signal_from_flag(flag: dict[str, Any]) -> str: raw = _text(flag.get("risk_signal")) or _text(flag.get("rule_code")) or _text(flag.get("label")) if not raw: return "" if "." in raw: raw = raw.split(".")[-1] return _canonical_key(raw) def _normalize_level(value: Any) -> str: normalized = _canonical_key(value) return normalized if normalized in {"low", "medium", "high", "critical"} else "medium" def _has_return_feedback(observation: RiskObservation) -> bool: if _canonical_key(observation.status) in {"returned", "supplement_required"}: return True for feedback in list(observation.feedback_items or []): action = _canonical_key(feedback.action) feedback_type = _canonical_key(feedback.feedback_type) if action in {"return", "returned", "supplement", "supplement_required"}: return True if feedback_type in {"return", "returned"}: return True return False def _text(value: Any) -> str: return str(value or "").strip() def _canonical_key(value: Any) -> str: return "_".join(_text(value).lower().split()) def _optional_text(value: Any) -> str | None: normalized = _text(value) return normalized or None def _dict(value: Any) -> dict[str, Any]: return dict(value) if isinstance(value, dict) else {} def _list(value: Any) -> list[Any]: return list(value) if isinstance(value, list) else [] def _risk_ontology_payload(payload: dict[str, Any]) -> dict[str, Any]: ontology = _dict(payload.get("ontology_json")) for key in ( "ontology_parse_id", "ontology_version", "domain", "scenario", "intent", "ontology_entities_json", "risk_signals_json", "canonical_subject_key", ): value = payload.get(key) if value not in (None, "", [], {}): ontology[key] = value return ontology def _risk_decision_trace_payload(payload: dict[str, Any]) -> dict[str, Any]: decision_trace = _dict(payload.get("decision_trace")) for key in ("sampling_strategy", "evaluation_case_id"): value = payload.get(key) if value not in (None, "", [], {}): decision_trace[key] = value return decision_trace def _float(value: Any) -> float: try: return float(value or 0) except (TypeError, ValueError): return 0.0 def _clamp_score(value: Any) -> int: try: numeric = int(float(value or 0)) except (TypeError, ValueError): numeric = 0 return max(0, min(100, numeric))