Files
X-Financial/server/src/app/services/notification_states.py

102 lines
3.5 KiB
Python
Raw Normal View History

from __future__ import annotations
from datetime import UTC, datetime
from threading import Lock
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.api.deps import CurrentUserContext
from app.db.base import Base
from app.models.notification_state import NotificationState
from app.schemas.notification_state import (
NotificationStateBatchPatch,
NotificationStateListRead,
NotificationStateRead,
)
class NotificationStateService:
_storage_ready_bind_ids: set[int] = set()
_storage_ready_lock = Lock()
def __init__(self, db: Session) -> None:
self.db = db
def ensure_storage_ready(self) -> None:
bind = self.db.get_bind()
bind_id = id(bind)
if bind_id in self._storage_ready_bind_ids:
return
with self._storage_ready_lock:
if bind_id in self._storage_ready_bind_ids:
return
Base.metadata.create_all(bind=bind, tables=[NotificationState.__table__])
self._storage_ready_bind_ids.add(bind_id)
def list_states(self, current_user: CurrentUserContext) -> NotificationStateListRead:
self.ensure_storage_ready()
stmt = (
select(NotificationState)
.where(NotificationState.user_id == self._user_key(current_user))
.order_by(NotificationState.updated_at.desc())
)
states = list(self.db.scalars(stmt).all())
return NotificationStateListRead(
states=[NotificationStateRead.model_validate(item) for item in states]
)
def patch_states(
self,
payload: NotificationStateBatchPatch,
current_user: CurrentUserContext,
) -> NotificationStateListRead:
self.ensure_storage_ready()
user_id = self._user_key(current_user)
patches = [item for item in payload.states if item.notification_id]
if not patches:
return self.list_states(current_user)
ids = {item.notification_id for item in patches}
existing_rows = list(
self.db.scalars(
select(NotificationState).where(
NotificationState.user_id == user_id,
NotificationState.notification_id.in_(ids),
)
).all()
)
existing_by_id = {item.notification_id: item for item in existing_rows}
now = datetime.now(UTC)
for patch in patches:
row = existing_by_id.get(patch.notification_id)
if row is None:
row = NotificationState(
user_id=user_id,
notification_id=patch.notification_id,
context_json={},
)
self.db.add(row)
existing_by_id[patch.notification_id] = row
if patch.read and row.read_at is None:
row.read_at = now
if patch.hidden and row.hidden_at is None:
row.hidden_at = now
if patch.context_json:
row.context_json = self._merge_context(row.context_json, patch.context_json)
self.db.commit()
return self.list_states(current_user)
@staticmethod
def _user_key(current_user: CurrentUserContext) -> str:
return str(current_user.username or current_user.name or "anonymous").strip() or "anonymous"
@staticmethod
def _merge_context(current: dict | None, patch: dict) -> dict:
base = current if isinstance(current, dict) else {}
return {**base, **patch}