from __future__ import annotations import json import re from datetime import UTC, datetime, timedelta from decimal import Decimal, InvalidOperation from typing import Any from sqlalchemy import or_, select from sqlalchemy.orm import selectinload from app.api.deps import CurrentUserContext from app.core.agent_enums import AgentAssetStatus, AgentAssetType from app.models.employee import Employee from app.models.financial_record import ExpenseClaim from app.schemas.agent_asset import AgentAssetListItem from app.schemas.reimbursement import TravelReimbursementCalculatorRequest from app.schemas.user_agent import ( UserAgentCitation, UserAgentDraftPayload, UserAgentExpenseQueryRecord, UserAgentQueryPayload, UserAgentQueryStatusGroup, UserAgentReviewAction, UserAgentReviewClaimGroup, UserAgentReviewDocumentCard, UserAgentReviewDocumentField, UserAgentReviewEditField, UserAgentReviewPayload, UserAgentReviewRiskBrief, UserAgentReviewSlotCard, UserAgentRequest, UserAgentSuggestedAction, ) from app.services.agent_assets import AgentAssetService from app.services.expense_claims import ExpenseClaimService from app.services.expense_rule_runtime import ExpenseRuleRuntimeService, RuntimeTravelPolicy, resolve_document_type_label from app.services.ontology_field_registry import normalize_ontology_form_values from app.services.risk_ontology_bridge import resolve_rule_codes_for_risk_check from app.services.travel_reimbursement_calculator import TravelReimbursementCalculatorService from app.services.user_agent_constants import * class UserAgentReviewTravelReceiptMixin: def _is_travel_review_context( self, payload: UserAgentRequest, document_cards: list[UserAgentReviewDocumentCard], claim_groups: list[UserAgentReviewClaimGroup], ) -> bool: entity_expense_type = self._collect_entity_values(payload).get("expense_type_code", "") review_form_values = self._resolve_review_form_values(payload) form_expense_type = str(review_form_values.get("expense_type") or "").strip() message_context = " ".join( [ str(payload.message or ""), str(payload.context_json.get("user_input_text") or ""), str(payload.context_json.get("expense_type") or ""), form_expense_type, ] ) if entity_expense_type in {"travel", "hotel", "transport"}: return True if any(group.group_code == "travel" or group.expense_type in {"travel", "hotel", "transport"} for group in claim_groups): return True if any(card.suggested_expense_type in {"travel", "hotel", "transport"} for card in document_cards): return True return any(keyword in message_context for keyword in ("差旅", "出差", "机票", "火车", "高铁", "酒店", "住宿")) def _build_travel_receipt_state( self, payload: UserAgentRequest, *, document_cards: list[UserAgentReviewDocumentCard], claim_groups: list[UserAgentReviewClaimGroup], ) -> dict[str, Any]: empty_state: dict[str, Any] = { "is_travel_context": False, "has_long_distance_ticket": False, "ticket_type_label": "", "ticket_amount": Decimal("0.00"), "destination": "", "days": 1, "has_hotel_invoice": False, "has_local_transport": False, "required_missing_labels": [], "optional_missing_labels": [], "blocks_next_step": False, } if not document_cards or not self._is_travel_review_context(payload, document_cards, claim_groups): return empty_state long_distance_cards = [card for card in document_cards if self._is_long_distance_travel_card(card)] if not long_distance_cards: return { **empty_state, "is_travel_context": True, } has_hotel_invoice = any(self._is_review_hotel_card(card) for card in document_cards) required_missing_labels = [] if has_hotel_invoice else ["酒店的报销票据待上传(必须)"] ticket_amount = sum( (self._extract_amount_decimal_from_card(card) or Decimal("0.00")) for card in long_distance_cards ).quantize(Decimal("0.01")) return { **empty_state, "is_travel_context": True, "has_long_distance_ticket": True, "ticket_type_label": self._resolve_travel_ticket_type_label(long_distance_cards), "ticket_amount": ticket_amount, "destination": self._resolve_travel_receipt_destination(payload, long_distance_cards), "days": self._resolve_travel_receipt_days(payload, long_distance_cards), "has_hotel_invoice": has_hotel_invoice, "has_local_transport": any(self._is_local_transport_receipt_card(card) for card in document_cards), "required_missing_labels": required_missing_labels, "optional_missing_labels": [], "blocks_next_step": bool(required_missing_labels), } @staticmethod def _is_long_distance_travel_card(card: UserAgentReviewDocumentCard) -> bool: document_type = str(card.document_type or "").strip().lower() return document_type in {"train_ticket", "flight_itinerary"} @staticmethod def _is_local_transport_receipt_card(card: UserAgentReviewDocumentCard) -> bool: document_type = str(card.document_type or "").strip().lower() suggested_type = str(card.suggested_expense_type or "").strip().lower() return document_type in {"taxi_receipt", "parking_toll_receipt", "transport_receipt"} or ( suggested_type == "transport" and document_type not in {"train_ticket", "flight_itinerary"} ) @staticmethod def _resolve_travel_ticket_type_label(cards: list[UserAgentReviewDocumentCard]) -> str: labels: list[str] = [] for card in cards: document_type = str(card.document_type or "").strip().lower() if document_type == "train_ticket" and "火车" not in labels: labels.append("火车") if document_type == "flight_itinerary" and "飞机" not in labels: labels.append("飞机") return "/".join(labels) if labels else "交通" def _resolve_travel_receipt_destination( self, payload: UserAgentRequest, long_distance_cards: list[UserAgentReviewDocumentCard], ) -> str: for card in long_distance_cards: for field in card.fields: if str(field.label or "").strip() not in {"行程", "路线"}: continue destination = self._extract_travel_destination_from_route(field.value) if destination: return self._normalize_travel_destination(destination) card_text = self._build_review_document_card_text(card) route_match = TRAVEL_ROUTE_PATTERN.search(card_text) if route_match: return self._normalize_travel_destination(route_match.group(2)) location = self._resolve_location_value(payload) if location: return self._normalize_travel_destination(location) return "" @staticmethod def _extract_travel_destination_from_route(value: str) -> str: route_text = str(value or "").strip() if not route_text: return "" route_match = TRAVEL_ROUTE_PATTERN.search(route_text) if route_match: return route_match.group(2).strip() parts = [ item.strip() for item in re.split(r"\s*(?:至|到|→|->|-|—|~|~)\s*", route_text) if item.strip() ] return parts[-1] if len(parts) >= 2 else "" def _normalize_travel_destination(self, value: str) -> str: candidate = re.sub( r"(?:火车站|高铁站|动车站|车站|站|机场|航站楼)$", "", str(value or "").strip(), ) if not candidate: return "" try: policy = ExpenseRuleRuntimeService(self.db).load_catalog().travel_policy except Exception: policy = None if policy is not None: policy_city = self._extract_policy_city_from_text(candidate, policy) if policy_city: return policy_city return candidate def _resolve_travel_receipt_days( self, payload: UserAgentRequest, long_distance_cards: list[UserAgentReviewDocumentCard], ) -> int: dates: list[datetime] = [] for card in long_distance_cards: card_text = self._build_review_document_card_text(card) dates.extend(self._extract_dates_from_text(card_text)) if dates: return max(1, (max(dates).date() - min(dates).date()).days + 1) start_date = self._parse_date_text(payload.ontology.time_range.start_date or "") end_date = self._parse_date_text(payload.ontology.time_range.end_date or "") if start_date and end_date: return max(1, (end_date.date() - start_date.date()).days + 1) return 1 @staticmethod def _extract_dates_from_text(text: str) -> list[datetime]: dates: list[datetime] = [] for match in DATE_TEXT_PATTERN.finditer(str(text or "")): parsed = UserAgentReviewTravelReceiptMixin._parse_date_text(match.group(1)) if parsed is not None: dates.append(parsed) return dates @staticmethod def _parse_date_text(value: str) -> datetime | None: raw_value = str(value or "").strip() if not raw_value: return None normalized = ( raw_value.replace("年", "-") .replace("月", "-") .replace("/", "-") .replace("日", "") .strip() ) parts = [part for part in normalized.split("-") if part] if len(parts) != 3: return None try: year, month, day = (int(part) for part in parts) return datetime(year, month, day) except ValueError: return None def _build_travel_receipt_briefs( self, travel_receipt_state: dict[str, Any], ) -> list[UserAgentReviewRiskBrief]: if not travel_receipt_state.get("has_long_distance_ticket"): return [] required_labels = [ str(item).strip() for item in travel_receipt_state.get("required_missing_labels", []) if str(item).strip() ] if not required_labels: return [] required_text = ";".join(required_labels) return [ UserAgentReviewRiskBrief( title="差旅票据待补充", level="warning", content=required_text, detail=( "系统已识别到长途交通票据,会按差旅报销口径核对住宿、交通等票据完整性。" + f"当前必须补充:{required_text}。" ), suggestion="请先补充酒店住宿发票或住宿清单;在补齐前只能保存为草稿。", ) ] def _resolve_review_travel_allowance_standard( self, policy: RuntimeTravelPolicy, *, declared_city: str, card_text: str, ) -> tuple[str, Decimal] | None: meal_limits = getattr(policy, "allowance_limits", {}).get("meal", {}) if not meal_limits: return None region_label = self._resolve_review_travel_allowance_region( " ".join([declared_city or "", card_text or ""]) ) amount = meal_limits.get(region_label) if amount is None and region_label != "其他地区": amount = meal_limits.get("其他地区") region_label = "其他地区" if amount is None: return None return region_label, Decimal(amount).quantize(Decimal("0.01")) @staticmethod def _resolve_review_travel_allowance_region(text: str) -> str: normalized = re.sub(r"\s+", "", str(text or "")) if not normalized: return "其他地区" if any(keyword in normalized for keyword in ("境外", "国外", "海外")): return "国外" if any(keyword in normalized for keyword in ("香港", "澳门", "台湾", "港澳台")): return "港澳台" if "乌鲁木齐" in normalized: return "新疆-乌鲁木齐" if "新疆" in normalized: return "新疆-其他" if any(keyword in normalized for keyword in ("西藏", "拉萨")): return "西藏" if any(keyword in normalized for keyword in ("北京", "上海", "天津", "重庆", "深圳", "珠海", "汕头", "厦门")): return "直辖市/特区" return "其他地区" def _resolve_review_amount_scene_code( self, card: UserAgentReviewDocumentCard, payload: UserAgentRequest, ) -> str: document_type = str(card.document_type or "").strip().lower() suggested_type = str(card.suggested_expense_type or "").strip().lower() if document_type in {"taxi_receipt", "parking_toll_receipt", "transport_receipt"}: return "transport" if document_type == "meal_receipt": entity_values = self._collect_entity_values(payload) if suggested_type == "entertainment" or entity_values.get("expense_type_code") == "entertainment": return "entertainment" return "meal" if document_type == "hotel_invoice" or suggested_type == "hotel": return "hotel" if suggested_type in { "travel", "transport", "meal", "entertainment", "office", "meeting", "training", "communication", "welfare", "other", }: return suggested_type return self._collect_entity_values(payload).get("expense_type_code") or "other" @staticmethod def _resolve_review_scene_amount_limit(scene_policy: Any | None) -> Any | None: if scene_policy is None: return None return getattr(scene_policy, "item_amount_limit", None) or getattr(scene_policy, "claim_amount_limit", None) @staticmethod def _resolve_scene_standard_amount(limit_config: Any | None) -> Decimal | None: if limit_config is None: return None warn_amount = getattr(limit_config, "warn_amount", None) block_amount = getattr(limit_config, "block_amount", None) amount = warn_amount if warn_amount is not None else block_amount if amount is None: return None try: return Decimal(amount).quantize(Decimal("0.01")) except (InvalidOperation, ValueError): return None @staticmethod def _evaluate_review_scene_amount( *, amount: Decimal, limit_config: Any, reason_text: str, ) -> tuple[str, Decimal] | None: block_amount = getattr(limit_config, "block_amount", None) warn_amount = getattr(limit_config, "warn_amount", None) exception_keywords = list(getattr(limit_config, "exception_keywords", []) or []) has_exception = UserAgentReviewTravelReceiptMixin._text_contains_any(reason_text, exception_keywords) if block_amount is not None and amount > Decimal(block_amount): return ("high", Decimal(block_amount).quantize(Decimal("0.01"))) if warn_amount is not None and amount > Decimal(warn_amount): return ("high", Decimal(warn_amount).quantize(Decimal("0.01"))) return None def _resolve_review_employee_grade(self, payload: UserAgentRequest, *, employee: Employee | None) -> str: if employee is not None and employee.grade: return str(employee.grade).strip() review_form_values = self._resolve_review_form_values(payload) for source in ( review_form_values, payload.context_json, payload.tool_payload, ): for key in ("employee_grade", "grade", "user_grade", "position_grade"): value = str(source.get(key) or "").strip() if isinstance(source, dict) else "" if value: return value return "" def _build_review_reason_corpus(self, payload: UserAgentRequest) -> str: review_form_values = normalize_ontology_form_values(self._resolve_review_form_values(payload)) parts = [ str(payload.message or ""), str(payload.context_json.get("user_input_text") or ""), str(review_form_values.get("reason") or ""), str(review_form_values.get("location") or ""), ] return "\n".join(part.strip() for part in parts if part and part.strip()) def _resolve_declared_travel_city(self, payload: UserAgentRequest, policy: RuntimeTravelPolicy) -> str: review_form_values = normalize_ontology_form_values(self._resolve_review_form_values(payload)) candidates = [ str(review_form_values.get("location") or ""), self._resolve_location_value(payload), str(payload.message or ""), ] for candidate in candidates: city = self._extract_policy_city_from_text(candidate, policy) if city: return city return "" @staticmethod def _build_review_document_card_text(card: UserAgentReviewDocumentCard) -> str: field_text = " ".join(f"{field.label}:{field.value}" for field in card.fields) return " ".join( [ str(card.filename or ""), str(card.document_type or ""), str(card.scene_label or ""), str(card.summary or ""), field_text, ] ).strip() @staticmethod def _is_review_hotel_card(card: UserAgentReviewDocumentCard) -> bool: document_type = str(card.document_type or "").strip().lower() suggested_type = str(card.suggested_expense_type or "").strip().lower() scene_label = str(card.scene_label or "").strip() return document_type == "hotel_invoice" or suggested_type == "hotel" or "住宿" in scene_label @staticmethod def _extract_amount_decimal_from_card(card: UserAgentReviewDocumentCard) -> Decimal | None: for field in card.fields: if field.label != "金额": continue normalized = str(field.value or "").replace("元", "").replace("¥", "").replace("¥", "").replace(",", "").strip() try: amount = Decimal(normalized).quantize(Decimal("0.01")) except (InvalidOperation, ValueError): continue if amount > Decimal("0.00"): return amount return None @staticmethod def _extract_review_hotel_night_count(card: UserAgentReviewDocumentCard) -> int: text = f"{card.summary or ''} {' '.join(f'{field.label}:{field.value}' for field in card.fields)}" match = TRAVEL_REVIEW_HOTEL_NIGHT_PATTERN.search(text) if not match: return 1 try: return max(1, int(match.group(1))) except (TypeError, ValueError): return 1 @staticmethod def _extract_policy_city_from_text(text: str, policy: RuntimeTravelPolicy) -> str: normalized = str(text or "").strip() if not normalized: return "" city_names = set(policy.city_tiers.keys()) city_names.update(getattr(policy, "hotel_city_limits", {}).keys()) for city in sorted(city_names, key=lambda item: len(item), reverse=True): if city in normalized: return city return "" @staticmethod def _format_travel_city_tier(city_tier: str) -> str: return { "tier_1": "一线城市", "tier_2": "重点城市", "tier_3": "其他城市", }.get(str(city_tier or "").strip(), "当前城市") @staticmethod def _resolve_review_hotel_cap( policy: RuntimeTravelPolicy, *, grade_band: str, city: str, city_tier: str, ) -> Decimal: normalized_city = str(city or "").strip() if normalized_city and getattr(policy, "hotel_city_limits", None): city_limits = policy.hotel_city_limits.get(normalized_city, {}) city_cap = city_limits.get(grade_band) if city_cap is not None: return Decimal(city_cap).quantize(Decimal("0.01")) return Decimal(policy.hotel_limits.get(grade_band, {}).get(city_tier, Decimal("0.00"))).quantize( Decimal("0.01") ) def _detect_review_transport_class( self, card: UserAgentReviewDocumentCard, policy: RuntimeTravelPolicy, ) -> tuple[str, str, int] | None: document_type = str(card.document_type or "").strip().lower() text = re.sub(r"\s+", "", self._build_review_document_card_text(card)) if not text: return None if document_type == "flight_itinerary" or any(keyword in text for keyword in ("机票", "航班", "登机牌")): for config in policy.flight_classes: label = str(config.keyword or "").strip() if label and label in text: return "flight", label, int(config.level) if document_type == "train_ticket" or any(keyword in text for keyword in ("火车", "高铁", "动车", "铁路")): for config in policy.train_classes: label = str(config.keyword or "").strip() if label and label in text: return "train", label, int(config.level) return None @staticmethod def _text_contains_any(text: str, keywords: list[str] | tuple[str, ...]) -> bool: compact = re.sub(r"\s+", "", str(text or "")) return bool(compact) and any(str(keyword or "").strip() and str(keyword).strip() in compact for keyword in keywords) @staticmethod def _resolve_submission_blocked_reasons(payload: UserAgentRequest) -> list[str]: raw_reasons = payload.tool_payload.get("submission_blocked_reasons") submission_blocked = bool(payload.tool_payload.get("submission_blocked")) if raw_reasons is None and submission_blocked: raw_reasons = payload.tool_payload.get("missing_fields") if raw_reasons is None and not submission_blocked: return [] reasons: list[str] = [] if isinstance(raw_reasons, list): reasons.extend(str(item or "").strip() for item in raw_reasons) elif isinstance(raw_reasons, str): reasons.extend( item.strip() for item in re.split(r"[;;\n]+", raw_reasons) if item.strip() ) if not reasons and submission_blocked: message = str(payload.tool_payload.get("message") or "").strip() for prefix in ( "提交前请先补全信息:", "自动检测暂未通过,原因如下:", "自动检测未通过,原因如下:", "自动检测暂未通过:", "自动检测未通过:", "AI预审暂未通过,原因如下:", "AI预审未通过,原因如下:", "AI预审暂未通过:", "AI预审未通过:", ): if message.startswith(prefix): message = message[len(prefix):].strip() break if message: reasons.extend( item.strip() for item in re.split(r"[;;\n]+", message) if item.strip() and not item.strip().startswith("AI预审暂未通过") and not item.strip().startswith("自动检测暂未通过") ) return list(dict.fromkeys(reason for reason in reasons if reason))