2026-06-03 21:43:35 +08:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
from collections.abc import Generator
|
|
|
|
|
|
|
|
|
|
from fastapi.testclient import TestClient
|
|
|
|
|
from sqlalchemy import create_engine
|
|
|
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
|
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
|
|
|
|
|
|
from app.api.deps import CurrentUserContext, get_db
|
|
|
|
|
from app.db.base import Base
|
|
|
|
|
from app.main import create_app
|
|
|
|
|
from app.schemas.notification_state import NotificationStateBatchPatch, NotificationStatePatch
|
|
|
|
|
from app.services.notification_states import NotificationStateService
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_session() -> Session:
|
|
|
|
|
engine = create_engine(
|
|
|
|
|
"sqlite+pysqlite:///:memory:",
|
|
|
|
|
connect_args={"check_same_thread": False},
|
|
|
|
|
poolclass=StaticPool,
|
|
|
|
|
)
|
|
|
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
|
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
|
|
|
|
return session_factory()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_client() -> TestClient:
|
|
|
|
|
engine = create_engine(
|
|
|
|
|
"sqlite+pysqlite:///:memory:",
|
|
|
|
|
connect_args={"check_same_thread": False},
|
|
|
|
|
poolclass=StaticPool,
|
|
|
|
|
)
|
|
|
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
|
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
|
|
|
|
app = create_app()
|
|
|
|
|
|
|
|
|
|
def override_db() -> Generator[Session, None, None]:
|
|
|
|
|
db = session_factory()
|
|
|
|
|
try:
|
|
|
|
|
yield db
|
|
|
|
|
finally:
|
|
|
|
|
db.close()
|
|
|
|
|
|
|
|
|
|
app.dependency_overrides[get_db] = override_db
|
|
|
|
|
return TestClient(app)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_notification_state_service_persists_user_scoped_read_and_hidden_state() -> None:
|
|
|
|
|
with build_session() as db:
|
|
|
|
|
service = NotificationStateService(db)
|
|
|
|
|
user = CurrentUserContext(username="alice", name="Alice", role_codes=[], is_admin=False)
|
|
|
|
|
other_user = CurrentUserContext(username="bob", name="Bob", role_codes=[], is_admin=False)
|
|
|
|
|
|
|
|
|
|
saved = service.patch_states(
|
|
|
|
|
NotificationStateBatchPatch(
|
|
|
|
|
states=[
|
|
|
|
|
NotificationStatePatch(
|
|
|
|
|
notification_id="document:owned:EXP-001",
|
|
|
|
|
read=True,
|
|
|
|
|
hidden=True,
|
|
|
|
|
context_json={"kind": "document"},
|
|
|
|
|
)
|
|
|
|
|
]
|
|
|
|
|
),
|
|
|
|
|
user,
|
|
|
|
|
)
|
|
|
|
|
other_saved = service.patch_states(
|
|
|
|
|
NotificationStateBatchPatch(
|
|
|
|
|
states=[
|
|
|
|
|
NotificationStatePatch(
|
|
|
|
|
notification_id="document:owned:EXP-001",
|
|
|
|
|
read=True,
|
|
|
|
|
)
|
|
|
|
|
]
|
|
|
|
|
),
|
|
|
|
|
other_user,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert len(saved.states) == 1
|
|
|
|
|
assert saved.states[0].notification_id == "document:owned:EXP-001"
|
|
|
|
|
assert saved.states[0].read_at is not None
|
|
|
|
|
assert saved.states[0].hidden_at is not None
|
|
|
|
|
assert saved.states[0].context_json["kind"] == "document"
|
|
|
|
|
assert other_saved.states[0].hidden_at is None
|
|
|
|
|
|
|
|
|
|
|
2026-06-24 10:42:24 +08:00
|
|
|
def test_notification_state_storage_ready_runs_once_per_database_bind(monkeypatch) -> None:
|
|
|
|
|
with build_session() as db:
|
|
|
|
|
service = NotificationStateService(db)
|
|
|
|
|
user = CurrentUserContext(username="alice", name="Alice", role_codes=[], is_admin=False)
|
|
|
|
|
calls: list[object] = []
|
|
|
|
|
original_create_all = Base.metadata.create_all
|
|
|
|
|
|
|
|
|
|
def track_create_all(*args, **kwargs):
|
|
|
|
|
calls.append(kwargs.get("bind"))
|
|
|
|
|
return original_create_all(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
monkeypatch.setattr(Base.metadata, "create_all", track_create_all)
|
|
|
|
|
|
|
|
|
|
service.list_states(user)
|
|
|
|
|
service.list_states(user)
|
|
|
|
|
service.patch_states(
|
|
|
|
|
NotificationStateBatchPatch(
|
|
|
|
|
states=[NotificationStatePatch(notification_id="workbench:todo:EXP-002", read=True)]
|
|
|
|
|
),
|
|
|
|
|
user,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert len(calls) == 1
|
|
|
|
|
|
|
|
|
|
|
2026-06-03 21:43:35 +08:00
|
|
|
def test_notification_state_endpoint_reads_and_updates_current_user_state() -> None:
|
|
|
|
|
client = build_client()
|
|
|
|
|
headers = {"x-auth-username": "alice", "x-auth-name": "Alice"}
|
|
|
|
|
|
|
|
|
|
post_response = client.post(
|
|
|
|
|
"/api/v1/notification-states",
|
|
|
|
|
json={
|
|
|
|
|
"states": [
|
|
|
|
|
{
|
|
|
|
|
"notification_id": "workbench:todo:EXP-002",
|
|
|
|
|
"read": True,
|
|
|
|
|
"hidden": False,
|
|
|
|
|
"context_json": {"kind": "workbench"},
|
|
|
|
|
}
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
headers=headers,
|
|
|
|
|
)
|
|
|
|
|
get_response = client.get("/api/v1/notification-states", headers=headers)
|
|
|
|
|
other_response = client.get(
|
|
|
|
|
"/api/v1/notification-states",
|
|
|
|
|
headers={"x-auth-username": "bob", "x-auth-name": "Bob"},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert post_response.status_code == 200
|
|
|
|
|
assert get_response.status_code == 200
|
|
|
|
|
payload = get_response.json()
|
|
|
|
|
assert payload["states"][0]["notification_id"] == "workbench:todo:EXP-002"
|
|
|
|
|
assert payload["states"][0]["read_at"] is not None
|
|
|
|
|
assert payload["states"][0]["hidden_at"] is None
|
|
|
|
|
assert payload["states"][0]["context_json"]["kind"] == "workbench"
|
|
|
|
|
assert other_response.json()["states"] == []
|
2026-06-06 17:19:07 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_notification_state_endpoint_accepts_document_center_bulk_read_state() -> None:
|
|
|
|
|
client = build_client()
|
|
|
|
|
headers = {"x-auth-username": "alice", "x-auth-name": "Alice"}
|
|
|
|
|
states = [
|
|
|
|
|
{
|
|
|
|
|
"notification_id": f"document:owned:DOC-{index}",
|
|
|
|
|
"read": True,
|
|
|
|
|
"hidden": False,
|
|
|
|
|
"context_json": {"kind": "document", "target_type": "documents-center"},
|
|
|
|
|
}
|
|
|
|
|
for index in range(150)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
response = client.post(
|
|
|
|
|
"/api/v1/notification-states",
|
|
|
|
|
json={"states": states},
|
|
|
|
|
headers=headers,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert response.status_code == 200
|
|
|
|
|
payload = response.json()
|
|
|
|
|
assert len(payload["states"]) == 150
|
|
|
|
|
assert all(item["read_at"] for item in payload["states"])
|