from __future__ import annotations import re from datetime import UTC, date, datetime, timedelta from decimal import Decimal, InvalidOperation from types import SimpleNamespace from typing import Any from sqlalchemy import or_, select from sqlalchemy import inspect as sqlalchemy_inspect from app.api.deps import CurrentUserContext from app.core.agent_enums import AgentAssetDomain, AgentAssetStatus, AgentAssetType from app.models.agent_asset import AgentAsset from app.models.employee import Employee from app.models.financial_record import ExpenseClaim, ExpenseClaimItem from app.schemas.reimbursement import TravelReimbursementCalculatorRequest from app.services.agent_asset_rule_library import AgentAssetRuleLibraryManager from app.services.agent_asset_spreadsheet import RISK_RULES_LIBRARY from app.services.expense_claim_constants import ( AI_REVIEW_LOOKBACK_DAYS, AI_REVIEW_REPEAT_RISK_BLOCK_COUNT, AI_REVIEW_REPEAT_RISK_WARNING_COUNT, DOCUMENT_FACT_ITEM_TYPES, LOCATION_REQUIRED_EXPENSE_TYPES, OPTIONAL_ATTACHMENT_ITEM_TYPES, STANDARD_ADJUSTMENT_RISK_SOURCE, SYSTEM_GENERATED_ITEM_TYPES, TRAVEL_ALLOWANCE_TRIGGER_ITEM_TYPES, TRAVEL_POLICY_HOTEL_NIGHT_PATTERN, ) from app.services.expense_claim_risk_stage import with_risk_business_stage from app.services.expense_rule_runtime import ( ExpenseRuleRuntimeService, RuntimeTravelPolicy, build_default_expense_rule_catalog, ) class ExpenseClaimItemSyncMixin: def _sync_travel_allowance_item(self, claim: ExpenseClaim) -> None: items = list(claim.items or []) allowance_items = [ item for item in items if str(item.item_type or "").strip().lower() == "travel_allowance" ] business_items = [ item for item in items if str(item.item_type or "").strip().lower() != "travel_allowance" ] business_types = {str(item.item_type or "").strip().lower() for item in business_items} is_travel_claim = str(claim.expense_type or "").strip().lower() == "travel" has_travel_detail = bool(business_types & TRAVEL_ALLOWANCE_TRIGGER_ITEM_TYPES) if not is_travel_claim and not has_travel_detail: for item in allowance_items: self._discard_claim_item(claim, item) return grade = self._resolve_claim_employee_grade(claim) if not grade: return allowance_location = self._resolve_travel_allowance_location_from_claim( claim=claim, business_items=business_items, ) if not allowance_location: return existing_allowance = allowance_items[0] if allowance_items else None days, start_date, end_date = self._resolve_travel_allowance_days_from_claim( claim=claim, business_items=business_items, existing_allowance=existing_allowance, ) if days < 1: return try: from app.services.travel_reimbursement_calculator import ( TravelReimbursementCalculatorService, ) result = TravelReimbursementCalculatorService(self.db).calculate( TravelReimbursementCalculatorRequest( days=days, location=allowance_location, grade=grade, ), CurrentUserContext( username=str(claim.employee_id or claim.employee_name or "system"), name=str(claim.employee_name or ""), role_codes=[], is_admin=False, ), ) except ValueError: return allowance_amount = Decimal(result.allowance_amount or Decimal("0.00")).quantize(Decimal("0.01")) allowance_rate = Decimal(result.total_allowance_rate or Decimal("0.00")).quantize(Decimal("0.01")) if allowance_amount <= Decimal("0.00") or allowance_rate <= Decimal("0.00"): return item = existing_allowance if item is None: item = ExpenseClaimItem(claim_id=claim.id) claim.items.append(item) self.db.add(item) for duplicate in allowance_items[1:]: self._discard_claim_item(claim, duplicate) item.item_date = end_date item.item_type = "travel_allowance" item.item_reason = ( f"系统自动计算出差补贴:{result.matched_city},{days}天," f"{allowance_rate:.2f}元/天" ) item.item_location = str(result.allowance_region or allowance_location).strip() item.item_amount = allowance_amount item.invoice_id = None def _resolve_claim_employee_grade(self, claim: ExpenseClaim) -> str: grade = str(claim.employee_grade or "").strip() if grade: return grade employee_id = str(claim.employee_id or "").strip() if not employee_id: return "" employee = self.db.get(Employee, employee_id) return str(employee.grade if employee is not None and employee.grade else "").strip() def _discard_claim_item(self, claim: ExpenseClaim, item: ExpenseClaimItem) -> None: if item in claim.items: claim.items.remove(item) state = sqlalchemy_inspect(item) if state.persistent: self.db.delete(item) elif state.pending: self.db.expunge(item) def _resolve_travel_allowance_days_from_claim( self, *, claim: ExpenseClaim, business_items: list[ExpenseClaimItem], existing_allowance: ExpenseClaimItem | None, ) -> tuple[int, date, date]: dated_items = sorted( [item.item_date for item in business_items if item.item_date is not None] ) if dated_items: start_date = dated_items[0] end_date = dated_items[-1] elif claim.occurred_at is not None: start_date = claim.occurred_at.date() end_date = start_date else: start_date = date.today() end_date = start_date days = (end_date - start_date).days + 1 application_days = self._resolve_travel_allowance_days_from_application_link(claim) explicit_days = max( (self._extract_travel_day_count(item.item_reason) for item in business_items), default=0, ) unique_dates = {value for value in dated_items} if application_days is not None and application_days[0] > days and len(unique_dates) <= 1: return application_days if explicit_days > 0: days = explicit_days end_date = start_date + timedelta(days=days - 1) if application_days is not None and application_days[0] > days and len(unique_dates) <= 1: return application_days return max(1, days), start_date, end_date existing_days = self._extract_travel_allowance_days(existing_allowance) if existing_days > days and len(unique_dates) <= 1: days = existing_days end_date = start_date + timedelta(days=days - 1) return max(1, days), start_date, end_date def _resolve_travel_allowance_days_from_application_link( self, claim: ExpenseClaim, ) -> tuple[int, date, date] | None: values = self._collect_application_link_values(claim) if not values: return None time_text = str( values.get("application_business_time") or values.get("business_time") or values.get("time_range") or values.get("application_time") or values.get("time") or "" ).strip() dates = self._extract_application_link_dates(time_text) if len(dates) >= 2: start_date, end_date = dates[0], dates[-1] if end_date < start_date: start_date, end_date = end_date, start_date return max(1, (end_date - start_date).days + 1), start_date, end_date days = self._extract_travel_day_count( str(values.get("application_days") or values.get("days") or "").strip() ) if days <= 0: return None start_date = dates[0] if dates else claim.occurred_at.date() if claim.occurred_at is not None else date.today() end_date = start_date + timedelta(days=days - 1) return days, start_date, end_date def _collect_application_link_values(self, claim: ExpenseClaim) -> dict[str, Any]: values: dict[str, Any] = {} for flag in list(claim.risk_flags_json or []): if not isinstance(flag, dict): continue if str(flag.get("source") or "").strip() not in {"application_link", "application_handoff"}: continue for source in ( flag.get("expense_scene_selection"), flag.get("review_form_values"), flag.get("application_detail"), flag, ): if isinstance(source, dict): values.update(source) linked_detail = self._resolve_linked_application_detail_values(values) for key, value in linked_detail.items(): values.setdefault(key, value) return values def _resolve_linked_application_detail_values(self, values: dict[str, Any]) -> dict[str, Any]: application_claim = self._find_linked_application_claim(values) if application_claim is None: return {} detail: dict[str, Any] = {} for flag in list(application_claim.risk_flags_json or []): if not isinstance(flag, dict) or str(flag.get("source") or "").strip() != "application_detail": continue payload = flag.get("application_detail") or flag.get("applicationDetail") or {} if isinstance(payload, dict): detail.update(payload) if detail.get("time"): detail.setdefault("application_time", detail.get("time")) if detail.get("days"): detail.setdefault("application_days", detail.get("days")) if detail.get("transport_mode"): detail.setdefault("application_transport_mode", detail.get("transport_mode")) if detail.get("location"): detail.setdefault("application_location", detail.get("location")) if detail.get("reason"): detail.setdefault("application_reason", detail.get("reason")) if application_claim.occurred_at is not None: detail.setdefault("application_time", application_claim.occurred_at.date().isoformat()) detail.setdefault("time", application_claim.occurred_at.date().isoformat()) detail.setdefault("application_reason", str(application_claim.reason or "").strip()) detail.setdefault("application_location", str(application_claim.location or "").strip()) return {str(key): value for key, value in detail.items() if str(value or "").strip()} def _find_linked_application_claim(self, values: dict[str, Any]) -> ExpenseClaim | None: application_claim_id = str( values.get("application_claim_id") or values.get("applicationClaimId") or "" ).strip() if application_claim_id: linked_claim = self.db.get(ExpenseClaim, application_claim_id) if linked_claim is not None: return linked_claim application_claim_no = str( values.get("application_claim_no") or values.get("applicationClaimNo") or "" ).strip() if not application_claim_no: return None return self.db.scalar( select(ExpenseClaim).where(ExpenseClaim.claim_no == application_claim_no) ) @staticmethod def _extract_application_link_dates(value: str) -> list[date]: dates: list[date] = [] for matched in re.findall(r"\d{4}-\d{2}-\d{2}", str(value or "")): try: dates.append(date.fromisoformat(matched)) except ValueError: continue return dates @staticmethod def _extract_travel_allowance_days(item: ExpenseClaimItem | None) -> int: if item is None: return 0 match = re.search(r"(\d+)\s*天", str(item.item_reason or "")) if not match: return 0 try: return max(0, int(match.group(1))) except ValueError: return 0 def _resolve_travel_allowance_location_from_claim( self, *, claim: ExpenseClaim, business_items: list[ExpenseClaimItem], ) -> str: claim_location = str(claim.location or "").strip() if claim_location and claim_location not in {"待补充", "未知", "暂无", "非必填"}: return claim_location sorted_items = sorted( business_items, key=lambda item: (item.item_date or date.max, self._normalize_sort_datetime(item.created_at)), ) for item in sorted_items: location = str(item.item_location or "").strip() if location and location not in {"待补充", "未知", "暂无", "非必填"}: return location reason = str(item.item_reason or "").strip() for separator in ("-", "至", "到", "→", "->"): if separator in reason: destination = reason.split(separator)[-1].strip() if destination: return destination return "" @staticmethod def _parse_standard_adjustment_amount(value: Any) -> Decimal | None: try: raw_value = "" if value is None else value amount = Decimal(str(raw_value)).quantize(Decimal("0.01")) except (InvalidOperation, ValueError): return None return amount if amount >= Decimal("0.00") else None def _collect_standard_adjusted_amounts(self, claim: ExpenseClaim) -> dict[str, Decimal]: adjusted_amounts: dict[str, Decimal] = {} for flag in list(claim.risk_flags_json or []): if not isinstance(flag, dict): continue if str(flag.get("source") or "").strip() != STANDARD_ADJUSTMENT_RISK_SOURCE: continue item_id = str(flag.get("item_id") or flag.get("itemId") or "").strip() if not item_id: continue amount = self._parse_standard_adjustment_amount( flag.get("reimbursable_amount") or flag.get("reimbursableAmount") ) if amount is None: continue adjusted_amounts[item_id] = amount return adjusted_amounts def _resolve_item_amount_for_claim_total( self, item: ExpenseClaimItem, adjusted_amounts: dict[str, Decimal], ) -> Decimal: original_amount = Decimal(item.item_amount or Decimal("0.00")).quantize(Decimal("0.01")) item_id = str(item.id or "").strip() adjusted_amount = adjusted_amounts.get(item_id) if adjusted_amount is None: return original_amount return min(max(adjusted_amount, Decimal("0.00")), original_amount) def _sync_claim_from_items(self, claim: ExpenseClaim) -> None: self._sync_travel_allowance_item(claim) if not claim.items: claim.amount = Decimal("0.00") claim.invoice_count = 0 claim.risk_flags_json = self._merge_claim_attachment_risk_flags(claim, []) claim.risk_flags_json = self._merge_claim_platform_risk_preview_flags(claim, []) return ordered_items = sorted( claim.items, key=lambda item: ( item.item_date or date.max, self._normalize_sort_datetime(item.created_at), ), ) primary_item = ordered_items[0] adjusted_amounts = self._collect_standard_adjusted_amounts(claim) total_amount = sum( (self._resolve_item_amount_for_claim_total(item, adjusted_amounts) for item in ordered_items), Decimal("0.00"), ) claim.amount = total_amount.quantize(Decimal("0.01")) claim.invoice_count = sum(1 for item in ordered_items if str(item.invoice_id or "").strip()) claim.occurred_at = datetime( primary_item.item_date.year, primary_item.item_date.month, primary_item.item_date.day, tzinfo=UTC, ) claim.expense_type = self._resolve_claim_expense_type_from_items( ordered_items, fallback=str(primary_item.item_type or claim.expense_type or "other").strip() or "other", ) primary_item_type = str(primary_item.item_type or "").strip() if primary_item_type not in DOCUMENT_FACT_ITEM_TYPES: claim.reason = ( self._normalize_optional_text(primary_item.item_reason, fallback=claim.reason or "待补充") or "待补充" ) claim.location = ( self._normalize_optional_text(primary_item.item_location, fallback=claim.location or "待补充") or "待补充" ) claim.risk_flags_json = self._merge_claim_attachment_risk_flags( claim, self._build_claim_attachment_risk_flags(ordered_items), ) self._refresh_claim_platform_risk_preview_flags(claim) if str(claim.status or "").strip().lower() == "draft": claim.approval_stage = "待提交" @staticmethod def _resolve_claim_expense_type_from_items( items: list[ExpenseClaimItem], *, fallback: str, ) -> str: fallback_type = str(fallback or "").strip() or "other" item_types = {str(item.item_type or "").strip().lower() for item in items} if item_types & (TRAVEL_ALLOWANCE_TRIGGER_ITEM_TYPES | {"travel_allowance"}): return "travel" return fallback_type def _refresh_item_attachment_analysis(self, item: ExpenseClaimItem) -> None: file_path = self._attachment_storage.resolve_path(item.invoice_id) if file_path is None or not file_path.exists(): return metadata = self._attachment_storage.read_meta(file_path) media_type = str(metadata.get("media_type") or self._attachment_presentation.resolve_media_type(file_path.name)).strip() ocr_status = str(metadata.get("ocr_status") or "").strip().lower() if ocr_status == "failed": analysis = self._build_failed_ocr_attachment_analysis( media_type=media_type, error_message=str(metadata.get("ocr_error") or ""), item=item, ) elif ocr_status == "recognized" or any( ( str(metadata.get("ocr_text") or "").strip(), str(metadata.get("ocr_summary") or "").strip(), int(metadata.get("ocr_line_count") or 0), list(metadata.get("ocr_warnings") or []), ) ): stored_document_info = metadata.get("document_info") if not isinstance(stored_document_info, dict): stored_document_info = {} document = SimpleNamespace( filename=str(metadata.get("file_name") or file_path.name), text=str(metadata.get("ocr_text") or ""), summary=str(metadata.get("ocr_summary") or ""), avg_score=float(metadata.get("ocr_avg_score") or 0.0), line_count=int(metadata.get("ocr_line_count") or 0), document_type=str(stored_document_info.get("document_type") or ""), document_type_label=str(stored_document_info.get("document_type_label") or ""), scene_code=str(stored_document_info.get("scene_code") or ""), scene_label=str(stored_document_info.get("scene_label") or ""), document_fields=list(stored_document_info.get("fields") or []), warnings=[str(value) for value in list(metadata.get("ocr_warnings") or []) if str(value).strip()], ) document_info = self._build_attachment_document_info(document) requirement_check = self._build_attachment_requirement_check( item=item, document_info=document_info, ) analysis = self._build_attachment_analysis( document=document, item=item, claim=getattr(item, "claim", None), document_info=document_info, requirement_check=requirement_check, ) metadata["document_info"] = document_info metadata["requirement_check"] = requirement_check else: analysis = self._build_fallback_attachment_analysis(media_type=media_type, item=item) metadata["analysis"] = analysis self._attachment_storage.write_meta(file_path, metadata) def _build_claim_attachment_risk_flags( self, ordered_items: list[ExpenseClaimItem] ) -> list[dict[str, Any]]: derived_flags: list[dict[str, Any]] = [] for index, item in enumerate(ordered_items, start=1): file_path = self._attachment_storage.resolve_path(item.invoice_id) if file_path is None or not file_path.exists(): continue metadata = self._attachment_storage.read_meta(file_path) analysis = metadata.get("analysis") if not isinstance(analysis, dict): continue severity = str(analysis.get("severity") or "").strip().lower() if severity in {"", "pass", "low"}: continue summary = ( str(analysis.get("summary") or analysis.get("headline") or "").strip() or "附件存在待核对风险。" ) points = [ str(point or "").strip() for point in list(analysis.get("points") or []) if str(point or "").strip() ] message_detail = ";".join(points[:3]) if points else summary label = str( analysis.get("label") or ("高风险" if severity == "high" else "中风险") ).strip() derived_flags.append( with_risk_business_stage( { "source": "attachment_analysis", "item_id": item.id, "severity": severity, "label": label, "message": f"费用明细第 {index} 条:{message_detail}", "summary": summary, "points": points, }, "reimbursement", ) ) return derived_flags def _get_expense_rule_catalog(self) -> Any: cached = getattr(self, "_expense_rule_catalog", None) if cached is not None: return cached db = getattr(self, "db", None) if db is None: catalog = build_default_expense_rule_catalog() else: catalog = ExpenseRuleRuntimeService(db).load_catalog() setattr(self, "_expense_rule_catalog", catalog) return catalog def _get_expense_scene_policy(self, expense_type: str | None) -> Any | None: return self._get_expense_rule_catalog().get_scene_policy(expense_type) def _resolve_min_attachment_count(self, expense_type: str | None) -> int: policy = self._get_expense_scene_policy(expense_type) if policy is None: return 1 return max(0, int(policy.min_attachment_count or 0)) @staticmethod def _is_attachment_required_item_type(item_type: str | None) -> bool: normalized = str(item_type or "").strip().lower() return normalized not in SYSTEM_GENERATED_ITEM_TYPES and normalized not in OPTIONAL_ATTACHMENT_ITEM_TYPES def _resolve_claim_required_attachment_count(self, claim: ExpenseClaim) -> int: required_items = [ item for item in list(claim.items or []) if self._is_attachment_required_item_type(item.item_type) ] if not required_items: return 0 return min(self._resolve_min_attachment_count(claim.expense_type), len(required_items)) def _build_scene_reason_corpus(self, claim: ExpenseClaim) -> str: parts = [str(claim.reason or "").strip(), str(claim.location or "").strip()] for item in claim.items: parts.append(str(item.item_reason or "").strip()) parts.append(str(item.item_location or "").strip()) return "\n".join(part for part in parts if part) @staticmethod def _merge_claim_attachment_risk_flags( claim: ExpenseClaim, attachment_risk_flags: list[dict[str, Any]], ) -> list[Any]: preserved_flags = [ flag for flag in list(claim.risk_flags_json or []) if not (isinstance(flag, dict) and str(flag.get("source") or "").strip() == "attachment_analysis") ] return preserved_flags + attachment_risk_flags def _refresh_claim_platform_risk_preview_flags(self, claim: ExpenseClaim) -> None: if str(claim.expense_type or "").strip().lower().endswith("_application"): return evaluator = getattr(self, "evaluate_platform_risk_rules", None) if not callable(evaluator): return try: review = evaluator(claim, business_stage="reimbursement") except Exception: return platform_flags = list(review.get("flags") or []) if isinstance(review, dict) else [] claim.risk_flags_json = self._merge_claim_platform_risk_preview_flags( claim, platform_flags, ) @staticmethod def _merge_claim_platform_risk_preview_flags( claim: ExpenseClaim, platform_flags: list[dict[str, Any]], ) -> list[Any]: preserved_flags = [ flag for flag in list(claim.risk_flags_json or []) if not ( isinstance(flag, dict) and str(flag.get("source") or "").strip() == "submission_review" and str(flag.get("hit_source") or "").strip() == "rule_center" ) ] return preserved_flags + platform_flags @staticmethod def _format_submission_blocked_message(issues: list[str]) -> str: normalized_issues = [str(issue or "").strip() for issue in issues if str(issue or "").strip()] if not normalized_issues: return "自动检测未通过,但没有返回明确原因,请刷新草稿后重试。" return "自动检测暂未通过,原因如下:\n" + "\n".join( f"{index}. {issue}" for index, issue in enumerate(normalized_issues, start=1) ) def _validate_claim_for_submission(self, claim: ExpenseClaim) -> list[str]: issues: list[str] = [] claim_location_required = self._is_location_required_expense_type(claim.expense_type) claim_min_attachment_count = self._resolve_claim_required_attachment_count(claim) substantive_items = [ item for item in list(claim.items or []) if str(item.item_type or "").strip().lower() not in SYSTEM_GENERATED_ITEM_TYPES and not self._is_submission_placeholder_item(item) ] if self._is_missing_value(claim.employee_name): issues.append("申请人未完善") if self._is_missing_value(claim.department_name): issues.append("所属部门未完善") if self._is_missing_value(claim.expense_type): issues.append("报销类型未完善") if self._is_missing_value(claim.reason): issues.append("报销事由未完善") if claim_location_required and self._is_missing_value(claim.location): issues.append("业务地点未完善") if claim.amount is None or claim.amount <= Decimal("0.00"): issues.append("报销金额未完善") if claim.occurred_at is None: issues.append("发生时间未完善") if int(claim.invoice_count or 0) < claim_min_attachment_count: issues.append("票据附件数量不足") if not substantive_items: issues.append("费用明细不能为空") for index, item in enumerate(claim.items, start=1): prefix = f"费用明细第 {index} 条" is_system_generated = str(item.item_type or "").strip().lower() in SYSTEM_GENERATED_ITEM_TYPES if is_system_generated or self._is_submission_placeholder_item(item): continue item_location_required = self._is_location_required_expense_type(item.item_type or claim.expense_type) item_has_attachment = not self._is_missing_value(item.invoice_id) if not item_has_attachment and item.item_date is None: issues.append(f"{prefix}缺少日期") if self._is_missing_value(item.item_type): issues.append(f"{prefix}缺少费用项目") if not item_has_attachment and self._is_missing_value(item.item_reason): issues.append(f"{prefix}缺少说明") if not item_has_attachment and item_location_required and self._is_missing_value(item.item_location): issues.append(f"{prefix}缺少地点") if not item_has_attachment and (item.item_amount is None or item.item_amount <= Decimal("0.00")): issues.append(f"{prefix}缺少金额") if self._is_attachment_required_item_type(item.item_type) and not item_has_attachment: issues.append(f"{prefix}缺少票据标识") return issues def _is_submission_placeholder_item(self, item: ExpenseClaimItem) -> bool: if not self._is_missing_value(item.invoice_id): return False missing_reason = self._is_missing_value(item.item_reason) missing_location = self._is_missing_value(item.item_location) missing_amount = item.item_amount is None or item.item_amount <= Decimal("0.00") return missing_reason and missing_location and missing_amount def _is_location_required_expense_type(self, expense_type: str | None) -> bool: policy = self._get_expense_scene_policy(expense_type) if policy is None: return str(expense_type or "").strip().lower() in LOCATION_REQUIRED_EXPENSE_TYPES return bool(policy.location_required)