from __future__ import annotations from collections.abc import Generator from fastapi.testclient import TestClient from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.pool import StaticPool from app.api.deps import CurrentUserContext, get_db from app.db.base import Base from app.main import create_app from app.schemas.notification_state import NotificationStateBatchPatch, NotificationStatePatch from app.services.notification_states import NotificationStateService def build_session() -> Session: engine = create_engine( "sqlite+pysqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) Base.metadata.create_all(bind=engine) session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False) return session_factory() def build_client() -> TestClient: engine = create_engine( "sqlite+pysqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) Base.metadata.create_all(bind=engine) session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False) app = create_app() def override_db() -> Generator[Session, None, None]: db = session_factory() try: yield db finally: db.close() app.dependency_overrides[get_db] = override_db return TestClient(app) def test_notification_state_service_persists_user_scoped_read_and_hidden_state() -> None: with build_session() as db: service = NotificationStateService(db) user = CurrentUserContext(username="alice", name="Alice", role_codes=[], is_admin=False) other_user = CurrentUserContext(username="bob", name="Bob", role_codes=[], is_admin=False) saved = service.patch_states( NotificationStateBatchPatch( states=[ NotificationStatePatch( notification_id="document:owned:EXP-001", read=True, hidden=True, context_json={"kind": "document"}, ) ] ), user, ) other_saved = service.patch_states( NotificationStateBatchPatch( states=[ NotificationStatePatch( notification_id="document:owned:EXP-001", read=True, ) ] ), other_user, ) assert len(saved.states) == 1 assert saved.states[0].notification_id == "document:owned:EXP-001" assert saved.states[0].read_at is not None assert saved.states[0].hidden_at is not None assert saved.states[0].context_json["kind"] == "document" assert other_saved.states[0].hidden_at is None def test_notification_state_storage_ready_runs_once_per_database_bind(monkeypatch) -> None: with build_session() as db: service = NotificationStateService(db) user = CurrentUserContext(username="alice", name="Alice", role_codes=[], is_admin=False) calls: list[object] = [] original_create_all = Base.metadata.create_all def track_create_all(*args, **kwargs): calls.append(kwargs.get("bind")) return original_create_all(*args, **kwargs) monkeypatch.setattr(Base.metadata, "create_all", track_create_all) service.list_states(user) service.list_states(user) service.patch_states( NotificationStateBatchPatch( states=[NotificationStatePatch(notification_id="workbench:todo:EXP-002", read=True)] ), user, ) assert len(calls) == 1 def test_notification_state_endpoint_reads_and_updates_current_user_state() -> None: client = build_client() headers = {"x-auth-username": "alice", "x-auth-name": "Alice"} post_response = client.post( "/api/v1/notification-states", json={ "states": [ { "notification_id": "workbench:todo:EXP-002", "read": True, "hidden": False, "context_json": {"kind": "workbench"}, } ] }, headers=headers, ) get_response = client.get("/api/v1/notification-states", headers=headers) other_response = client.get( "/api/v1/notification-states", headers={"x-auth-username": "bob", "x-auth-name": "Bob"}, ) assert post_response.status_code == 200 assert get_response.status_code == 200 payload = get_response.json() assert payload["states"][0]["notification_id"] == "workbench:todo:EXP-002" assert payload["states"][0]["read_at"] is not None assert payload["states"][0]["hidden_at"] is None assert payload["states"][0]["context_json"]["kind"] == "workbench" assert other_response.json()["states"] == [] def test_notification_state_endpoint_accepts_document_center_bulk_read_state() -> None: client = build_client() headers = {"x-auth-username": "alice", "x-auth-name": "Alice"} states = [ { "notification_id": f"document:owned:DOC-{index}", "read": True, "hidden": False, "context_json": {"kind": "document", "target_type": "documents-center"}, } for index in range(150) ] response = client.post( "/api/v1/notification-states", json={"states": states}, headers=headers, ) assert response.status_code == 200 payload = response.json() assert len(payload["states"]) == 150 assert all(item["read_at"] for item in payload["states"])