from __future__ import annotations from dataclasses import asdict, dataclass from datetime import UTC, date, datetime, time, timedelta from decimal import Decimal from sqlalchemy import func, select, text from sqlalchemy.orm import Session from app.models.budget import BudgetAllocation, BudgetReservation, BudgetTransaction from app.models.financial_record import ExpenseClaim from app.models.risk_observation import RiskObservation from app.services.demo_company_simulation_catalog import ( BUDGETED_STATUSES, SIM_BUDGET_PREFIX, SIM_PROJECT_CODE, SIM_RESERVATION_PREFIX, SIM_RISK_PREFIX, SIM_TRANSACTION_PREFIX, build_simulation_reimbursement_no, target_budget_usage, ) @dataclass(frozen=True, slots=True) class SimulationRebalanceSummary: mode: str claims: int main_period_claims: int recent_claims: int period_start: str period_end: str max_daily_count: int budget_transactions: int budget_reservations: int risk_observations: int allocation_missing_count: int def to_dict(self) -> dict[str, object]: return asdict(self) class HalfYearExpenseSimulationRebalancer: """Rebalance existing simulation rows without deleting business records.""" def __init__( self, db: Session, *, start_date: date = date(2026, 1, 1), end_date: date = date(2026, 6, 2), recent_sample_days: int = 2, ) -> None: self.db = db self.start_date = start_date self.end_date = end_date self.main_period_end = date(end_date.year, end_date.month, 1) - timedelta(days=1) self.recent_sample_days = max(1, recent_sample_days) def preview(self) -> SimulationRebalanceSummary: return self._run(apply=False) def apply(self) -> SimulationRebalanceSummary: return self._run(apply=True) def _run(self, *, apply: bool) -> SimulationRebalanceSummary: claims = self._simulation_claims() plans = self._claim_plans(claims) allocation_map = self._allocation_map() allocation_missing_count = self._count_missing_allocations(plans, allocation_map) day_counts: dict[date, int] = {} for _claim, plan in plans: day_counts[plan["day"]] = day_counts.get(plan["day"], 0) + 1 if apply and plans: self._apply_claim_plans(plans, allocation_map) self._rebalance_allocation_amounts() self.db.flush() recent_count = sum(1 for _claim, plan in plans if plan["day"] >= date(2026, 6, 1)) return SimulationRebalanceSummary( mode="apply" if apply else "dry-run", claims=len(claims), main_period_claims=len(claims) - recent_count, recent_claims=recent_count, period_start=self.start_date.isoformat(), period_end=self.end_date.isoformat(), max_daily_count=max(day_counts.values()) if day_counts else 0, budget_transactions=self._sim_transaction_count(), budget_reservations=self._sim_reservation_count(), risk_observations=self._sim_risk_count(), allocation_missing_count=allocation_missing_count, ) def _simulation_claims(self) -> list[ExpenseClaim]: return list( self.db.scalars( select(ExpenseClaim) .where(ExpenseClaim.project_code == SIM_PROJECT_CODE) .order_by(ExpenseClaim.claim_no.asc(), ExpenseClaim.id.asc()) ).all() ) def _claim_plans(self, claims: list[ExpenseClaim]) -> list[tuple[ExpenseClaim, dict[str, object]]]: recent_count = self._recent_count(len(claims)) main_count = max(len(claims) - recent_count, 0) main_days = self._date_range(self.start_date, self.main_period_end) recent_days = self._date_range(date(2026, 6, 1), self.end_date) plans: list[tuple[ExpenseClaim, dict[str, object]]] = [] for index, claim in enumerate(claims): if index < main_count: day = self._spread_day(index, main_count, main_days) else: recent_index = index - main_count day = recent_days[recent_index % len(recent_days)] occurred_at = datetime.combine(day, time(hour=8 + (index % 9)), tzinfo=UTC) submitted_at = None if self._status(claim) != "draft": submitted_at = datetime.combine(day, time(hour=9 + (index % 7)), tzinfo=UTC) updated_at = self._updated_at(claim, occurred_at, submitted_at, index) final_claim_no = build_simulation_reimbursement_no(occurred_at, index + 1) period_key = f"{occurred_at.year}Q{((occurred_at.month - 1) // 3) + 1}" subject_code = "meal" if str(claim.expense_type or "") == "entertainment" else str(claim.expense_type or "") plans.append( ( claim, { "sequence": index + 1, "day": day, "occurred_at": occurred_at, "submitted_at": submitted_at, "updated_at": updated_at, "claim_no": final_claim_no, "period_key": period_key, "subject_code": subject_code, }, ) ) return plans def _apply_claim_plans( self, plans: list[tuple[ExpenseClaim, dict[str, object]]], allocation_map: dict[tuple[str | None, str, str], str], ) -> None: claim_ids = [claim.id for claim, _plan in plans] transactions_by_claim = self._transactions_by_claim_id(claim_ids) reservations_by_claim = self._reservations_by_claim_id(claim_ids) observations_by_claim = self._observations_by_claim_id(claim_ids) for claim, plan in plans: claim.claim_no = f"SIM-TEMP-{claim.id}" self.db.flush() for claim, plan in plans: claim_no = str(plan["claim_no"]) occurred_at = plan["occurred_at"] submitted_at = plan["submitted_at"] updated_at = plan["updated_at"] allocation_id = allocation_map.get( ( claim.department_id, str(plan["period_key"]), str(plan["subject_code"]), ) ) claim.claim_no = claim_no claim.occurred_at = occurred_at claim.submitted_at = submitted_at claim.created_at = occurred_at claim.updated_at = updated_at claim.reason = self._normalized_reason(claim.reason, occurred_at.date()) self.db.execute( text( """ update expense_claim_items set item_date = :item_date, updated_at = :updated_at where claim_id = :claim_id """ ), { "item_date": occurred_at.date(), "updated_at": updated_at, "claim_id": claim.id, }, ) for transaction in transactions_by_claim.get(claim.id, []): transaction.source_no = claim_no transaction.created_at = submitted_at or occurred_at if allocation_id: transaction.allocation_id = allocation_id for reservation in reservations_by_claim.get(claim.id, []): reservation.source_no = claim_no reservation.created_at = submitted_at or occurred_at reservation.updated_at = updated_at if allocation_id: reservation.allocation_id = allocation_id for observation in observations_by_claim.get(claim.id, []): observation.subject_key = claim_no observation.subject_label = claim_no observation.claim_no = claim_no observation.created_at = submitted_at or occurred_at observation.updated_at = updated_at def _allocation_map(self) -> dict[tuple[str | None, str, str], str]: rows = self.db.scalars( select(BudgetAllocation).where(BudgetAllocation.project_code == SIM_PROJECT_CODE) ).all() return { (row.department_id, row.period_key, row.subject_code): row.id for row in rows } def _count_missing_allocations( self, plans: list[tuple[ExpenseClaim, dict[str, object]]], allocation_map: dict[tuple[str | None, str, str], str], ) -> int: missing = { (claim.department_id, str(plan["period_key"]), str(plan["subject_code"])) for claim, plan in plans if self._status(claim) in BUDGETED_STATUSES and (claim.department_id, str(plan["period_key"]), str(plan["subject_code"])) not in allocation_map } return len(missing) def _rebalance_allocation_amounts(self) -> None: allocations = list( self.db.scalars( select(BudgetAllocation) .where(BudgetAllocation.budget_no.like(f"{SIM_BUDGET_PREFIX}%")) .order_by(BudgetAllocation.period_key.asc(), BudgetAllocation.subject_code.asc()) ).all() ) transactions = list( self.db.scalars( select(BudgetTransaction).where( BudgetTransaction.transaction_no.like(f"{SIM_TRANSACTION_PREFIX}%") ) ).all() ) used_by_allocation: dict[str, Decimal] = {} for transaction in transactions: used_by_allocation[transaction.allocation_id] = ( used_by_allocation.get(transaction.allocation_id, Decimal("0.00")) + Decimal(transaction.amount or 0) ) for index, allocation in enumerate(allocations): used = used_by_allocation.get(allocation.id, Decimal("0.00")) usage = target_budget_usage(allocation.period_key, allocation.subject_code, index) allocation.original_amount = max( (used / usage).quantize(Decimal("0.01")) if usage > 0 else used, Decimal("3000.00"), ) allocation.updated_by = "simulation_rebalance" allocation.updated_at = datetime.now(UTC) def _transactions_by_claim_id(self, claim_ids: list[str]) -> dict[str, list[BudgetTransaction]]: rows = self.db.scalars( select(BudgetTransaction) .where(BudgetTransaction.transaction_no.like(f"{SIM_TRANSACTION_PREFIX}%")) .where(BudgetTransaction.source_id.in_(claim_ids)) ).all() return self._group_by_source_id(rows) def _reservations_by_claim_id(self, claim_ids: list[str]) -> dict[str, list[BudgetReservation]]: rows = self.db.scalars( select(BudgetReservation) .where(BudgetReservation.reservation_no.like(f"{SIM_RESERVATION_PREFIX}%")) .where(BudgetReservation.source_id.in_(claim_ids)) ).all() return self._group_by_source_id(rows) def _observations_by_claim_id(self, claim_ids: list[str]) -> dict[str, list[RiskObservation]]: rows = self.db.scalars( select(RiskObservation) .where(RiskObservation.observation_key.like(f"{SIM_RISK_PREFIX}%")) .where(RiskObservation.claim_id.in_(claim_ids)) ).all() grouped: dict[str, list[RiskObservation]] = {} for row in rows: if row.claim_id: grouped.setdefault(row.claim_id, []).append(row) return grouped @staticmethod def _group_by_source_id(rows: object) -> dict[str, list[object]]: grouped: dict[str, list[object]] = {} for row in rows: grouped.setdefault(row.source_id, []).append(row) return grouped def _recent_count(self, total: int) -> int: if total <= 0: return 0 return min(24, max(12, total // 50)) @staticmethod def _date_range(start: date, end: date) -> list[date]: days = max((end - start).days, 0) return [start + timedelta(days=index) for index in range(days + 1)] @staticmethod def _spread_day(index: int, count: int, days: list[date]) -> date: if not days: raise ValueError("days cannot be empty") if count <= 1: return days[0] day_index = round(index * (len(days) - 1) / (count - 1)) jitter = ((index * 17) % 5) - 2 return days[max(0, min(len(days) - 1, day_index + jitter))] @staticmethod def _updated_at( claim: ExpenseClaim, occurred_at: datetime, submitted_at: datetime | None, index: int, ) -> datetime: base = submitted_at or occurred_at status = HalfYearExpenseSimulationRebalancer._status(claim) if status == "paid": return base + timedelta(days=2 + (index % 3), hours=index % 5) if status in {"approved", "pending_payment"}: return base + timedelta(days=1 + (index % 2), hours=index % 4) if status in {"returned", "rejected"}: return base + timedelta(hours=6 + (index % 8)) return base + timedelta(hours=2 + (index % 4)) @staticmethod def _normalized_reason(reason: str, day: date) -> str: text = str(reason or "").strip() for month in range(1, 7): text = text.replace(f"{month}月", f"{day.month}月") return text @staticmethod def _status(claim: ExpenseClaim) -> str: return str(claim.status or "").strip().lower() def _sim_transaction_count(self) -> int: return int( self.db.scalar( select(func.count()).select_from(BudgetTransaction).where( BudgetTransaction.transaction_no.like(f"{SIM_TRANSACTION_PREFIX}%") ) ) or 0 ) def _sim_reservation_count(self) -> int: return int( self.db.scalar( select(func.count()).select_from(BudgetReservation).where( BudgetReservation.reservation_no.like(f"{SIM_RESERVATION_PREFIX}%") ) ) or 0 ) def _sim_risk_count(self) -> int: return int( self.db.scalar( select(func.count()).select_from(RiskObservation).where( RiskObservation.observation_key.like(f"{SIM_RISK_PREFIX}%") ) ) or 0 )