Files
X-Financial/server/src/app/services/expense_claim_item_sync.py

705 lines
30 KiB
Python
Raw Normal View History

from __future__ import annotations
import re
from datetime import UTC, date, datetime, timedelta
from decimal import Decimal, InvalidOperation
from types import SimpleNamespace
from typing import Any
from sqlalchemy import or_, select
from sqlalchemy import inspect as sqlalchemy_inspect
from app.api.deps import CurrentUserContext
from app.core.agent_enums import AgentAssetDomain, AgentAssetStatus, AgentAssetType
from app.models.agent_asset import AgentAsset
from app.models.employee import Employee
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
from app.schemas.reimbursement import TravelReimbursementCalculatorRequest
from app.services.agent_asset_rule_library import AgentAssetRuleLibraryManager
from app.services.agent_asset_spreadsheet import RISK_RULES_LIBRARY
from app.services.expense_claim_constants import (
AI_REVIEW_LOOKBACK_DAYS,
AI_REVIEW_REPEAT_RISK_BLOCK_COUNT,
AI_REVIEW_REPEAT_RISK_WARNING_COUNT,
DOCUMENT_FACT_ITEM_TYPES,
LOCATION_REQUIRED_EXPENSE_TYPES,
OPTIONAL_ATTACHMENT_ITEM_TYPES,
STANDARD_ADJUSTMENT_RISK_SOURCE,
SYSTEM_GENERATED_ITEM_TYPES,
TRAVEL_ALLOWANCE_TRIGGER_ITEM_TYPES,
TRAVEL_POLICY_HOTEL_NIGHT_PATTERN,
)
from app.services.expense_claim_risk_stage import with_risk_business_stage
from app.services.expense_rule_runtime import (
ExpenseRuleRuntimeService,
RuntimeTravelPolicy,
build_default_expense_rule_catalog,
)
class ExpenseClaimItemSyncMixin:
def _sync_travel_allowance_item(self, claim: ExpenseClaim) -> None:
items = list(claim.items or [])
allowance_items = [
item for item in items if str(item.item_type or "").strip().lower() == "travel_allowance"
]
business_items = [
item for item in items if str(item.item_type or "").strip().lower() != "travel_allowance"
]
business_types = {str(item.item_type or "").strip().lower() for item in business_items}
is_travel_claim = str(claim.expense_type or "").strip().lower() == "travel"
has_travel_detail = bool(business_types & TRAVEL_ALLOWANCE_TRIGGER_ITEM_TYPES)
if not is_travel_claim and not has_travel_detail:
for item in allowance_items:
self._discard_claim_item(claim, item)
return
grade = self._resolve_claim_employee_grade(claim)
if not grade:
return
allowance_location = self._resolve_travel_allowance_location_from_claim(
claim=claim,
business_items=business_items,
)
if not allowance_location:
return
existing_allowance = allowance_items[0] if allowance_items else None
days, start_date, end_date = self._resolve_travel_allowance_days_from_claim(
claim=claim,
business_items=business_items,
existing_allowance=existing_allowance,
)
if days < 1:
return
try:
from app.services.travel_reimbursement_calculator import (
TravelReimbursementCalculatorService,
)
result = TravelReimbursementCalculatorService(self.db).calculate(
TravelReimbursementCalculatorRequest(
days=days,
location=allowance_location,
grade=grade,
),
CurrentUserContext(
username=str(claim.employee_id or claim.employee_name or "system"),
name=str(claim.employee_name or ""),
role_codes=[],
is_admin=False,
),
)
except ValueError:
return
allowance_amount = Decimal(result.allowance_amount or Decimal("0.00")).quantize(Decimal("0.01"))
allowance_rate = Decimal(result.total_allowance_rate or Decimal("0.00")).quantize(Decimal("0.01"))
if allowance_amount <= Decimal("0.00") or allowance_rate <= Decimal("0.00"):
return
item = existing_allowance
if item is None:
item = ExpenseClaimItem(claim_id=claim.id)
claim.items.append(item)
self.db.add(item)
for duplicate in allowance_items[1:]:
self._discard_claim_item(claim, duplicate)
item.item_date = end_date
item.item_type = "travel_allowance"
item.item_reason = (
f"系统自动计算出差补贴:{result.matched_city}{days}天,"
f"{allowance_rate:.2f}元/天"
)
item.item_location = str(result.allowance_region or allowance_location).strip()
item.item_amount = allowance_amount
item.invoice_id = None
def _resolve_claim_employee_grade(self, claim: ExpenseClaim) -> str:
grade = str(claim.employee_grade or "").strip()
if grade:
return grade
employee_id = str(claim.employee_id or "").strip()
if not employee_id:
return ""
employee = self.db.get(Employee, employee_id)
return str(employee.grade if employee is not None and employee.grade else "").strip()
def _discard_claim_item(self, claim: ExpenseClaim, item: ExpenseClaimItem) -> None:
if item in claim.items:
claim.items.remove(item)
state = sqlalchemy_inspect(item)
if state.persistent:
self.db.delete(item)
elif state.pending:
self.db.expunge(item)
def _resolve_travel_allowance_days_from_claim(
self,
*,
claim: ExpenseClaim,
business_items: list[ExpenseClaimItem],
existing_allowance: ExpenseClaimItem | None,
) -> tuple[int, date, date]:
dated_items = sorted(
[item.item_date for item in business_items if item.item_date is not None]
)
if dated_items:
start_date = dated_items[0]
end_date = dated_items[-1]
elif claim.occurred_at is not None:
start_date = claim.occurred_at.date()
end_date = start_date
else:
start_date = date.today()
end_date = start_date
days = (end_date - start_date).days + 1
application_days = self._resolve_travel_allowance_days_from_application_link(claim)
explicit_days = max(
(self._extract_travel_day_count(item.item_reason) for item in business_items),
default=0,
)
unique_dates = {value for value in dated_items}
if application_days is not None and application_days[0] > days and len(unique_dates) <= 1:
return application_days
if explicit_days > 0:
days = explicit_days
end_date = start_date + timedelta(days=days - 1)
if application_days is not None and application_days[0] > days and len(unique_dates) <= 1:
return application_days
return max(1, days), start_date, end_date
existing_days = self._extract_travel_allowance_days(existing_allowance)
if existing_days > days and len(unique_dates) <= 1:
days = existing_days
end_date = start_date + timedelta(days=days - 1)
return max(1, days), start_date, end_date
def _resolve_travel_allowance_days_from_application_link(
self,
claim: ExpenseClaim,
) -> tuple[int, date, date] | None:
values = self._collect_application_link_values(claim)
if not values:
return None
time_text = str(
values.get("application_business_time")
or values.get("business_time")
or values.get("time_range")
or values.get("application_time")
or values.get("time")
or ""
).strip()
dates = self._extract_application_link_dates(time_text)
if len(dates) >= 2:
start_date, end_date = dates[0], dates[-1]
if end_date < start_date:
start_date, end_date = end_date, start_date
return max(1, (end_date - start_date).days + 1), start_date, end_date
days = self._extract_travel_day_count(
str(values.get("application_days") or values.get("days") or "").strip()
)
if days <= 0:
return None
start_date = dates[0] if dates else claim.occurred_at.date() if claim.occurred_at is not None else date.today()
end_date = start_date + timedelta(days=days - 1)
return days, start_date, end_date
def _collect_application_link_values(self, claim: ExpenseClaim) -> dict[str, Any]:
values: dict[str, Any] = {}
for flag in list(claim.risk_flags_json or []):
if not isinstance(flag, dict):
continue
if str(flag.get("source") or "").strip() not in {"application_link", "application_handoff"}:
continue
for source in (
flag.get("expense_scene_selection"),
flag.get("review_form_values"),
flag.get("application_detail"),
flag,
):
if isinstance(source, dict):
values.update(source)
linked_detail = self._resolve_linked_application_detail_values(values)
for key, value in linked_detail.items():
values.setdefault(key, value)
return values
def _resolve_linked_application_detail_values(self, values: dict[str, Any]) -> dict[str, Any]:
application_claim = self._find_linked_application_claim(values)
if application_claim is None:
return {}
detail: dict[str, Any] = {}
for flag in list(application_claim.risk_flags_json or []):
if not isinstance(flag, dict) or str(flag.get("source") or "").strip() != "application_detail":
continue
payload = flag.get("application_detail") or flag.get("applicationDetail") or {}
if isinstance(payload, dict):
detail.update(payload)
if detail.get("time"):
detail.setdefault("application_time", detail.get("time"))
if detail.get("days"):
detail.setdefault("application_days", detail.get("days"))
if detail.get("transport_mode"):
detail.setdefault("application_transport_mode", detail.get("transport_mode"))
if detail.get("location"):
detail.setdefault("application_location", detail.get("location"))
if detail.get("reason"):
detail.setdefault("application_reason", detail.get("reason"))
if application_claim.occurred_at is not None:
detail.setdefault("application_time", application_claim.occurred_at.date().isoformat())
detail.setdefault("time", application_claim.occurred_at.date().isoformat())
detail.setdefault("application_reason", str(application_claim.reason or "").strip())
detail.setdefault("application_location", str(application_claim.location or "").strip())
return {str(key): value for key, value in detail.items() if str(value or "").strip()}
def _find_linked_application_claim(self, values: dict[str, Any]) -> ExpenseClaim | None:
application_claim_id = str(
values.get("application_claim_id")
or values.get("applicationClaimId")
or ""
).strip()
if application_claim_id:
linked_claim = self.db.get(ExpenseClaim, application_claim_id)
if linked_claim is not None:
return linked_claim
application_claim_no = str(
values.get("application_claim_no")
or values.get("applicationClaimNo")
or ""
).strip()
if not application_claim_no:
return None
return self.db.scalar(
select(ExpenseClaim).where(ExpenseClaim.claim_no == application_claim_no)
)
@staticmethod
def _extract_application_link_dates(value: str) -> list[date]:
dates: list[date] = []
for matched in re.findall(r"\d{4}-\d{2}-\d{2}", str(value or "")):
try:
dates.append(date.fromisoformat(matched))
except ValueError:
continue
return dates
@staticmethod
def _extract_travel_allowance_days(item: ExpenseClaimItem | None) -> int:
if item is None:
return 0
match = re.search(r"(\d+)\s*天", str(item.item_reason or ""))
if not match:
return 0
try:
return max(0, int(match.group(1)))
except ValueError:
return 0
def _resolve_travel_allowance_location_from_claim(
self,
*,
claim: ExpenseClaim,
business_items: list[ExpenseClaimItem],
) -> str:
claim_location = str(claim.location or "").strip()
if claim_location and claim_location not in {"待补充", "未知", "暂无", "非必填"}:
return claim_location
sorted_items = sorted(
business_items,
key=lambda item: (item.item_date or date.max, self._normalize_sort_datetime(item.created_at)),
)
for item in sorted_items:
location = str(item.item_location or "").strip()
if location and location not in {"待补充", "未知", "暂无", "非必填"}:
return location
reason = str(item.item_reason or "").strip()
for separator in ("-", "", "", "", "->"):
if separator in reason:
destination = reason.split(separator)[-1].strip()
if destination:
return destination
return ""
@staticmethod
def _parse_standard_adjustment_amount(value: Any) -> Decimal | None:
try:
raw_value = "" if value is None else value
amount = Decimal(str(raw_value)).quantize(Decimal("0.01"))
except (InvalidOperation, ValueError):
return None
return amount if amount >= Decimal("0.00") else None
def _collect_standard_adjusted_amounts(self, claim: ExpenseClaim) -> dict[str, Decimal]:
adjusted_amounts: dict[str, Decimal] = {}
for flag in list(claim.risk_flags_json or []):
if not isinstance(flag, dict):
continue
if str(flag.get("source") or "").strip() != STANDARD_ADJUSTMENT_RISK_SOURCE:
continue
item_id = str(flag.get("item_id") or flag.get("itemId") or "").strip()
if not item_id:
continue
amount = self._parse_standard_adjustment_amount(
flag.get("reimbursable_amount") or flag.get("reimbursableAmount")
)
if amount is None:
continue
adjusted_amounts[item_id] = amount
return adjusted_amounts
def _resolve_item_amount_for_claim_total(
self,
item: ExpenseClaimItem,
adjusted_amounts: dict[str, Decimal],
) -> Decimal:
original_amount = Decimal(item.item_amount or Decimal("0.00")).quantize(Decimal("0.01"))
item_id = str(item.id or "").strip()
adjusted_amount = adjusted_amounts.get(item_id)
if adjusted_amount is None:
return original_amount
return min(max(adjusted_amount, Decimal("0.00")), original_amount)
def _sync_claim_from_items(self, claim: ExpenseClaim) -> None:
self._sync_travel_allowance_item(claim)
if not claim.items:
claim.amount = Decimal("0.00")
claim.invoice_count = 0
claim.risk_flags_json = self._merge_claim_attachment_risk_flags(claim, [])
claim.risk_flags_json = self._merge_claim_platform_risk_preview_flags(claim, [])
return
ordered_items = sorted(
claim.items,
key=lambda item: (
item.item_date or date.max,
self._normalize_sort_datetime(item.created_at),
),
)
primary_item = ordered_items[0]
adjusted_amounts = self._collect_standard_adjusted_amounts(claim)
total_amount = sum(
(self._resolve_item_amount_for_claim_total(item, adjusted_amounts) for item in ordered_items),
Decimal("0.00"),
)
claim.amount = total_amount.quantize(Decimal("0.01"))
claim.invoice_count = sum(1 for item in ordered_items if str(item.invoice_id or "").strip())
claim.occurred_at = datetime(
primary_item.item_date.year,
primary_item.item_date.month,
primary_item.item_date.day,
tzinfo=UTC,
)
claim.expense_type = self._resolve_claim_expense_type_from_items(
ordered_items,
fallback=str(primary_item.item_type or claim.expense_type or "other").strip() or "other",
)
primary_item_type = str(primary_item.item_type or "").strip()
if primary_item_type not in DOCUMENT_FACT_ITEM_TYPES:
claim.reason = (
self._normalize_optional_text(primary_item.item_reason, fallback=claim.reason or "待补充")
or "待补充"
)
claim.location = (
self._normalize_optional_text(primary_item.item_location, fallback=claim.location or "待补充")
or "待补充"
)
claim.risk_flags_json = self._merge_claim_attachment_risk_flags(
claim,
self._build_claim_attachment_risk_flags(ordered_items),
)
self._refresh_claim_platform_risk_preview_flags(claim)
if str(claim.status or "").strip().lower() == "draft":
claim.approval_stage = "待提交"
@staticmethod
def _resolve_claim_expense_type_from_items(
items: list[ExpenseClaimItem],
*,
fallback: str,
) -> str:
fallback_type = str(fallback or "").strip() or "other"
item_types = {str(item.item_type or "").strip().lower() for item in items}
if item_types & (TRAVEL_ALLOWANCE_TRIGGER_ITEM_TYPES | {"travel_allowance"}):
return "travel"
return fallback_type
def _refresh_item_attachment_analysis(self, item: ExpenseClaimItem) -> None:
file_path = self._attachment_storage.resolve_path(item.invoice_id)
if file_path is None or not file_path.exists():
return
metadata = self._attachment_storage.read_meta(file_path)
media_type = str(metadata.get("media_type") or self._attachment_presentation.resolve_media_type(file_path.name)).strip()
ocr_status = str(metadata.get("ocr_status") or "").strip().lower()
if ocr_status == "failed":
analysis = self._build_failed_ocr_attachment_analysis(
media_type=media_type,
error_message=str(metadata.get("ocr_error") or ""),
item=item,
)
elif ocr_status == "recognized" or any(
(
str(metadata.get("ocr_text") or "").strip(),
str(metadata.get("ocr_summary") or "").strip(),
int(metadata.get("ocr_line_count") or 0),
list(metadata.get("ocr_warnings") or []),
)
):
stored_document_info = metadata.get("document_info")
if not isinstance(stored_document_info, dict):
stored_document_info = {}
document = SimpleNamespace(
filename=str(metadata.get("file_name") or file_path.name),
text=str(metadata.get("ocr_text") or ""),
summary=str(metadata.get("ocr_summary") or ""),
avg_score=float(metadata.get("ocr_avg_score") or 0.0),
line_count=int(metadata.get("ocr_line_count") or 0),
document_type=str(stored_document_info.get("document_type") or ""),
document_type_label=str(stored_document_info.get("document_type_label") or ""),
scene_code=str(stored_document_info.get("scene_code") or ""),
scene_label=str(stored_document_info.get("scene_label") or ""),
document_fields=list(stored_document_info.get("fields") or []),
warnings=[str(value) for value in list(metadata.get("ocr_warnings") or []) if str(value).strip()],
)
document_info = self._build_attachment_document_info(document)
requirement_check = self._build_attachment_requirement_check(
item=item,
document_info=document_info,
)
analysis = self._build_attachment_analysis(
document=document,
item=item,
claim=getattr(item, "claim", None),
document_info=document_info,
requirement_check=requirement_check,
)
metadata["document_info"] = document_info
metadata["requirement_check"] = requirement_check
else:
analysis = self._build_fallback_attachment_analysis(media_type=media_type, item=item)
metadata["analysis"] = analysis
self._attachment_storage.write_meta(file_path, metadata)
def _build_claim_attachment_risk_flags(
self, ordered_items: list[ExpenseClaimItem]
) -> list[dict[str, Any]]:
derived_flags: list[dict[str, Any]] = []
for index, item in enumerate(ordered_items, start=1):
file_path = self._attachment_storage.resolve_path(item.invoice_id)
if file_path is None or not file_path.exists():
continue
metadata = self._attachment_storage.read_meta(file_path)
analysis = metadata.get("analysis")
if not isinstance(analysis, dict):
continue
severity = str(analysis.get("severity") or "").strip().lower()
if severity in {"", "pass", "low"}:
continue
summary = (
str(analysis.get("summary") or analysis.get("headline") or "").strip()
or "附件存在待核对风险。"
)
points = [
str(point or "").strip()
for point in list(analysis.get("points") or [])
if str(point or "").strip()
]
message_detail = "".join(points[:3]) if points else summary
label = str(
analysis.get("label") or ("高风险" if severity == "high" else "中风险")
).strip()
derived_flags.append(
with_risk_business_stage(
{
"source": "attachment_analysis",
"item_id": item.id,
"severity": severity,
"label": label,
"message": f"费用明细第 {index} 条:{message_detail}",
"summary": summary,
"points": points,
},
"reimbursement",
)
)
return derived_flags
def _get_expense_rule_catalog(self) -> Any:
cached = getattr(self, "_expense_rule_catalog", None)
if cached is not None:
return cached
db = getattr(self, "db", None)
if db is None:
catalog = build_default_expense_rule_catalog()
else:
catalog = ExpenseRuleRuntimeService(db).load_catalog()
setattr(self, "_expense_rule_catalog", catalog)
return catalog
def _get_expense_scene_policy(self, expense_type: str | None) -> Any | None:
return self._get_expense_rule_catalog().get_scene_policy(expense_type)
def _resolve_min_attachment_count(self, expense_type: str | None) -> int:
policy = self._get_expense_scene_policy(expense_type)
if policy is None:
return 1
return max(0, int(policy.min_attachment_count or 0))
@staticmethod
def _is_attachment_required_item_type(item_type: str | None) -> bool:
normalized = str(item_type or "").strip().lower()
return normalized not in SYSTEM_GENERATED_ITEM_TYPES and normalized not in OPTIONAL_ATTACHMENT_ITEM_TYPES
def _resolve_claim_required_attachment_count(self, claim: ExpenseClaim) -> int:
required_items = [
item for item in list(claim.items or [])
if self._is_attachment_required_item_type(item.item_type)
]
if not required_items:
return 0
return min(self._resolve_min_attachment_count(claim.expense_type), len(required_items))
def _build_scene_reason_corpus(self, claim: ExpenseClaim) -> str:
parts = [str(claim.reason or "").strip(), str(claim.location or "").strip()]
for item in claim.items:
parts.append(str(item.item_reason or "").strip())
parts.append(str(item.item_location or "").strip())
return "\n".join(part for part in parts if part)
@staticmethod
def _merge_claim_attachment_risk_flags(
claim: ExpenseClaim,
attachment_risk_flags: list[dict[str, Any]],
) -> list[Any]:
preserved_flags = [
flag
for flag in list(claim.risk_flags_json or [])
if not (isinstance(flag, dict) and str(flag.get("source") or "").strip() == "attachment_analysis")
]
return preserved_flags + attachment_risk_flags
def _refresh_claim_platform_risk_preview_flags(self, claim: ExpenseClaim) -> None:
if str(claim.expense_type or "").strip().lower().endswith("_application"):
return
evaluator = getattr(self, "evaluate_platform_risk_rules", None)
if not callable(evaluator):
return
try:
review = evaluator(claim, business_stage="reimbursement")
except Exception:
return
platform_flags = list(review.get("flags") or []) if isinstance(review, dict) else []
claim.risk_flags_json = self._merge_claim_platform_risk_preview_flags(
claim,
platform_flags,
)
@staticmethod
def _merge_claim_platform_risk_preview_flags(
claim: ExpenseClaim,
platform_flags: list[dict[str, Any]],
) -> list[Any]:
preserved_flags = [
flag
for flag in list(claim.risk_flags_json or [])
if not (
isinstance(flag, dict)
and str(flag.get("source") or "").strip() == "submission_review"
and str(flag.get("hit_source") or "").strip() == "rule_center"
)
]
return preserved_flags + platform_flags
@staticmethod
def _format_submission_blocked_message(issues: list[str]) -> str:
normalized_issues = [str(issue or "").strip() for issue in issues if str(issue or "").strip()]
if not normalized_issues:
return "自动检测未通过,但没有返回明确原因,请刷新草稿后重试。"
return "自动检测暂未通过,原因如下:\n" + "\n".join(
f"{index}. {issue}" for index, issue in enumerate(normalized_issues, start=1)
)
def _validate_claim_for_submission(self, claim: ExpenseClaim) -> list[str]:
issues: list[str] = []
claim_location_required = self._is_location_required_expense_type(claim.expense_type)
claim_min_attachment_count = self._resolve_claim_required_attachment_count(claim)
substantive_items = [
item
for item in list(claim.items or [])
if str(item.item_type or "").strip().lower() not in SYSTEM_GENERATED_ITEM_TYPES
and not self._is_submission_placeholder_item(item)
]
if self._is_missing_value(claim.employee_name):
issues.append("申请人未完善")
if self._is_missing_value(claim.department_name):
issues.append("所属部门未完善")
if self._is_missing_value(claim.expense_type):
issues.append("报销类型未完善")
if self._is_missing_value(claim.reason):
issues.append("报销事由未完善")
if claim_location_required and self._is_missing_value(claim.location):
issues.append("业务地点未完善")
if claim.amount is None or claim.amount <= Decimal("0.00"):
issues.append("报销金额未完善")
if claim.occurred_at is None:
issues.append("发生时间未完善")
if int(claim.invoice_count or 0) < claim_min_attachment_count:
issues.append("票据附件数量不足")
if not substantive_items:
issues.append("费用明细不能为空")
for index, item in enumerate(claim.items, start=1):
prefix = f"费用明细第 {index}"
is_system_generated = str(item.item_type or "").strip().lower() in SYSTEM_GENERATED_ITEM_TYPES
if is_system_generated or self._is_submission_placeholder_item(item):
continue
item_location_required = self._is_location_required_expense_type(item.item_type or claim.expense_type)
item_has_attachment = not self._is_missing_value(item.invoice_id)
if not item_has_attachment and item.item_date is None:
issues.append(f"{prefix}缺少日期")
if self._is_missing_value(item.item_type):
issues.append(f"{prefix}缺少费用项目")
if not item_has_attachment and self._is_missing_value(item.item_reason):
issues.append(f"{prefix}缺少说明")
if not item_has_attachment and item_location_required and self._is_missing_value(item.item_location):
issues.append(f"{prefix}缺少地点")
if not item_has_attachment and (item.item_amount is None or item.item_amount <= Decimal("0.00")):
issues.append(f"{prefix}缺少金额")
if self._is_attachment_required_item_type(item.item_type) and not item_has_attachment:
issues.append(f"{prefix}缺少票据标识")
return issues
def _is_submission_placeholder_item(self, item: ExpenseClaimItem) -> bool:
if not self._is_missing_value(item.invoice_id):
return False
missing_reason = self._is_missing_value(item.item_reason)
missing_location = self._is_missing_value(item.item_location)
missing_amount = item.item_amount is None or item.item_amount <= Decimal("0.00")
return missing_reason and missing_location and missing_amount
def _is_location_required_expense_type(self, expense_type: str | None) -> bool:
policy = self._get_expense_scene_policy(expense_type)
if policy is None:
return str(expense_type or "").strip().lower() in LOCATION_REQUIRED_EXPENSE_TYPES
return bool(policy.location_required)