from __future__ import annotations import json from collections import defaultdict from decimal import Decimal from typing import Any from app.algorithem.employee_behavior_profile import ALGORITHM_VERSION from app.models.agent_run import AgentRun from app.models.employee import Employee from app.models.financial_record import ExpenseClaim TRAVEL_EXPENSE_TYPES = { "travel", "train_ticket", "flight_ticket", "hotel_ticket", "ride_ticket", "travel_allowance", } ENTERTAINMENT_EXPENSE_TYPES = {"meal", "entertainment"} class EmployeeBehaviorProfileMetricHelpers: def _sum_amount_by_employee(self, claims: list[ExpenseClaim]) -> dict[str, Decimal]: grouped: dict[str, Decimal] = defaultdict(Decimal) for claim in claims: grouped[self._claim_employee_key(claim)] += self._decimal(claim.amount) return dict(grouped) def _count_by_employee(self, claims: list[ExpenseClaim]) -> dict[str, int]: grouped: dict[str, int] = defaultdict(int) for claim in claims: grouped[self._claim_employee_key(claim)] += 1 return dict(grouped) def _return_count_by_employee(self, claims: list[ExpenseClaim]) -> dict[str, int]: grouped: dict[str, int] = defaultdict(int) for claim in claims: grouped[self._claim_employee_key(claim)] += self._return_count([claim]) return dict(grouped) def _claim_employee_key(self, claim: ExpenseClaim) -> str: return str(claim.employee_id or claim.employee_name or "unknown").strip() def _employee_identifiers(self, employee: Employee) -> set[str]: return { item for item in ( employee.id, employee.employee_no, employee.email, employee.name, ) if str(item or "").strip() } def _return_count(self, claims: list[ExpenseClaim]) -> int: count = 0 for claim in claims: status = str(claim.status or "").lower() if status in {"returned", "supplement", "rejected"}: count += 1 for flag in claim.risk_flags_json or []: if isinstance(flag, dict) and str(flag.get("source") or "") == "manual_return": count += 1 return count def _missing_attachment_count(self, claim: ExpenseClaim) -> int: if not claim.items: return int((claim.invoice_count or 0) <= 0) return sum(1 for item in claim.items if not str(item.invoice_id or "").strip()) def _has_amount_mismatch(self, claim: ExpenseClaim) -> bool: if not claim.items: return False item_total = sum((self._decimal(item.item_amount) for item in claim.items), Decimal("0")) return abs(item_total - self._decimal(claim.amount)) > Decimal("0.01") def _missing_context_count(self, claim: ExpenseClaim) -> int: missing = 0 for value in (claim.reason, claim.location, claim.project_code): if self._is_missing_value(value): missing += 1 for item in claim.items or []: if self._is_missing_value(item.item_reason): missing += 1 if item.item_type in TRAVEL_EXPENSE_TYPES and self._is_missing_value( item.item_location ): missing += 1 return missing def _claim_travel_days(self, claim: ExpenseClaim | None) -> Decimal: if claim is None: return Decimal("0") dates = { item.item_date for item in claim.items or [] if item.item_type in TRAVEL_EXPENSE_TYPES and item.item_date is not None } if dates: return Decimal(max(1, len(dates))) return Decimal("1") if claim.expense_type in TRAVEL_EXPENSE_TYPES else Decimal("0") def _entertainment_unit_amount(self, claim: ExpenseClaim) -> Decimal: if claim.expense_type not in ENTERTAINMENT_EXPENSE_TYPES: return Decimal("0") attendee_count = self._extract_attendee_count(claim) if attendee_count <= 0: return Decimal("0") return self._decimal(claim.amount) / Decimal(attendee_count) def _extract_attendee_count(self, claim: ExpenseClaim) -> int: text = " ".join( [claim.reason or "", *(item.item_reason or "" for item in claim.items or [])] ) for token in ("人", "位"): parts = text.split(token) for part in parts: digits = "".join(ch for ch in part[-3:] if ch.isdigit()) if digits: return max(1, int(digits)) return 0 def _resolve_scope_from_claim(self, claim_id: str | None, expense_type_scope: str) -> str: normalized = str(expense_type_scope or "overall").strip() or "overall" if normalized != "overall" or not claim_id: return normalized claim = self.db.get(ExpenseClaim, claim_id) return str(claim.expense_type or "overall").strip() if claim is not None else normalized def _is_claim_in_scope(self, claim: ExpenseClaim, expense_type_scope: str) -> bool: scope = str(expense_type_scope or "overall").strip() if scope == "overall": return True if scope == "entertainment": return claim.expense_type in ENTERTAINMENT_EXPENSE_TYPES if scope == "travel": return claim.expense_type in TRAVEL_EXPENSE_TYPES return claim.expense_type == scope def _common_metrics(self, context: dict[str, Any]) -> dict[str, Any]: return { "window_days": context["window_days"], "expense_type_scope": context["expense_type_scope"], "peer_group_key": context["peer_group_key"], "peer_group_fallback_level": context["peer_group_fallback_level"], "peer_sample_size": context["peer_sample_size"], "algorithm_version": ALGORITHM_VERSION, } def _estimate_tokens(self, runs: list[AgentRun]) -> int: total = 0 for run in runs: payload = { "ontology": run.ontology_json, "route": run.route_json, "summary": run.result_summary, "error": run.error_message, "tools": [ { "request": tool.request_json, "response": tool.response_json, "error": tool.error_message, } for tool in run.tool_calls ], } text = json.dumps(payload, ensure_ascii=False, default=str) total += max(0, len(text) // 4) return total @staticmethod def _is_missing_value(value: Any) -> bool: text = str(value or "").strip() return not text or text in {"待补充", "暂无", "无", "未知"} @staticmethod def _decimal(value: Any) -> Decimal: try: return Decimal(str(value or "0")) except Exception: return Decimal("0") @staticmethod def _format_decimal(value: Any) -> str: try: return str(Decimal(str(value or "0")).quantize(Decimal("0.0001")).normalize()) except Exception: return "0"