from __future__ import annotations import uuid from datetime import UTC, datetime from decimal import Decimal from typing import Any from sqlalchemy import select from app.models.budget import BudgetAllocation, BudgetReservation, BudgetTransaction from app.models.financial_record import ExpenseClaim from app.models.organization import OrganizationUnit from app.schemas.budget import BudgetAllocationRead, BudgetTransactionRead from app.services.budget_types import ( BUDGET_SUBJECT_LABELS, BudgetBalance, DEFAULT_SUBJECT_AMOUNTS, SUBJECT_CODE_ALIASES, SUPPORTED_BUDGET_SUBJECT_CODES, ) from app.services.expense_claim_constants import EXPENSE_TYPE_LABELS from app.services.expense_type_keywords import resolve_expense_type_code_from_text class BudgetSupportMixin: def serialize_allocation(self, allocation: BudgetAllocation) -> BudgetAllocationRead: return BudgetAllocationRead( id=allocation.id, budget_no=allocation.budget_no, fiscal_year=allocation.fiscal_year, period_type=allocation.period_type, period_key=allocation.period_key, department_id=allocation.department_id, department_name=allocation.department_name, cost_center=allocation.cost_center, project_code=allocation.project_code, subject_code=allocation.subject_code, subject_name=allocation.subject_name, original_amount=self._money(allocation.original_amount), adjusted_amount=self._money(allocation.adjusted_amount), status=allocation.status, warning_threshold=self._percent(allocation.warning_threshold), control_action=allocation.control_action, description=allocation.description, balance=self.get_balance(allocation).to_read(), created_at=allocation.created_at, updated_at=allocation.updated_at, ) def get_balance(self, allocation: BudgetAllocation) -> BudgetBalance: reservations = self.db.scalars( select(BudgetReservation).where( BudgetReservation.allocation_id == allocation.id, BudgetReservation.source_status == "active", ) ).all() transactions = self.db.scalars( select(BudgetTransaction).where(BudgetTransaction.allocation_id == allocation.id) ).all() reserved_amount = sum((self._money(item.amount) for item in reservations), Decimal("0.00")) consumed_amount = Decimal("0.00") for transaction in transactions: transaction_type = str(transaction.transaction_type or "").strip().lower() amount = self._money(transaction.amount) if transaction_type == "consume": consumed_amount += amount elif transaction_type == "rollback": consumed_amount -= amount total_amount = self._money(allocation.original_amount) + self._money(allocation.adjusted_amount) available_amount = total_amount - reserved_amount - consumed_amount usage_amount = reserved_amount + consumed_amount usage_rate = Decimal("0.00") if total_amount > Decimal("0.00"): usage_rate = ((usage_amount / total_amount) * Decimal("100")).quantize(Decimal("0.01")) return BudgetBalance( total_amount=total_amount, reserved_amount=reserved_amount, consumed_amount=consumed_amount, available_amount=available_amount, usage_rate=usage_rate, ) def list_transactions(self, allocation_id: str) -> list[BudgetTransactionRead]: self.ensure_budget_ready() rows = self.db.scalars( select(BudgetTransaction) .where(BudgetTransaction.allocation_id == allocation_id) .order_by(BudgetTransaction.created_at.desc()) ).all() return [BudgetTransactionRead.model_validate(row) for row in rows] def get_allocation_row(self, allocation_id: str) -> BudgetAllocation | None: self.ensure_budget_ready() return self.db.get(BudgetAllocation, allocation_id) def _review_allocation_amount( self, allocation: BudgetAllocation, amount: Decimal, ) -> dict[str, list[Any]]: balance = self.get_balance(allocation) flags: list[dict[str, Any]] = [] blocking_reasons: list[str] = [] if str(allocation.status or "").strip().lower() == "frozen": message = f"预算 {allocation.budget_no} 已冻结,不能继续占用。" flags.append( self._build_operation_flag( allocation, event_type="budget_frozen", label="预算已冻结", message=message, severity="high", amount=amount, ) ) blocking_reasons.append(message) return {"flags": flags, "blocking_reasons": blocking_reasons} if amount > balance.available_amount: over_amount = amount - balance.available_amount message = ( f"预算 {allocation.budget_no} 可用余额 {balance.available_amount} 元," f"当前单据金额 {amount} 元,超出 {over_amount} 元。" ) flags.append( self._build_operation_flag( allocation, event_type="budget_insufficient", label="预算余额不足", message=message, severity="high", amount=amount, extra={"available_amount": str(balance.available_amount), "over_budget_amount": str(over_amount)}, ) ) blocking_reasons.append(message) return {"flags": flags, "blocking_reasons": blocking_reasons} after_usage = balance.reserved_amount + balance.consumed_amount + amount usage_rate = Decimal("0.00") if balance.total_amount > Decimal("0.00"): usage_rate = ((after_usage / balance.total_amount) * Decimal("100")).quantize(Decimal("0.01")) if usage_rate >= self._percent(allocation.warning_threshold): flags.append( self._build_operation_flag( allocation, event_type="budget_warning", label="预算接近预警线", message=( f"预算 {allocation.budget_no} 本次占用后使用率预计达到 {usage_rate}%," f"已达到预警线 {allocation.warning_threshold}%。" ), severity="medium", amount=amount, extra={"usage_rate": str(usage_rate)}, ) ) return {"flags": flags, "blocking_reasons": blocking_reasons} def build_claim_budget_context(self, claim: ExpenseClaim) -> dict[str, Any]: self.ensure_budget_ready() amount = self._money(claim.amount or Decimal("0.00")) fiscal_year, period_key = self._period_from_claim(claim) subject_code = self._subject_code_from_claim(claim) if not self._is_supported_budget_subject(subject_code): return { "matched": False, "budget_applicable": False, "skip_reason": "demo_budget_subject_not_enabled", "claim_amount": str(amount), "fiscal_year": fiscal_year, "period_key": period_key, "subject_code": subject_code, "department_id": claim.department_id, "department_name": claim.department_name, "cost_center": self._resolve_claim_cost_center(claim), } allocation = self._find_allocation_for_claim(claim) if allocation is None: return { "matched": False, "budget_applicable": True, "claim_amount": str(amount), "fiscal_year": fiscal_year, "period_key": period_key, "subject_code": subject_code, "department_id": claim.department_id, "department_name": claim.department_name, "cost_center": self._resolve_claim_cost_center(claim), } balance = self.get_balance(allocation) over_budget_amount = max(amount - balance.available_amount, Decimal("0.00")) return { "matched": True, "budget_applicable": True, "allocation_id": allocation.id, "budget_no": allocation.budget_no, "claim_amount": str(amount), "total_amount": str(balance.total_amount), "reserved_amount": str(balance.reserved_amount), "consumed_amount": str(balance.consumed_amount), "available_amount": str(balance.available_amount), "usage_rate": str(balance.usage_rate), "over_budget_amount": str(over_budget_amount), "warning_threshold": str(allocation.warning_threshold), "control_action": allocation.control_action, "fiscal_year": allocation.fiscal_year, "period_key": allocation.period_key, "subject_code": allocation.subject_code, "subject_name": allocation.subject_name, "department_id": allocation.department_id, "department_name": allocation.department_name, "cost_center": allocation.cost_center, "project_code": allocation.project_code, } def _find_allocation_for_claim(self, claim: ExpenseClaim) -> BudgetAllocation | None: fiscal_year, period_key = self._period_from_claim(claim) return self._find_allocation_for_dimension( fiscal_year=fiscal_year, period_key=period_key, department_id=claim.department_id, department_name=claim.department_name, cost_center=self._resolve_claim_cost_center(claim), project_code=claim.project_code, subject_code=self._subject_code_from_claim(claim), ) def _find_allocation_for_dimension( self, *, fiscal_year: int | None, period_key: str | None, department_id: str | None, department_name: str | None, cost_center: str | None, project_code: str | None, subject_code: str, ) -> BudgetAllocation | None: now = datetime.now(UTC) year = fiscal_year or now.year key = self._normalize_period_key(year, period_key or self._quarter_key(year, now.month)) normalized_subject = self._normalize_subject_code(subject_code) candidates = list( self.db.scalars( select(BudgetAllocation) .where(BudgetAllocation.fiscal_year == year) .where(BudgetAllocation.period_key == key) .where(BudgetAllocation.subject_code == normalized_subject) .where(BudgetAllocation.status.in_(["active", "published"])) .order_by(BudgetAllocation.project_code.desc().nullslast()) ).all() ) if not candidates: return None normalized_department_id = self._blank_to_none(department_id) normalized_department_name = str(department_name or "").strip() normalized_cost_center = self._blank_to_none(cost_center) normalized_project_code = self._blank_to_none(project_code) for item in candidates: if normalized_project_code and item.project_code and item.project_code != normalized_project_code: continue if normalized_department_id and item.department_id == normalized_department_id: return item if normalized_cost_center and item.cost_center == normalized_cost_center: return item if normalized_department_name and item.department_name == normalized_department_name: return item return None def _find_exact_allocation( self, *, fiscal_year: int, period_key: str, department_id: str | None, department_name: str, cost_center: str | None, project_code: str | None, subject_code: str, ) -> BudgetAllocation | None: rows = self.db.scalars( select(BudgetAllocation) .where(BudgetAllocation.fiscal_year == fiscal_year) .where(BudgetAllocation.period_key == period_key) .where(BudgetAllocation.subject_code == subject_code) ).all() normalized_department_id = self._blank_to_none(department_id) normalized_department_name = department_name.strip() normalized_cost_center = self._blank_to_none(cost_center) normalized_project_code = self._blank_to_none(project_code) for row in rows: if row.project_code != normalized_project_code: continue if normalized_department_id and row.department_id == normalized_department_id: return row if normalized_cost_center and row.cost_center == normalized_cost_center: return row if row.department_name == normalized_department_name: return row return None def _find_active_reservation(self, *, source_type: str, source_id: str) -> BudgetReservation | None: return self.db.scalar( select(BudgetReservation) .where(BudgetReservation.source_type == source_type) .where(BudgetReservation.source_id == source_id) .where(BudgetReservation.source_status == "active") .order_by(BudgetReservation.created_at.desc()) .limit(1) ) def _find_active_reservations(self, *, source_type: str, source_id: str) -> list[BudgetReservation]: return list( self.db.scalars( select(BudgetReservation) .where(BudgetReservation.source_type == source_type) .where(BudgetReservation.source_id == source_id) .where(BudgetReservation.source_status == "active") ).all() ) def _seed_default_allocations(self) -> None: units = list( self.db.scalars( select(OrganizationUnit).where(OrganizationUnit.unit_type == "department") ).all() ) if not units: return year = datetime.now(UTC).year for unit in units: for quarter in range(1, 5): period_key = f"{year}Q{quarter}" for subject_code, amount in DEFAULT_SUBJECT_AMOUNTS.items(): allocation = BudgetAllocation( budget_no=self._make_no("BUD"), fiscal_year=year, period_type="quarter", period_key=period_key, department_id=unit.id, department_name=unit.name, cost_center=unit.cost_center, project_code=None, subject_code=subject_code, subject_name=self._subject_label(subject_code), original_amount=amount, adjusted_amount=Decimal("0.00"), status="active", warning_threshold=Decimal("80.00"), control_action="block", description="系统初始化预算池额度", created_by="system", updated_by="system", ) self.db.add(allocation) self.db.flush() self._record_transaction( allocation=allocation, transaction_type="init", amount=amount, before_available=Decimal("0.00"), after_available=amount, source_type="budget_seed", source_id=allocation.id, source_no=allocation.budget_no, operator="system", reason="系统初始化预算池额度", ) self.db.flush() def _create_fallback_allocation_for_claim(self, claim: ExpenseClaim) -> BudgetAllocation: fiscal_year, period_key = self._period_from_claim(claim) subject_code = self._subject_code_from_claim(claim) allocation = BudgetAllocation( budget_no=self._make_no("BUD"), fiscal_year=fiscal_year, period_type="quarter", period_key=period_key, department_id=claim.department_id, department_name=str(claim.department_name or "未归属部门").strip() or "未归属部门", cost_center=self._resolve_claim_cost_center(claim), project_code=claim.project_code, subject_code=subject_code, subject_name=self._subject_label(subject_code), original_amount=DEFAULT_SUBJECT_AMOUNTS.get(subject_code, Decimal("100000.00")), adjusted_amount=Decimal("0.00"), status="active", warning_threshold=Decimal("80.00"), control_action="block", description="测试或演示环境自动补齐预算池额度", created_by="system", updated_by="system", ) self.db.add(allocation) self.db.flush() self._record_transaction( allocation=allocation, transaction_type="init", amount=allocation.original_amount, before_available=Decimal("0.00"), after_available=allocation.original_amount, source_type="budget_seed", source_id=allocation.id, source_no=allocation.budget_no, operator="system", reason="自动补齐预算池额度", ) self.db.flush() return allocation def _budget_table_empty(self) -> bool: return self.db.scalar(select(BudgetAllocation.id).limit(1)) is None def _record_transaction( self, *, allocation: BudgetAllocation, transaction_type: str, amount: Decimal, before_available: Decimal, after_available: Decimal, source_type: str, source_id: str, source_no: str, operator: str | None, reason: str | None, reservation: BudgetReservation | None = None, context_json: dict[str, Any] | None = None, ) -> BudgetTransaction: transaction = BudgetTransaction( transaction_no=self._make_no("BTX"), allocation_id=allocation.id, reservation_id=reservation.id if reservation is not None else None, source_type=source_type, source_id=source_id, source_no=source_no, transaction_type=transaction_type, amount=self._money(amount), before_available_amount=self._money(before_available), after_available_amount=self._money(after_available), operator=operator, reason=reason, context_json=context_json or {}, ) self.db.add(transaction) return transaction @staticmethod def _build_budget_flag( *, event_type: str, severity: str, label: str, message: str, amount: Decimal, extra: dict[str, Any] | None = None, ) -> dict[str, Any]: payload = { "source": "budget_control", "event_type": event_type, "severity": severity, "label": label, "message": message, "amount": str(amount), "created_at": datetime.now(UTC).isoformat(), } payload.update(extra or {}) return payload def _build_operation_flag( self, allocation: BudgetAllocation, *, event_type: str, label: str, message: str, amount: Decimal, severity: str = "info", reservation_id: str | None = None, transaction_id: str | None = None, extra: dict[str, Any] | None = None, ) -> dict[str, Any]: balance = self.get_balance(allocation) payload = self._build_budget_flag( event_type=event_type, severity=severity, label=label, message=message, amount=amount, extra={ "allocation_id": allocation.id, "budget_no": allocation.budget_no, "subject_code": allocation.subject_code, "subject_name": allocation.subject_name, "available_amount": str(balance.available_amount), "reserved_amount": str(balance.reserved_amount), "consumed_amount": str(balance.consumed_amount), **(extra or {}), }, ) if reservation_id: payload["reservation_id"] = reservation_id if transaction_id: payload["transaction_id"] = transaction_id return payload @staticmethod def _money(value: Any) -> Decimal: return Decimal(str(value or "0")).quantize(Decimal("0.01")) @staticmethod def _percent(value: Any) -> Decimal: return Decimal(str(value or "0")).quantize(Decimal("0.01")) @staticmethod def _blank_to_none(value: str | None) -> str | None: text = str(value or "").strip() return text or None @staticmethod def _make_no(prefix: str) -> str: return f"{prefix}-{datetime.now(UTC).strftime('%Y%m%d%H%M%S')}-{uuid.uuid4().hex[:8].upper()}" @staticmethod def _normalize_period_type(value: str | None) -> str: text = str(value or "").strip().lower() return text if text in {"month", "quarter", "year"} else "quarter" @staticmethod def _normalize_period_key(year: int, value: str | None) -> str: text = str(value or "").strip().upper().replace("年", "").replace("第", "").replace("季度", "") if text.startswith(str(year)) and "Q" in text: return text if text in {"Q1", "Q2", "Q3", "Q4"}: return f"{year}{text}" return text or f"{year}Q1" @staticmethod def _quarter_key(year: int, month: int) -> str: quarter = ((max(1, min(month, 12)) - 1) // 3) + 1 return f"{year}Q{quarter}" def _period_from_claim(self, claim: ExpenseClaim) -> tuple[int, str]: occurred_at = claim.occurred_at or claim.submitted_at or datetime.now(UTC) return occurred_at.year, self._quarter_key(occurred_at.year, occurred_at.month) def _subject_code_from_claim(self, claim: ExpenseClaim) -> str: expense_type = str(claim.expense_type or "").strip().lower() if expense_type.endswith("_application"): expense_type = expense_type.removesuffix("_application") expense_type = SUBJECT_CODE_ALIASES.get(expense_type, expense_type) if expense_type in DEFAULT_SUBJECT_AMOUNTS or expense_type in EXPENSE_TYPE_LABELS: return expense_type resolved = resolve_expense_type_code_from_text(expense_type) if resolved: return SUBJECT_CODE_ALIASES.get(resolved, resolved) return resolved or expense_type or "other" @staticmethod def _normalize_subject_code(value: str | None) -> str: text = str(value or "").strip().lower() if text.endswith("_application"): text = text.removesuffix("_application") text = SUBJECT_CODE_ALIASES.get(text, text) resolved = resolve_expense_type_code_from_text(text) if resolved: return SUBJECT_CODE_ALIASES.get(resolved, resolved) return text or "other" @staticmethod def _is_supported_budget_subject(subject_code: str | None) -> bool: return str(subject_code or "").strip().lower() in SUPPORTED_BUDGET_SUBJECT_CODES def _claim_uses_budget_control(self, claim: ExpenseClaim) -> bool: return self._is_supported_budget_subject(self._subject_code_from_claim(claim)) @staticmethod def _subject_label(code: str) -> str: return BUDGET_SUBJECT_LABELS.get(code, EXPENSE_TYPE_LABELS.get(code, code)) @staticmethod def _normalize_control_action(value: str | None) -> str: text = str(value or "").strip().lower() if text in {"block", "control", "管控", "强控"}: return "block" if text in {"warn", "warning", "提醒", "预警"}: return "warn" if text in {"allow", "normal", "正常", "放行"}: return "allow" return "block" def _resolve_claim_cost_center(self, claim: ExpenseClaim) -> str | None: employee = getattr(claim, "employee", None) if employee is not None: cost_center = self._blank_to_none(getattr(employee, "cost_center", None)) if cost_center: return cost_center organization_unit = getattr(employee, "organization_unit", None) if organization_unit is not None: cost_center = self._blank_to_none(getattr(organization_unit, "cost_center", None)) if cost_center: return cost_center return None def _claim_context(self, claim: ExpenseClaim) -> dict[str, Any]: fiscal_year, period_key = self._period_from_claim(claim) return { "claim_id": claim.id, "claim_no": claim.claim_no, "employee_id": claim.employee_id, "employee_name": claim.employee_name, "department_id": claim.department_id, "department_name": claim.department_name, "cost_center": self._resolve_claim_cost_center(claim), "project_code": claim.project_code, "expense_type": claim.expense_type, "subject_code": self._subject_code_from_claim(claim), "fiscal_year": fiscal_year, "period_key": period_key, }