from __future__ import annotations from datetime import UTC, date, datetime from sqlalchemy import create_engine, func, select from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.pool import StaticPool from app.db.base import Base from app.models.budget import BudgetAllocation, BudgetReservation, BudgetTransaction from app.models.employee import Employee from app.models.financial_record import ExpenseClaim, ExpenseClaimItem from app.models.organization import OrganizationUnit from app.models.risk_observation import RiskObservation from app.services.budget import BudgetService from app.services.demo_company_simulation_seed import ( SIM_EMPLOYEE_PREFIX, HalfYearExpenseSimulationSeeder, SimulationConfig, ) from app.services.demo_company_simulation_catalog import SIM_PROJECT_CODE from app.services.demo_company_simulation_rebalance import HalfYearExpenseSimulationRebalancer def build_session() -> Session: engine = create_engine( "sqlite+pysqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) Base.metadata.create_all(bind=engine) session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False) return session_factory() def seed_company(db: Session) -> None: tech = OrganizationUnit( id="dept-tech", unit_code="TECH-DEPT", name="技术部", unit_type="department", cost_center="CC-6100", location="北京", ) market = OrganizationUnit( id="dept-market", unit_code="MARKET-DEPT", name="市场部", unit_type="department", cost_center="CC-4100", location="上海", ) db.add_all([tech, market]) for index in range(3): db.add( Employee( id=f"emp-existing-{index}", employee_no=f"E-EXISTING-{index}", name=f"现有员工{index}", email=f"existing-{index}@xf.com", grade="P5", position="主管", organization_unit=tech if index % 2 == 0 else market, cost_center="CC-6100" if index % 2 == 0 else "CC-4100", ) ) db.commit() def test_half_year_simulation_preview_and_apply_are_idempotent() -> None: with build_session() as db: seed_company(db) config = SimulationConfig(target_employees=8, start_date=date(2026, 1, 1), months=6, seed=7) preview = HalfYearExpenseSimulationSeeder(db, config).preview() assert preview.mode == "dry-run" assert preview.current_employee_count == 3 assert preview.employees_to_create == 5 assert preview.claims_to_create >= 24 assert preview.budget_allocations_to_create > 0 assert preview.budget_transactions_to_create > 0 applied = HalfYearExpenseSimulationSeeder(db, config).apply() db.commit() assert applied.mode == "apply" assert applied.employees_to_create == 5 assert db.scalar(select(func.count()).select_from(Employee)) == 8 assert db.scalar(select(func.count()).select_from(ExpenseClaim)) == applied.claims_to_create assert ( db.scalar(select(func.count()).select_from(ExpenseClaimItem)) == applied.claim_items_to_create ) assert ( db.scalar(select(func.count()).select_from(BudgetAllocation)) == applied.budget_allocations_to_create ) assert ( db.scalar(select(func.count()).select_from(BudgetTransaction)) == applied.budget_transactions_to_create ) assert ( db.scalar(select(func.count()).select_from(BudgetReservation)) == applied.budget_reservations_to_create ) assert ( db.scalar(select(func.count()).select_from(RiskObservation)) == applied.risk_observations_to_create ) repeated = HalfYearExpenseSimulationSeeder(db, config).apply() db.commit() assert repeated.employees_to_create == 0 assert repeated.claims_to_create == 0 assert repeated.budget_allocations_to_create == 0 assert repeated.budget_transactions_to_create == 0 assert repeated.budget_reservations_to_create == 0 assert repeated.risk_observations_to_create == 0 def test_half_year_simulation_feeds_budget_summary() -> None: with build_session() as db: seed_company(db) config = SimulationConfig( target_employees=10, start_date=date(2026, 1, 1), months=6, seed=11, ) HalfYearExpenseSimulationSeeder(db, config).apply() db.commit() summary = BudgetService(db).get_summary(fiscal_year=2026, period_key="2026Q2") sim_claim_count = db.scalar( select(func.count()).select_from(ExpenseClaim).where(ExpenseClaim.project_code == SIM_PROJECT_CODE) ) sim_employee_count = db.scalar( select(func.count()).select_from(Employee).where(Employee.employee_no.like(f"{SIM_EMPLOYEE_PREFIX}%")) ) assert sim_claim_count and sim_claim_count >= 30 assert sim_employee_count == 7 assert summary.trend assert {item.period_key for item in summary.trend} == {"2026Q1", "2026Q2"} assert summary.warning_count + summary.over_budget_count > 0 def test_half_year_simulation_excludes_admin_and_visible_month_has_real_volume() -> None: with build_session() as db: seed_company(db) db.add( Employee( id="emp-admin", employee_no="ADMIN", name="admin", email="admin@xf.com", grade="P8", position="admin", ) ) db.commit() config = SimulationConfig( target_employees=100, start_date=date(2026, 1, 1), months=6, seed=20260602, ) HalfYearExpenseSimulationSeeder(db, config).apply() db.commit() admin_claim_count = db.scalar( select(func.count()) .select_from(ExpenseClaim) .where(ExpenseClaim.employee_name == "admin") ) visible_claim_count = db.scalar( select(func.count()) .select_from(ExpenseClaim) .where(ExpenseClaim.project_code == SIM_PROJECT_CODE) .where(ExpenseClaim.occurred_at >= datetime(2026, 6, 1, tzinfo=UTC)) .where(ExpenseClaim.occurred_at < datetime(2026, 6, 3, tzinfo=UTC)) ) total_claim_count = db.scalar( select(func.count()) .select_from(ExpenseClaim) .where(ExpenseClaim.project_code == SIM_PROJECT_CODE) ) daily_counts = [ row[0] for row in db.execute( select(func.count()) .select_from(ExpenseClaim) .where(ExpenseClaim.project_code == SIM_PROJECT_CODE) .where(ExpenseClaim.occurred_at >= datetime(2026, 6, 1, tzinfo=UTC)) .where(ExpenseClaim.occurred_at < datetime(2026, 6, 3, tzinfo=UTC)) .group_by(func.date(ExpenseClaim.occurred_at)) ).all() ] max_daily_count = max(daily_counts) if daily_counts else 0 earliest_claim_day = db.scalar( select(func.min(ExpenseClaim.occurred_at)).where( ExpenseClaim.project_code == SIM_PROJECT_CODE ) ) latest_claim_day = db.scalar( select(func.max(ExpenseClaim.occurred_at)).where( ExpenseClaim.project_code == SIM_PROJECT_CODE ) ) assert admin_claim_count == 0 assert total_claim_count is not None assert 400 <= total_claim_count <= 500 assert visible_claim_count is not None assert 12 <= visible_claim_count <= 30 assert max_daily_count <= 16 assert earliest_claim_day is not None assert latest_claim_day is not None assert earliest_claim_day.date() >= date(2026, 1, 1) assert latest_claim_day.date() <= date(2026, 6, 2) def test_half_year_simulation_rebalance_spreads_existing_rows_without_deleting() -> None: with build_session() as db: seed_company(db) config = SimulationConfig( target_employees=100, start_date=date(2026, 1, 1), months=6, seed=20260602, ) HalfYearExpenseSimulationSeeder(db, config).apply() db.commit() claims = list( db.scalars( select(ExpenseClaim) .where(ExpenseClaim.project_code == SIM_PROJECT_CODE) .order_by(ExpenseClaim.claim_no.asc()) ).all() ) for claim in claims: claim.occurred_at = datetime(2026, 6, 1, 10, tzinfo=UTC) claim.submitted_at = datetime(2026, 6, 1, 11, tzinfo=UTC) claim.created_at = claim.occurred_at claim.updated_at = claim.submitted_at for item in claim.items: item.item_date = date(2026, 6, 1) db.commit() before_count = db.scalar( select(func.count()).select_from(ExpenseClaim).where(ExpenseClaim.project_code == SIM_PROJECT_CODE) ) preview = HalfYearExpenseSimulationRebalancer(db).preview() applied = HalfYearExpenseSimulationRebalancer(db).apply() db.commit() after_count = db.scalar( select(func.count()).select_from(ExpenseClaim).where(ExpenseClaim.project_code == SIM_PROJECT_CODE) ) daily_counts = [ row[0] for row in db.execute( select(func.count()) .select_from(ExpenseClaim) .where(ExpenseClaim.project_code == SIM_PROJECT_CODE) .group_by(func.date(ExpenseClaim.occurred_at)) ).all() ] month_keys = { (claim.occurred_at.year, claim.occurred_at.month) for claim in db.scalars( select(ExpenseClaim).where(ExpenseClaim.project_code == SIM_PROJECT_CODE) ).all() } sample_claim = db.scalar( select(ExpenseClaim) .where(ExpenseClaim.project_code == SIM_PROJECT_CODE) .where(ExpenseClaim.status != "draft") .order_by(ExpenseClaim.claim_no.asc()) .limit(1) ) sample_transaction = db.scalar( select(BudgetTransaction) .where(BudgetTransaction.source_id == sample_claim.id) .limit(1) ) sample_observation = db.scalar( select(RiskObservation) .where(RiskObservation.claim_id == sample_claim.id) .limit(1) ) assert before_count == after_count assert preview.claims == applied.claims == after_count assert applied.recent_claims <= 24 assert max(daily_counts) <= 16 assert {(2026, month) for month in range(1, 7)}.issubset(month_keys) if sample_transaction is not None: assert sample_transaction.source_no == sample_claim.claim_no assert sample_transaction.created_at.date() == sample_claim.submitted_at.date() if sample_observation is not None: assert sample_observation.claim_no == sample_claim.claim_no assert sample_observation.created_at.date() == sample_claim.submitted_at.date()