Files
X-Financial/server/src/app/services/expense_claim_policy_review.py
2026-06-15 20:53:48 +08:00

733 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import re
from collections import defaultdict
from datetime import UTC, date, datetime, timedelta
from decimal import Decimal
from types import SimpleNamespace
from typing import Any
from sqlalchemy import or_, select
from sqlalchemy import inspect as sqlalchemy_inspect
from app.api.deps import CurrentUserContext
from app.core.agent_enums import AgentAssetDomain, AgentAssetStatus, AgentAssetType
from app.models.agent_asset import AgentAsset
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
from app.schemas.reimbursement import TravelReimbursementCalculatorRequest
from app.services.agent_asset_rule_library import AgentAssetRuleLibraryManager
from app.services.agent_asset_spreadsheet import RISK_RULES_LIBRARY
from app.services.expense_claim_constants import (
AI_REVIEW_LOOKBACK_DAYS,
AI_REVIEW_REPEAT_RISK_BLOCK_COUNT,
AI_REVIEW_REPEAT_RISK_WARNING_COUNT,
DOCUMENT_FACT_ITEM_TYPES,
LOCATION_REQUIRED_EXPENSE_TYPES,
SYSTEM_GENERATED_ITEM_TYPES,
TRAVEL_ALLOWANCE_TRIGGER_ITEM_TYPES,
TRAVEL_POLICY_HOTEL_NIGHT_PATTERN,
)
from app.services.expense_claim_risk_stage import with_risk_business_stage
from app.services.expense_rule_runtime import (
ExpenseRuleRuntimeService,
RuntimeTravelPolicy,
build_default_expense_rule_catalog,
)
from app.services.travel_policy_grades import (
resolve_travel_policy_grade_key,
travel_policy_grade_key_candidates,
)
class ExpenseClaimPolicyReviewMixin:
def _run_scene_policy_review(self, claim: ExpenseClaim) -> dict[str, list[Any]]:
catalog = self._get_expense_rule_catalog()
flags: list[dict[str, Any]] = []
blocking_reasons: list[str] = []
reason_corpus = self._build_scene_reason_corpus(claim)
scene_totals: dict[str, Decimal] = defaultdict(lambda: Decimal("0.00"))
scene_warned: set[str] = set()
for item in claim.items:
item_type = str(item.item_type or claim.expense_type or "other").strip().lower() or "other"
policy = catalog.get_scene_policy(item_type)
if policy is None:
continue
scene_totals[item_type] += Decimal(item.item_amount or Decimal("0.00")).quantize(Decimal("0.01"))
if policy.always_warn and item_type not in scene_warned:
scene_warned.add(item_type)
flags.append(
{
"source": "submission_review",
"severity": "medium",
"label": f"{policy.label}人工重点复核",
"message": policy.always_warn_message or f"{policy.label}默认需要人工重点复核。",
"rule_code": policy.rule_code,
}
)
item_limit = policy.item_amount_limit
item_amount = Decimal(item.item_amount or Decimal("0.00")).quantize(Decimal("0.01"))
if item_limit is not None and item_amount > Decimal("0.00"):
exceeded = self._evaluate_amount_limit(
amount=item_amount,
limit_config=item_limit,
reason_text="\n".join(
part
for part in [
reason_corpus,
str(item.item_reason or "").strip(),
str(item.item_note or "").strip(),
]
if part
),
)
if exceeded is not None:
severity, threshold = exceeded
label = (
f"{policy.label}金额超标待说明"
if severity == "high"
else f"{policy.label}金额超标提醒"
)
message = (
f"{policy.label}当前识别金额为 {item_amount} 元,"
f"已超过制度阈值 {threshold} 元。"
)
if severity == "high":
message += " 当前未识别到例外说明,请先补充原因。"
blocking_reasons.append(f"{policy.label}金额超出制度阈值,且未补充例外说明。")
else:
message += " 已识别到例外说明,请审批人重点复核。"
flags.append(
{
"source": "submission_review",
"severity": severity,
"label": label,
"message": message,
"rule_code": policy.rule_code,
}
)
for scene_code, total_amount in scene_totals.items():
policy = catalog.get_scene_policy(scene_code)
if policy is None or policy.claim_amount_limit is None or total_amount <= Decimal("0.00"):
continue
exceeded = self._evaluate_amount_limit(
amount=total_amount,
limit_config=policy.claim_amount_limit,
reason_text=reason_corpus,
)
if exceeded is None:
continue
severity, threshold = exceeded
label = f"{policy.label}合计超标待说明" if severity == "high" else f"{policy.label}合计超标提醒"
message = (
f"{policy.label}当前合计金额为 {total_amount} 元,"
f"已超过制度阈值 {threshold} 元。"
)
if severity == "high":
message += " 当前未识别到例外说明,请先补充原因。"
blocking_reasons.append(f"{policy.label}合计金额超出制度阈值,且未补充例外说明。")
else:
message += " 已识别到例外说明,请审批人重点复核。"
flags.append(
{
"source": "submission_review",
"severity": severity,
"label": label,
"message": message,
"rule_code": policy.rule_code,
}
)
return {
"flags": [with_risk_business_stage(flag, "reimbursement") for flag in flags],
"blocking_reasons": list(dict.fromkeys(reason for reason in blocking_reasons if reason)),
}
def _evaluate_amount_limit(
self,
*,
amount: Decimal,
limit_config: Any,
reason_text: str,
) -> tuple[str, Decimal] | None:
block_amount = getattr(limit_config, "block_amount", None)
warn_amount = getattr(limit_config, "warn_amount", None)
exception_keywords = list(getattr(limit_config, "exception_keywords", []) or [])
has_exception = self._text_contains_keywords(reason_text, exception_keywords)
if block_amount is not None and amount > Decimal(block_amount):
return ("medium" if has_exception else "high", Decimal(block_amount))
if warn_amount is not None and amount > Decimal(warn_amount):
return ("medium", Decimal(warn_amount))
return None
def _run_travel_policy_review(self, claim: ExpenseClaim) -> dict[str, list[Any]]:
policy = self._get_expense_rule_catalog().travel_policy
if policy is None:
return {"flags": [], "blocking_reasons": []}
contexts = [
context
for context in self._build_claim_attachment_contexts(claim)
if self._is_travel_policy_relevant_context(context, policy)
]
if not contexts:
return {"flags": [], "blocking_reasons": []}
reason_corpus = self._build_travel_reason_corpus(claim)
has_route_exception = self._text_contains_keywords(
reason_corpus,
policy.route_exception_keywords,
)
has_standard_exception = self._text_contains_keywords(
reason_corpus,
policy.standard_exception_keywords,
)
grade_band = self._resolve_travel_policy_band(claim.employee_grade)
band_label = policy.band_labels.get(grade_band or "", str(claim.employee_grade or "").strip() or "当前职级")
itinerary_segments: list[dict[str, Any]] = []
itinerary_cities: list[str] = []
hotel_contexts: list[dict[str, Any]] = []
flags: list[dict[str, Any]] = []
blocking_reasons: list[str] = []
for context in contexts:
route_segment = self._extract_route_segment(context, policy)
if route_segment and self._is_long_distance_travel_context(context, policy):
itinerary_segments.append(
{
"item": context["item"],
"origin": route_segment[0],
"destination": route_segment[1],
}
)
itinerary_cities.extend([route_segment[0], route_segment[1]])
scene_code = str(context["document_info"].get("scene_code") or "").strip().lower()
document_type = str(context["document_info"].get("document_type") or "").strip().lower()
item_type = str(context["item"].item_type or "").strip().lower()
if "hotel" in {scene_code, document_type, item_type} or document_type == "hotel_invoice":
hotel_contexts.append(context)
unique_itinerary_cities = list(dict.fromkeys(city for city in itinerary_cities if city))
expected_destination_city = self._resolve_expected_travel_city(
claim,
contexts,
unique_itinerary_cities,
policy,
)
if itinerary_segments:
unique_destinations = list(
dict.fromkeys(segment["destination"] for segment in itinerary_segments if segment["destination"])
)
first_origin = str(itinerary_segments[0]["origin"] or "").strip()
last_destination = str(itinerary_segments[-1]["destination"] or "").strip()
for previous, current in zip(itinerary_segments, itinerary_segments[1:]):
previous_destination = str(previous["destination"] or "").strip()
current_origin = str(current["origin"] or "").strip()
if previous_destination and current_origin and previous_destination != current_origin:
message = (
f"差旅行程未形成连续链路:上一段到达 {previous_destination}"
f"下一段却从 {current_origin} 出发,请补充中转或改签说明。"
)
flags.append(
self._with_related_item_ids(
{
"source": "submission_review",
"severity": "high",
"label": "行程闭环异常",
"message": message,
"rule_code": policy.rule_code,
},
self._itinerary_segment_item_ids([previous, current]),
)
)
blocking_reasons.append("差旅行程未形成连续闭环,请补充中转、改签或异地出发原因。")
break
if (
expected_destination_city
and last_destination
and last_destination not in {expected_destination_city, first_origin}
):
message = (
f"差旅行程终点识别为 {last_destination}"
f"与申报目的地 {expected_destination_city} 不一致,请补充多地出差或后续行程说明。"
)
flags.append(
self._with_related_item_ids(
{
"source": "submission_review",
"severity": "high",
"label": "行程终点异常",
"message": message,
"rule_code": policy.rule_code,
},
self._itinerary_segment_item_ids(itinerary_segments),
)
)
blocking_reasons.append("差旅行程终点与申报目的地不一致,请补充多地出差说明或补齐后续票据。")
expected_city_set = {
city
for city in (expected_destination_city, first_origin)
if city
}
extra_destinations = [
city
for city in unique_destinations
if city and city not in expected_city_set
]
if extra_destinations and not has_route_exception:
destinations_text = "".join(extra_destinations[:3])
affected_segments = [
segment
for segment in itinerary_segments
if segment["origin"] in extra_destinations or segment["destination"] in extra_destinations
]
flags.append(
self._with_related_item_ids(
{
"source": "submission_review",
"severity": "high",
"label": "多城市行程待说明",
"message": (
f"检测到本次差旅涉及 {destinations_text} 多个目的地,"
"但当前报销事由未说明中转、多地拜访或改签原因。"
),
"rule_code": policy.rule_code,
},
self._itinerary_segment_item_ids(affected_segments or itinerary_segments),
)
)
blocking_reasons.append("检测到多城市差旅行程,但当前未补充中转或多地出差说明。")
allowed_hotel_cities = {
city
for city in [expected_destination_city, *unique_itinerary_cities]
if city
}
for context in hotel_contexts:
hotel_city = self._extract_hotel_city(context, policy)
if hotel_city and allowed_hotel_cities and hotel_city not in allowed_hotel_cities:
expected_text = "".join(sorted(allowed_hotel_cities))
flags.append(
self._with_related_item_ids(
{
"source": "submission_review",
"severity": "high",
"label": "酒店地点异常",
"message": (
f"酒店票据识别城市为 {hotel_city}"
f"与当前差旅目的地/行程城市 {expected_text} 不一致,请补充异地住宿原因。"
),
"rule_code": policy.rule_code,
},
[self._context_item_id(context)],
)
)
blocking_reasons.append("酒店票据地点与差旅目的地不一致,请补充异地住宿原因或更换附件。")
if grade_band is None:
continue
baseline_city = hotel_city or expected_destination_city
standard = self._resolve_travel_policy_hotel_standard(
policy=policy,
grade_band=grade_band,
city=baseline_city,
)
if standard is None:
continue
cap, standard_label = standard
night_count = self._extract_hotel_night_count(context)
item_amount = Decimal(context["item"].item_amount or Decimal("0.00")).quantize(Decimal("0.01"))
nightly_amount = (item_amount / Decimal(max(night_count, 1))).quantize(Decimal("0.01"))
if nightly_amount <= cap:
continue
hotel_message = (
f"{band_label} 职级在{standard_label}的住宿标准为 {cap} 元/晚,"
f"当前酒店识别金额约 {nightly_amount} 元/晚。"
)
item_reason = " ".join(
[
str(context["item"].item_reason or "").strip(),
str(context["item"].item_note or "").strip(),
]
).strip()
item_has_exception = self._text_contains_keywords(item_reason, policy.standard_exception_keywords)
if has_standard_exception or item_has_exception:
flags.append(
self._with_related_item_ids(
{
"source": "submission_review",
"severity": "medium",
"label": "住宿超标提醒",
"message": hotel_message + " 已识别到补充说明,请直属领导重点复核。",
"rule_code": policy.rule_code,
},
[self._context_item_id(context)],
)
)
else:
flags.append(
self._with_related_item_ids(
{
"source": "submission_review",
"severity": "high",
"label": "住宿超标待说明",
"message": hotel_message + " 当前未识别到超标说明,请先补充原因。",
"rule_code": policy.rule_code,
},
[self._context_item_id(context)],
)
)
blocking_reasons.append("住宿金额超出当前职级差标,且未补充超标说明。")
if grade_band is not None:
for context in contexts:
transport_class = self._detect_transport_class(context, policy)
if transport_class is None:
continue
transport_kind, class_label, class_level = transport_class
allowed_level = self._resolve_travel_policy_transport_level(
policy,
grade_band=grade_band,
transport_kind=transport_kind,
)
if allowed_level is None or class_level <= allowed_level:
continue
item_reason = " ".join(
[
str(context["item"].item_reason or "").strip(),
str(context["item"].item_note or "").strip(),
]
).strip()
item_has_exception = self._text_contains_keywords(item_reason, policy.standard_exception_keywords)
message = f"{band_label} 职级当前默认不可报销 {class_label}"
if has_standard_exception or item_has_exception:
flags.append(
self._with_related_item_ids(
{
"source": "submission_review",
"severity": "medium",
"label": "交通舱位超标提醒",
"message": message + " 已识别到补充说明,请审批人重点复核。",
"rule_code": policy.rule_code,
},
[self._context_item_id(context)],
)
)
else:
flags.append(
self._with_related_item_ids(
{
"source": "submission_review",
"severity": "high",
"label": "交通舱位超标待说明",
"message": message + " 当前未识别到例外说明,请先补充原因。",
"rule_code": policy.rule_code,
},
[self._context_item_id(context)],
)
)
blocking_reasons.append("交通舱位或席别超出当前职级差标,且未补充例外说明。")
return {
"flags": [with_risk_business_stage(flag, "reimbursement") for flag in flags],
"blocking_reasons": list(dict.fromkeys(reason for reason in blocking_reasons if reason)),
}
def _build_claim_attachment_contexts(self, claim: ExpenseClaim) -> list[dict[str, Any]]:
contexts: list[dict[str, Any]] = []
ordered_items = sorted(
claim.items,
key=lambda item: (
item.item_date or date.max,
self._normalize_sort_datetime(item.created_at),
),
)
for index, item in enumerate(ordered_items, start=1):
file_path = self._attachment_storage.resolve_path(item.invoice_id)
if file_path is None or not file_path.exists():
continue
metadata = self._attachment_storage.read_meta(file_path)
document_info = metadata.get("document_info")
contexts.append(
{
"index": index,
"item": item,
"document_info": document_info if isinstance(document_info, dict) else {},
"ocr_text": str(metadata.get("ocr_text") or ""),
"ocr_summary": str(metadata.get("ocr_summary") or ""),
}
)
return contexts
@staticmethod
def _context_item_id(context: dict[str, Any]) -> str:
item = context.get("item") if isinstance(context, dict) else None
return str(getattr(item, "id", "") or "").strip()
@classmethod
def _itinerary_segment_item_ids(cls, segments: list[dict[str, Any]]) -> list[str]:
item_ids: list[str] = []
seen: set[str] = set()
for segment in list(segments or []):
item = segment.get("item") if isinstance(segment, dict) else None
item_id = str(getattr(item, "id", "") or "").strip()
if item_id and item_id not in seen:
seen.add(item_id)
item_ids.append(item_id)
return item_ids
@staticmethod
def _with_related_item_ids(flag: dict[str, Any], item_ids: list[str]) -> dict[str, Any]:
normalized_item_ids = list(
dict.fromkeys(str(item_id or "").strip() for item_id in list(item_ids or []) if str(item_id or "").strip())
)
if not normalized_item_ids:
return flag
flag["item_ids"] = normalized_item_ids
if len(normalized_item_ids) == 1:
flag["item_id"] = normalized_item_ids[0]
return flag
def _is_travel_policy_relevant_context(
self,
context: dict[str, Any],
policy: RuntimeTravelPolicy,
) -> bool:
item = context.get("item")
document_info = context.get("document_info") or {}
item_type = str(getattr(item, "item_type", "") or "").strip().lower()
scene_code = str(document_info.get("scene_code") or "").strip().lower()
document_type = str(document_info.get("document_type") or "").strip().lower()
return (
item_type in set(policy.relevant_expense_types)
or scene_code in set(policy.relevant_expense_types)
or document_type in {"hotel_invoice", *set(policy.long_distance_document_types)}
)
@staticmethod
def _resolve_document_field_value(document_info: dict[str, Any], key: str) -> str:
normalized_key = str(key or "").strip().lower()
for field in list(document_info.get("fields") or []):
if not isinstance(field, dict):
continue
field_key = str(field.get("key") or "").strip().lower()
if field_key == normalized_key:
return str(field.get("value") or "").strip()
return ""
@staticmethod
def _text_contains_keywords(text: str, keywords: tuple[str, ...] | list[str]) -> bool:
compact = re.sub(r"\s+", "", str(text or ""))
if not compact:
return False
return any(keyword in compact for keyword in keywords)
def _build_travel_reason_corpus(self, claim: ExpenseClaim) -> str:
parts = [str(claim.reason or "").strip(), str(claim.location or "").strip()]
for item in claim.items:
parts.append(str(item.item_reason or "").strip())
parts.append(str(item.item_note or "").strip())
parts.append(str(item.item_location or "").strip())
return "\n".join(part for part in parts if part)
@staticmethod
def _resolve_travel_policy_band(grade: str | None) -> str | None:
return resolve_travel_policy_grade_key(grade)
@staticmethod
def _resolve_travel_policy_transport_level(
policy: RuntimeTravelPolicy,
*,
grade_band: str,
transport_kind: str,
) -> int | None:
for candidate in travel_policy_grade_key_candidates(grade_band):
allowed_level = (getattr(policy, "transport_limits", {}) or {}).get(
candidate, {}
).get(transport_kind)
if allowed_level is not None:
return allowed_level
return None
def _resolve_expected_travel_city(
self,
claim: ExpenseClaim,
contexts: list[dict[str, Any]],
itinerary_cities: list[str],
policy: RuntimeTravelPolicy,
) -> str:
claim_city = self._extract_city_from_text(str(claim.location or ""), policy)
if claim_city:
return claim_city
for context in contexts:
hotel_city = self._extract_hotel_city(context, policy)
if hotel_city:
return hotel_city
if len(itinerary_cities) >= 2 and itinerary_cities[1]:
return itinerary_cities[1]
for city in itinerary_cities:
if city:
return city
return ""
def _extract_route_segment(
self,
context: dict[str, Any],
policy: RuntimeTravelPolicy,
) -> tuple[str, str] | None:
document_info = context["document_info"]
route_value = self._resolve_document_field_value(document_info, "route")
if not route_value or "-" not in route_value:
return None
origin_text, destination_text = [segment.strip() for segment in route_value.split("-", 1)]
origin_city = self._extract_city_from_text(origin_text, policy)
destination_city = self._extract_city_from_text(destination_text, policy)
if not origin_city or not destination_city or origin_city == destination_city:
return None
return origin_city, destination_city
def _extract_hotel_city(self, context: dict[str, Any], policy: RuntimeTravelPolicy) -> str:
document_info = context["document_info"]
item = context["item"]
merchant_name = self._resolve_document_field_value(document_info, "merchant_name")
for candidate in (
merchant_name,
str(item.item_location or ""),
str(context.get("ocr_summary") or ""),
str(context.get("ocr_text") or ""),
):
city = self._extract_city_from_text(candidate, policy)
if city:
return city
return ""
@staticmethod
def _format_travel_policy_city_tier(city_tier: str) -> str:
return {
"tier_1": "一线城市",
"tier_2": "重点城市",
"tier_3": "其他城市",
}.get(str(city_tier or "").strip(), "当前城市")
def _resolve_travel_policy_hotel_standard(
self,
*,
policy: RuntimeTravelPolicy,
grade_band: str,
city: str,
) -> tuple[Decimal, str] | None:
normalized_city = str(city or "").strip()
city_limits = getattr(policy, "hotel_city_limits", {}) or {}
city_entry = city_limits.get(normalized_city) if normalized_city else None
for candidate in travel_policy_grade_key_candidates(grade_band):
if city_entry and city_entry.get(candidate) is not None:
cap = Decimal(city_entry[candidate]).quantize(Decimal("0.01"))
return cap, normalized_city
city_tier = (getattr(policy, "city_tiers", {}) or {}).get(normalized_city, "tier_3")
for candidate in travel_policy_grade_key_candidates(grade_band):
tier_entry = (getattr(policy, "hotel_limits", {}) or {}).get(candidate, {})
tier_cap = tier_entry.get(city_tier)
if tier_cap is None:
continue
tier_label = self._format_travel_policy_city_tier(city_tier)
cap = Decimal(tier_cap).quantize(Decimal("0.01"))
return cap, tier_label
return None
@staticmethod
def _extract_city_from_text(text: str, policy: RuntimeTravelPolicy) -> str:
normalized = str(text or "").strip()
if not normalized:
return ""
city_names = set(policy.city_tiers.keys())
city_names.update((getattr(policy, "hotel_city_limits", {}) or {}).keys())
city_match_order = sorted(city_names, key=lambda item: len(item), reverse=True)
for city in city_match_order:
if city in normalized:
return city
return ""
@staticmethod
def _extract_hotel_night_count(context: dict[str, Any]) -> int:
text = " ".join(
[
str(context.get("ocr_summary") or "").strip(),
str(context.get("ocr_text") or "").strip(),
]
).strip()
match = TRAVEL_POLICY_HOTEL_NIGHT_PATTERN.search(text)
if not match:
return 1
try:
return max(1, int(match.group(1)))
except (TypeError, ValueError):
return 1
def _detect_transport_class(
self,
context: dict[str, Any],
policy: RuntimeTravelPolicy,
) -> tuple[str, str, int] | None:
document_info = context["document_info"]
document_type = str(document_info.get("document_type") or "").strip().lower()
text = " ".join(
[
str(context.get("ocr_summary") or "").strip(),
str(context.get("ocr_text") or "").strip(),
]
).strip()
compact_text = re.sub(r"\s+", "", text)
if not compact_text:
return None
if document_type == "flight_itinerary":
for config in policy.flight_classes:
label = str(config.keyword or "").strip()
level = int(config.level)
if label in compact_text:
return "flight", label, level
return None
if document_type == "train_ticket":
for config in policy.train_classes:
label = str(config.keyword or "").strip()
level = int(config.level)
if label in compact_text:
return "train", label, level
return None
return None
def _is_long_distance_travel_context(
self,
context: dict[str, Any],
policy: RuntimeTravelPolicy,
) -> bool:
document_info = context["document_info"]
document_type = str(document_info.get("document_type") or "").strip().lower()
scene_code = str(document_info.get("scene_code") or "").strip().lower()
if document_type in set(policy.long_distance_document_types):
return True
return scene_code == "travel"