Files
X-Financial/server/tests/test_hermes_employee_profile_baselines.py

145 lines
4.9 KiB
Python
Raw Normal View History

from __future__ import annotations
from datetime import UTC, date, datetime, timedelta
from decimal import Decimal
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.db.base import Base
from app.models.employee import Employee
from app.models.employee_behavior_profile import EmployeeBehaviorProfileSnapshot
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
from app.models.organization import OrganizationUnit
from app.services.digital_employee_dashboard import DigitalEmployeeDashboardService
from app.services.employee_profile_scan_task import EmployeeProfileScanTaskService
from app.services.hermes_employee_profile_scanner import HermesEmployeeProfileScannerService
def test_hermes_employee_profile_scan_returns_profile_baseline_summary() -> None:
session_factory = _build_session_factory()
with session_factory() as db:
_seed_scan_data(db)
summary = HermesEmployeeProfileScannerService(db).scan_employee_profiles(log_id=None)
assert summary["target_employee_count"] == 3
assert db.query(EmployeeBehaviorProfileSnapshot).count() >= 12
baseline_summary = summary["baseline_summary"]
assert baseline_summary["dimension_counts"]["employee"] == 3
assert baseline_summary["dimension_counts"]["department"] == 1
assert baseline_summary["dimension_counts"]["supplier"] == 2
assert baseline_summary["dimension_counts"]["expense_type"] == 2
assert any(
bucket["dimension"] == "supplier" and bucket["key"] == "s-hotel"
for bucket in baseline_summary["buckets"]
)
def test_employee_profile_scan_task_records_digital_employee_run() -> None:
session_factory = _build_session_factory()
with session_factory() as db:
_seed_scan_data(db)
result = EmployeeProfileScanTaskService(db).refresh_profiles()
summary = result["summary"]
assert result["task_type"] == "employee_behavior_profile_scan"
assert summary["target_employee_count"] == 3
assert summary["snapshot_count"] >= 12
assert db.query(EmployeeBehaviorProfileSnapshot).count() >= 12
dashboard = DigitalEmployeeDashboardService(db).build_dashboard(days=7)
assert dashboard.totals["profileSnapshots"] >= 12
assert dashboard.task_distribution[0]["taskType"] == "employee_behavior_profile_scan"
def _build_session_factory() -> sessionmaker[Session]:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
return sessionmaker(bind=engine, autoflush=False, autocommit=False)
def _seed_scan_data(db: Session) -> None:
org = OrganizationUnit(
id="dept-sales",
unit_code="SALES",
name="市场部",
unit_type="department",
)
employees = [
Employee(
id=f"emp-{index}",
employee_no=f"E10{index}",
name=f"员工{index}",
email=f"emp{index}@example.com",
position="客户经理",
grade="P5",
organization_unit=org,
)
for index in range(1, 4)
]
db.add(org)
db.add_all(employees)
now = datetime.now(UTC)
claims = [
_claim("c1", employees[0], "travel", "600", "s-hotel", "Hotel A", now),
_claim("c2", employees[1], "travel", "900", "s-hotel", "Hotel A", now),
_claim("c3", employees[2], "meal", "300", "s-meal", "Meal B", now),
]
db.add_all(claims)
db.commit()
def _claim(
claim_id: str,
employee: Employee,
expense_type: str,
amount: str,
supplier_id: str,
supplier_name: str,
now: datetime,
) -> ExpenseClaim:
return ExpenseClaim(
id=claim_id,
claim_no=f"EXP-{claim_id}",
employee_id=employee.id,
employee_name=employee.name,
department_id="dept-sales",
department_name="市场部",
project_code="PRJ-001",
expense_type=expense_type,
reason="客户拜访",
location="北京",
amount=Decimal(amount),
currency="CNY",
invoice_count=1,
occurred_at=now - timedelta(days=5),
submitted_at=now - timedelta(days=5),
status="submitted",
approval_stage="直属领导审批",
risk_flags_json=[
{
"supplier_id": supplier_id,
"supplier_name": supplier_name,
}
],
items=[
ExpenseClaimItem(
id=f"item-{claim_id}",
claim_id=claim_id,
item_date=date.today(),
item_type=expense_type,
item_reason="客户拜访",
item_location="北京",
item_amount=Decimal(amount),
invoice_id=f"invoice-{claim_id}",
)
],
)