from __future__ import annotations from datetime import UTC, datetime from threading import Lock from sqlalchemy import select from sqlalchemy.orm import Session from app.api.deps import CurrentUserContext from app.db.base import Base from app.models.notification_state import NotificationState from app.schemas.notification_state import ( NotificationStateBatchPatch, NotificationStateListRead, NotificationStateRead, ) class NotificationStateService: _storage_ready_bind_ids: set[int] = set() _storage_ready_lock = Lock() def __init__(self, db: Session) -> None: self.db = db def ensure_storage_ready(self) -> None: bind = self.db.get_bind() bind_id = id(bind) if bind_id in self._storage_ready_bind_ids: return with self._storage_ready_lock: if bind_id in self._storage_ready_bind_ids: return Base.metadata.create_all(bind=bind, tables=[NotificationState.__table__]) self._storage_ready_bind_ids.add(bind_id) def list_states(self, current_user: CurrentUserContext) -> NotificationStateListRead: self.ensure_storage_ready() stmt = ( select(NotificationState) .where(NotificationState.user_id == self._user_key(current_user)) .order_by(NotificationState.updated_at.desc()) ) states = list(self.db.scalars(stmt).all()) return NotificationStateListRead( states=[NotificationStateRead.model_validate(item) for item in states] ) def patch_states( self, payload: NotificationStateBatchPatch, current_user: CurrentUserContext, ) -> NotificationStateListRead: self.ensure_storage_ready() user_id = self._user_key(current_user) patches = [item for item in payload.states if item.notification_id] if not patches: return self.list_states(current_user) ids = {item.notification_id for item in patches} existing_rows = list( self.db.scalars( select(NotificationState).where( NotificationState.user_id == user_id, NotificationState.notification_id.in_(ids), ) ).all() ) existing_by_id = {item.notification_id: item for item in existing_rows} now = datetime.now(UTC) for patch in patches: row = existing_by_id.get(patch.notification_id) if row is None: row = NotificationState( user_id=user_id, notification_id=patch.notification_id, context_json={}, ) self.db.add(row) existing_by_id[patch.notification_id] = row if patch.read and row.read_at is None: row.read_at = now if patch.hidden and row.hidden_at is None: row.hidden_at = now if patch.context_json: row.context_json = self._merge_context(row.context_json, patch.context_json) self.db.commit() return self.list_states(current_user) @staticmethod def _user_key(current_user: CurrentUserContext) -> str: return str(current_user.username or current_user.name or "anonymous").strip() or "anonymous" @staticmethod def _merge_context(current: dict | None, patch: dict) -> dict: base = current if isinstance(current, dict) else {} return {**base, **patch}