Files
X-Financial/server/src/app/services/employee_behavior_profile_helpers.py

208 lines
7.8 KiB
Python
Raw Normal View History

from __future__ import annotations
import json
from collections import defaultdict
from decimal import Decimal
from typing import Any
from app.algorithem.employee_behavior_profile import ALGORITHM_VERSION
from app.models.agent_run import AgentRun
from app.models.employee import Employee
from app.models.financial_record import ExpenseClaim
TRAVEL_EXPENSE_TYPES = {
"travel",
"train_ticket",
"flight_ticket",
"hotel_ticket",
"ride_ticket",
"travel_allowance",
}
ENTERTAINMENT_EXPENSE_TYPES = {"meal", "entertainment"}
class EmployeeBehaviorProfileMetricHelpers:
def _sum_amount_by_employee(self, claims: list[ExpenseClaim]) -> dict[str, Decimal]:
grouped: dict[str, Decimal] = defaultdict(Decimal)
for claim in claims:
grouped[self._claim_employee_key(claim)] += self._decimal(claim.amount)
return dict(grouped)
def _count_by_employee(self, claims: list[ExpenseClaim]) -> dict[str, int]:
grouped: dict[str, int] = defaultdict(int)
for claim in claims:
grouped[self._claim_employee_key(claim)] += 1
return dict(grouped)
def _return_count_by_employee(self, claims: list[ExpenseClaim]) -> dict[str, int]:
grouped: dict[str, int] = defaultdict(int)
for claim in claims:
grouped[self._claim_employee_key(claim)] += self._return_count([claim])
return dict(grouped)
def _claim_employee_key(self, claim: ExpenseClaim) -> str:
return str(claim.employee_id or claim.employee_name or "unknown").strip()
def _employee_identifiers(self, employee: Employee) -> set[str]:
return {
item
for item in (
employee.id,
employee.employee_no,
employee.email,
employee.name,
)
if str(item or "").strip()
}
def _return_count(self, claims: list[ExpenseClaim]) -> int:
count = 0
for claim in claims:
status = str(claim.status or "").lower()
if status in {"returned", "supplement", "rejected"}:
count += 1
for flag in claim.risk_flags_json or []:
if isinstance(flag, dict) and str(flag.get("source") or "") == "manual_return":
count += 1
return count
def _missing_attachment_count(self, claim: ExpenseClaim) -> int:
if not claim.items:
return int((claim.invoice_count or 0) <= 0)
return sum(1 for item in claim.items if not str(item.invoice_id or "").strip())
def _has_amount_mismatch(self, claim: ExpenseClaim) -> bool:
if not claim.items:
return False
item_total = sum((self._decimal(item.item_amount) for item in claim.items), Decimal("0"))
return abs(item_total - self._decimal(claim.amount)) > Decimal("0.01")
def _missing_context_count(self, claim: ExpenseClaim) -> int:
missing = 0
for value in (claim.reason, claim.location, claim.project_code):
if self._is_missing_value(value):
missing += 1
for item in claim.items or []:
if self._is_missing_value(item.item_reason):
missing += 1
if item.item_type in TRAVEL_EXPENSE_TYPES and self._is_missing_value(
item.item_location
):
missing += 1
return missing
def _claim_travel_days(self, claim: ExpenseClaim | None) -> Decimal:
if claim is None:
return Decimal("0")
dates = {
item.item_date
for item in claim.items or []
if item.item_type in TRAVEL_EXPENSE_TYPES and item.item_date is not None
}
if dates:
return Decimal(max(1, len(dates)))
return Decimal("1") if claim.expense_type in TRAVEL_EXPENSE_TYPES else Decimal("0")
def _entertainment_unit_amount(self, claim: ExpenseClaim) -> Decimal:
if claim.expense_type not in ENTERTAINMENT_EXPENSE_TYPES:
return Decimal("0")
attendee_count = self._extract_attendee_count(claim)
if attendee_count <= 0:
return Decimal("0")
return self._decimal(claim.amount) / Decimal(attendee_count)
def _extract_attendee_count(self, claim: ExpenseClaim) -> int:
text = " ".join(
[claim.reason or "", *(item.item_reason or "" for item in claim.items or [])]
)
for token in ("", ""):
parts = text.split(token)
for part in parts:
digits = "".join(ch for ch in part[-3:] if ch.isdigit())
if digits:
return max(1, int(digits))
return 0
def _resolve_scope_from_claim(self, claim_id: str | None, expense_type_scope: str) -> str:
normalized = str(expense_type_scope or "overall").strip() or "overall"
if normalized != "overall" or not claim_id:
return normalized
claim = self.db.get(ExpenseClaim, claim_id)
return str(claim.expense_type or "overall").strip() if claim is not None else normalized
def _is_claim_in_scope(self, claim: ExpenseClaim, expense_type_scope: str) -> bool:
scope = str(expense_type_scope or "overall").strip()
if scope == "overall":
return True
if scope == "entertainment":
return claim.expense_type in ENTERTAINMENT_EXPENSE_TYPES
if scope == "travel":
return claim.expense_type in TRAVEL_EXPENSE_TYPES
return claim.expense_type == scope
def _common_metrics(self, context: dict[str, Any]) -> dict[str, Any]:
return {
"window_days": context["window_days"],
"expense_type_scope": context["expense_type_scope"],
"peer_group_key": context["peer_group_key"],
"peer_group_fallback_level": context["peer_group_fallback_level"],
"peer_sample_size": context["peer_sample_size"],
"algorithm_version": ALGORITHM_VERSION,
}
def _estimate_tokens(self, runs: list[AgentRun]) -> int:
total = 0
for run in runs:
payload = {
"ontology": run.ontology_json,
"route": run.route_json,
"summary": run.result_summary,
"error": run.error_message,
"tools": [
{
"request": tool.request_json,
"response": tool.response_json,
"error": tool.error_message,
}
for tool in run.tool_calls
],
}
text = json.dumps(payload, ensure_ascii=False, default=str)
total += max(0, len(text) // 4)
return total
def _sum_agent_run_duration_ms(self, runs: list[AgentRun]) -> int:
return sum(self._agent_run_duration_ms(run) for run in runs)
def _agent_run_duration_ms(self, run: AgentRun) -> int:
if run.started_at is not None and run.finished_at is not None:
try:
if run.finished_at > run.started_at:
return min(
int((run.finished_at - run.started_at).total_seconds() * 1000),
24 * 60 * 60 * 1000,
)
except TypeError:
pass
return sum(max(0, int(tool.duration_ms or 0)) for tool in run.tool_calls)
@staticmethod
def _is_missing_value(value: Any) -> bool:
text = str(value or "").strip()
return not text or text in {"待补充", "暂无", "", "未知"}
@staticmethod
def _decimal(value: Any) -> Decimal:
try:
return Decimal(str(value or "0"))
except Exception:
return Decimal("0")
@staticmethod
def _format_decimal(value: Any) -> str:
try:
return str(Decimal(str(value or "0")).quantize(Decimal("0.0001")).normalize())
except Exception:
return "0"