144 lines
5.2 KiB
Python
144 lines
5.2 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import os
|
||
|
|
import threading
|
||
|
|
from datetime import datetime, time, timedelta
|
||
|
|
from zoneinfo import ZoneInfo
|
||
|
|
|
||
|
|
from sqlalchemy import select
|
||
|
|
|
||
|
|
from app.core.agent_enums import AgentRunSource, AgentRunStatus
|
||
|
|
from app.core.logging import get_logger
|
||
|
|
from app.db.session import get_session_factory
|
||
|
|
from app.models.agent_run import AgentRun
|
||
|
|
from app.services.digital_employee_finance_report_task import (
|
||
|
|
FINANCE_REPORT_TASK_TYPE,
|
||
|
|
DigitalEmployeeFinanceReportTaskService,
|
||
|
|
)
|
||
|
|
|
||
|
|
logger = get_logger("app.services.finance_report_scheduler")
|
||
|
|
|
||
|
|
|
||
|
|
class FinanceReportScheduler:
|
||
|
|
def __init__(self) -> None:
|
||
|
|
timezone_name = str(os.environ.get("X_FINANCIAL_SCHEDULER_TZ") or "Asia/Shanghai").strip()
|
||
|
|
report_time = str(os.environ.get("X_FINANCIAL_FINANCE_REPORT_TIME") or "08:30").strip()
|
||
|
|
initial_delay = int(
|
||
|
|
os.environ.get("X_FINANCIAL_FINANCE_REPORT_INITIAL_DELAY_SECONDS") or "36"
|
||
|
|
)
|
||
|
|
self._timezone = ZoneInfo(timezone_name or "Asia/Shanghai")
|
||
|
|
self._report_time = self._parse_time(report_time)
|
||
|
|
self._initial_delay_seconds = max(1, initial_delay)
|
||
|
|
self._stop_event = threading.Event()
|
||
|
|
self._thread: threading.Thread | None = None
|
||
|
|
self._lock = threading.Lock()
|
||
|
|
|
||
|
|
def start(self) -> None:
|
||
|
|
with self._lock:
|
||
|
|
if self._thread is not None and self._thread.is_alive():
|
||
|
|
return
|
||
|
|
self._stop_event.clear()
|
||
|
|
self._thread = threading.Thread(
|
||
|
|
target=self._run_loop,
|
||
|
|
name="finance-report-scheduler",
|
||
|
|
daemon=True,
|
||
|
|
)
|
||
|
|
self._thread.start()
|
||
|
|
logger.info(
|
||
|
|
"Finance report scheduler started timezone=%s report_time=%s",
|
||
|
|
self._timezone.key,
|
||
|
|
self._report_time.strftime("%H:%M"),
|
||
|
|
)
|
||
|
|
|
||
|
|
def shutdown(self) -> None:
|
||
|
|
with self._lock:
|
||
|
|
thread = self._thread
|
||
|
|
self._thread = None
|
||
|
|
self._stop_event.set()
|
||
|
|
if thread is not None and thread.is_alive():
|
||
|
|
thread.join(timeout=3)
|
||
|
|
logger.info("Finance report scheduler stopped")
|
||
|
|
|
||
|
|
def _run_loop(self) -> None:
|
||
|
|
if self._stop_event.wait(self._initial_delay_seconds):
|
||
|
|
return
|
||
|
|
while not self._stop_event.is_set():
|
||
|
|
wait_seconds = self._seconds_until_next_report_time()
|
||
|
|
if self._stop_event.wait(wait_seconds):
|
||
|
|
break
|
||
|
|
self._run_due_reports()
|
||
|
|
|
||
|
|
def _run_due_reports(self) -> None:
|
||
|
|
now = datetime.now(self._timezone)
|
||
|
|
due_types = ["weekly"]
|
||
|
|
if now.day <= 7 and now.month in {1, 4, 7, 10}:
|
||
|
|
due_types.append("quarterly")
|
||
|
|
if now.day <= 7 and now.month == 1:
|
||
|
|
due_types.append("annual")
|
||
|
|
for report_type in due_types:
|
||
|
|
self._run_report_once(report_type=report_type, now=now)
|
||
|
|
|
||
|
|
def _run_report_once(self, *, report_type: str, now: datetime) -> None:
|
||
|
|
db = get_session_factory()()
|
||
|
|
try:
|
||
|
|
if self._already_generated(db, report_type=report_type, now=now):
|
||
|
|
return
|
||
|
|
result = DigitalEmployeeFinanceReportTaskService(db).generate_report(
|
||
|
|
report_type=report_type, # type: ignore[arg-type]
|
||
|
|
source=AgentRunSource.SCHEDULE.value,
|
||
|
|
)
|
||
|
|
db.commit()
|
||
|
|
logger.info(
|
||
|
|
"Finance report generated type=%s status=%s",
|
||
|
|
report_type,
|
||
|
|
(result.get("delivery") or {}).get("status"),
|
||
|
|
)
|
||
|
|
except Exception:
|
||
|
|
db.rollback()
|
||
|
|
logger.exception("Scheduled finance report failed type=%s", report_type)
|
||
|
|
finally:
|
||
|
|
db.close()
|
||
|
|
|
||
|
|
def _already_generated(self, db, *, report_type: str, now: datetime) -> bool:
|
||
|
|
day_start = datetime.combine(
|
||
|
|
now.date(),
|
||
|
|
time.min,
|
||
|
|
tzinfo=self._timezone,
|
||
|
|
).astimezone(ZoneInfo("UTC"))
|
||
|
|
day_end = day_start + timedelta(days=1)
|
||
|
|
stmt = (
|
||
|
|
select(AgentRun)
|
||
|
|
.where(AgentRun.started_at >= day_start)
|
||
|
|
.where(AgentRun.started_at < day_end)
|
||
|
|
.where(AgentRun.status == AgentRunStatus.SUCCEEDED.value)
|
||
|
|
)
|
||
|
|
for run in db.scalars(stmt).all():
|
||
|
|
route_json = run.route_json or {}
|
||
|
|
if (
|
||
|
|
str(route_json.get("task_type") or "") == FINANCE_REPORT_TASK_TYPE
|
||
|
|
and str(route_json.get("report_type") or "") == report_type
|
||
|
|
):
|
||
|
|
return True
|
||
|
|
return False
|
||
|
|
|
||
|
|
def _seconds_until_next_report_time(self) -> float:
|
||
|
|
now = datetime.now(self._timezone)
|
||
|
|
target = datetime.combine(now.date(), self._report_time, tzinfo=self._timezone)
|
||
|
|
if target <= now:
|
||
|
|
target += timedelta(days=1)
|
||
|
|
return max(1.0, (target - now).total_seconds())
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _parse_time(raw_value: str) -> time:
|
||
|
|
try:
|
||
|
|
hour_text, minute_text = str(raw_value or "").split(":", 1)
|
||
|
|
return time(
|
||
|
|
hour=max(0, min(int(hour_text), 23)),
|
||
|
|
minute=max(0, min(int(minute_text), 59)),
|
||
|
|
)
|
||
|
|
except Exception:
|
||
|
|
return time(hour=8, minute=30)
|
||
|
|
|
||
|
|
|
||
|
|
finance_report_scheduler = FinanceReportScheduler()
|