feat(workbench): persist topbar notification state
This commit is contained in:
38
server/src/app/api/v1/endpoints/notification_states.py
Normal file
38
server/src/app/api/v1/endpoints/notification_states.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.deps import CurrentUserContext, get_current_user, get_db
|
||||
from app.schemas.notification_state import NotificationStateBatchPatch, NotificationStateListRead
|
||||
from app.services.notification_states import NotificationStateService
|
||||
|
||||
router = APIRouter(prefix="/notification-states")
|
||||
DbSession = Annotated[Session, Depends(get_db)]
|
||||
CurrentUser = Annotated[CurrentUserContext, Depends(get_current_user)]
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=NotificationStateListRead,
|
||||
summary="读取当前用户通知状态",
|
||||
description="读取当前登录用户的小铃铛通知已读和隐藏状态,用于跨设备保持一致。",
|
||||
)
|
||||
def list_notification_states(db: DbSession, current_user: CurrentUser) -> NotificationStateListRead:
|
||||
return NotificationStateService(db).list_states(current_user)
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=NotificationStateListRead,
|
||||
summary="批量保存当前用户通知状态",
|
||||
description="批量保存当前登录用户的小铃铛通知已读和隐藏状态。",
|
||||
)
|
||||
def patch_notification_states(
|
||||
payload: NotificationStateBatchPatch,
|
||||
db: DbSession,
|
||||
current_user: CurrentUser,
|
||||
) -> NotificationStateListRead:
|
||||
return NotificationStateService(db).patch_states(payload, current_user)
|
||||
@@ -14,6 +14,7 @@ from app.api.v1.endpoints.employees import router as employees_router
|
||||
from app.api.v1.endpoints.employee_profiles import router as employee_profiles_router
|
||||
from app.api.v1.endpoints.health import router as health_router
|
||||
from app.api.v1.endpoints.knowledge import router as knowledge_router
|
||||
from app.api.v1.endpoints.notification_states import router as notification_states_router
|
||||
from app.api.v1.endpoints.ocr import router as ocr_router
|
||||
from app.api.v1.endpoints.ontology import router as ontology_router
|
||||
from app.api.v1.endpoints.orchestrator import router as orchestrator_router
|
||||
@@ -36,6 +37,7 @@ router.include_router(agent_traces_router, tags=["agent-traces"])
|
||||
router.include_router(analytics_router, tags=["analytics"])
|
||||
router.include_router(audit_logs_router, tags=["audit-logs"])
|
||||
router.include_router(knowledge_router, tags=["knowledge"])
|
||||
router.include_router(notification_states_router, tags=["notification-states"])
|
||||
router.include_router(ocr_router, tags=["ocr"])
|
||||
router.include_router(ontology_router, tags=["ontology"])
|
||||
router.include_router(orchestrator_router, tags=["orchestrator"])
|
||||
|
||||
@@ -23,6 +23,7 @@ from app.models.financial_record import (
|
||||
)
|
||||
from app.models.hermes_config import HermesTaskConfig, HermesTaskExecutionLog
|
||||
from app.models.hermes_report import HermesRiskReport
|
||||
from app.models.notification_state import NotificationState
|
||||
from app.models.organization import OrganizationUnit
|
||||
from app.models.reimbursement import ReimbursementRequest
|
||||
from app.models.risk_observation import RiskObservation, RiskObservationFeedback
|
||||
@@ -60,6 +61,7 @@ __all__ = [
|
||||
"HermesTaskConfig",
|
||||
"HermesTaskExecutionLog",
|
||||
"HermesRiskReport",
|
||||
"NotificationState",
|
||||
"OrganizationUnit",
|
||||
"ReimbursementRequest",
|
||||
"RiskObservation",
|
||||
|
||||
@@ -16,6 +16,7 @@ from app.models.financial_record import (
|
||||
)
|
||||
from app.models.hermes_config import HermesTaskConfig, HermesTaskExecutionLog
|
||||
from app.models.hermes_report import HermesRiskReport
|
||||
from app.models.notification_state import NotificationState
|
||||
from app.models.organization import OrganizationUnit
|
||||
from app.models.reimbursement import ReimbursementRequest
|
||||
from app.models.risk_observation import RiskObservation, RiskObservationFeedback
|
||||
@@ -51,6 +52,7 @@ __all__ = [
|
||||
"HermesTaskConfig",
|
||||
"HermesTaskExecutionLog",
|
||||
"HermesRiskReport",
|
||||
"NotificationState",
|
||||
"OrganizationUnit",
|
||||
"ReimbursementRequest",
|
||||
"RiskObservation",
|
||||
|
||||
32
server/src/app/models/notification_state.py
Normal file
32
server/src/app/models/notification_state.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import DateTime, Index, String, UniqueConstraint, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.types import JSON
|
||||
|
||||
from app.db.base_class import Base
|
||||
|
||||
|
||||
class NotificationState(Base):
|
||||
__tablename__ = "notification_states"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("user_id", "notification_id", name="uq_notification_states_user_notification"),
|
||||
Index("ix_notification_states_user_updated", "user_id", "updated_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
user_id: Mapped[str] = mapped_column(String(100), index=True)
|
||||
notification_id: Mapped[str] = mapped_column(String(180), index=True)
|
||||
read_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
hidden_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
context_json: Mapped[dict[str, Any]] = mapped_column(JSON, default=dict)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
45
server/src/app/schemas/notification_state.py
Normal file
45
server/src/app/schemas/notification_state.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
def _normalize_text(value: Any) -> str:
|
||||
return str(value or "").strip()
|
||||
|
||||
|
||||
class NotificationStatePatch(BaseModel):
|
||||
notification_id: str = Field(min_length=1, max_length=180)
|
||||
read: bool = False
|
||||
hidden: bool = False
|
||||
context_json: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("notification_id", mode="before")
|
||||
@classmethod
|
||||
def normalize_notification_id(cls, value: Any) -> str:
|
||||
return _normalize_text(value)
|
||||
|
||||
@field_validator("context_json", mode="before")
|
||||
@classmethod
|
||||
def normalize_context(cls, value: Any) -> dict[str, Any]:
|
||||
return value if isinstance(value, dict) else {}
|
||||
|
||||
|
||||
class NotificationStateBatchPatch(BaseModel):
|
||||
states: list[NotificationStatePatch] = Field(default_factory=list, max_length=100)
|
||||
|
||||
|
||||
class NotificationStateRead(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
notification_id: str
|
||||
read_at: datetime | None
|
||||
hidden_at: datetime | None
|
||||
context_json: dict[str, Any]
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class NotificationStateListRead(BaseModel):
|
||||
states: list[NotificationStateRead] = Field(default_factory=list)
|
||||
88
server/src/app/services/notification_states.py
Normal file
88
server/src/app/services/notification_states.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.deps import CurrentUserContext
|
||||
from app.db.base import Base
|
||||
from app.models.notification_state import NotificationState
|
||||
from app.schemas.notification_state import (
|
||||
NotificationStateBatchPatch,
|
||||
NotificationStateListRead,
|
||||
NotificationStateRead,
|
||||
)
|
||||
|
||||
|
||||
class NotificationStateService:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def ensure_storage_ready(self) -> None:
|
||||
Base.metadata.create_all(bind=self.db.get_bind(), tables=[NotificationState.__table__])
|
||||
|
||||
def list_states(self, current_user: CurrentUserContext) -> NotificationStateListRead:
|
||||
self.ensure_storage_ready()
|
||||
stmt = (
|
||||
select(NotificationState)
|
||||
.where(NotificationState.user_id == self._user_key(current_user))
|
||||
.order_by(NotificationState.updated_at.desc())
|
||||
)
|
||||
states = list(self.db.scalars(stmt).all())
|
||||
return NotificationStateListRead(
|
||||
states=[NotificationStateRead.model_validate(item) for item in states]
|
||||
)
|
||||
|
||||
def patch_states(
|
||||
self,
|
||||
payload: NotificationStateBatchPatch,
|
||||
current_user: CurrentUserContext,
|
||||
) -> NotificationStateListRead:
|
||||
self.ensure_storage_ready()
|
||||
user_id = self._user_key(current_user)
|
||||
patches = [item for item in payload.states if item.notification_id]
|
||||
if not patches:
|
||||
return self.list_states(current_user)
|
||||
|
||||
ids = {item.notification_id for item in patches}
|
||||
existing_rows = list(
|
||||
self.db.scalars(
|
||||
select(NotificationState).where(
|
||||
NotificationState.user_id == user_id,
|
||||
NotificationState.notification_id.in_(ids),
|
||||
)
|
||||
).all()
|
||||
)
|
||||
existing_by_id = {item.notification_id: item for item in existing_rows}
|
||||
now = datetime.now(UTC)
|
||||
|
||||
for patch in patches:
|
||||
row = existing_by_id.get(patch.notification_id)
|
||||
if row is None:
|
||||
row = NotificationState(
|
||||
user_id=user_id,
|
||||
notification_id=patch.notification_id,
|
||||
context_json={},
|
||||
)
|
||||
self.db.add(row)
|
||||
existing_by_id[patch.notification_id] = row
|
||||
|
||||
if patch.read and row.read_at is None:
|
||||
row.read_at = now
|
||||
if patch.hidden and row.hidden_at is None:
|
||||
row.hidden_at = now
|
||||
if patch.context_json:
|
||||
row.context_json = self._merge_context(row.context_json, patch.context_json)
|
||||
|
||||
self.db.commit()
|
||||
return self.list_states(current_user)
|
||||
|
||||
@staticmethod
|
||||
def _user_key(current_user: CurrentUserContext) -> str:
|
||||
return str(current_user.username or current_user.name or "anonymous").strip() or "anonymous"
|
||||
|
||||
@staticmethod
|
||||
def _merge_context(current: dict | None, patch: dict) -> dict:
|
||||
base = current if isinstance(current, dict) else {}
|
||||
return {**base, **patch}
|
||||
119
server/tests/test_notification_states.py
Normal file
119
server/tests/test_notification_states.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.api.deps import CurrentUserContext, get_db
|
||||
from app.db.base import Base
|
||||
from app.main import create_app
|
||||
from app.schemas.notification_state import NotificationStateBatchPatch, NotificationStatePatch
|
||||
from app.services.notification_states import NotificationStateService
|
||||
|
||||
|
||||
def build_session() -> Session:
|
||||
engine = create_engine(
|
||||
"sqlite+pysqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||||
return session_factory()
|
||||
|
||||
|
||||
def build_client() -> TestClient:
|
||||
engine = create_engine(
|
||||
"sqlite+pysqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||||
app = create_app()
|
||||
|
||||
def override_db() -> Generator[Session, None, None]:
|
||||
db = session_factory()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
app.dependency_overrides[get_db] = override_db
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_notification_state_service_persists_user_scoped_read_and_hidden_state() -> None:
|
||||
with build_session() as db:
|
||||
service = NotificationStateService(db)
|
||||
user = CurrentUserContext(username="alice", name="Alice", role_codes=[], is_admin=False)
|
||||
other_user = CurrentUserContext(username="bob", name="Bob", role_codes=[], is_admin=False)
|
||||
|
||||
saved = service.patch_states(
|
||||
NotificationStateBatchPatch(
|
||||
states=[
|
||||
NotificationStatePatch(
|
||||
notification_id="document:owned:EXP-001",
|
||||
read=True,
|
||||
hidden=True,
|
||||
context_json={"kind": "document"},
|
||||
)
|
||||
]
|
||||
),
|
||||
user,
|
||||
)
|
||||
other_saved = service.patch_states(
|
||||
NotificationStateBatchPatch(
|
||||
states=[
|
||||
NotificationStatePatch(
|
||||
notification_id="document:owned:EXP-001",
|
||||
read=True,
|
||||
)
|
||||
]
|
||||
),
|
||||
other_user,
|
||||
)
|
||||
|
||||
assert len(saved.states) == 1
|
||||
assert saved.states[0].notification_id == "document:owned:EXP-001"
|
||||
assert saved.states[0].read_at is not None
|
||||
assert saved.states[0].hidden_at is not None
|
||||
assert saved.states[0].context_json["kind"] == "document"
|
||||
assert other_saved.states[0].hidden_at is None
|
||||
|
||||
|
||||
def test_notification_state_endpoint_reads_and_updates_current_user_state() -> None:
|
||||
client = build_client()
|
||||
headers = {"x-auth-username": "alice", "x-auth-name": "Alice"}
|
||||
|
||||
post_response = client.post(
|
||||
"/api/v1/notification-states",
|
||||
json={
|
||||
"states": [
|
||||
{
|
||||
"notification_id": "workbench:todo:EXP-002",
|
||||
"read": True,
|
||||
"hidden": False,
|
||||
"context_json": {"kind": "workbench"},
|
||||
}
|
||||
]
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
get_response = client.get("/api/v1/notification-states", headers=headers)
|
||||
other_response = client.get(
|
||||
"/api/v1/notification-states",
|
||||
headers={"x-auth-username": "bob", "x-auth-name": "Bob"},
|
||||
)
|
||||
|
||||
assert post_response.status_code == 200
|
||||
assert get_response.status_code == 200
|
||||
payload = get_response.json()
|
||||
assert payload["states"][0]["notification_id"] == "workbench:todo:EXP-002"
|
||||
assert payload["states"][0]["read_at"] is not None
|
||||
assert payload["states"][0]["hidden_at"] is None
|
||||
assert payload["states"][0]["context_json"]["kind"] == "workbench"
|
||||
assert other_response.json()["states"] == []
|
||||
Reference in New Issue
Block a user