feat(workbench): persist topbar notification state

This commit is contained in:
caoxiaozhu
2026-06-03 21:43:35 +08:00
parent b9826a1985
commit 75d5c178e1
15 changed files with 799 additions and 59 deletions

View File

@@ -0,0 +1,38 @@
from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.api.deps import CurrentUserContext, get_current_user, get_db
from app.schemas.notification_state import NotificationStateBatchPatch, NotificationStateListRead
from app.services.notification_states import NotificationStateService
router = APIRouter(prefix="/notification-states")
DbSession = Annotated[Session, Depends(get_db)]
CurrentUser = Annotated[CurrentUserContext, Depends(get_current_user)]
@router.get(
"",
response_model=NotificationStateListRead,
summary="读取当前用户通知状态",
description="读取当前登录用户的小铃铛通知已读和隐藏状态,用于跨设备保持一致。",
)
def list_notification_states(db: DbSession, current_user: CurrentUser) -> NotificationStateListRead:
return NotificationStateService(db).list_states(current_user)
@router.post(
"",
response_model=NotificationStateListRead,
summary="批量保存当前用户通知状态",
description="批量保存当前登录用户的小铃铛通知已读和隐藏状态。",
)
def patch_notification_states(
payload: NotificationStateBatchPatch,
db: DbSession,
current_user: CurrentUser,
) -> NotificationStateListRead:
return NotificationStateService(db).patch_states(payload, current_user)

View File

@@ -14,6 +14,7 @@ from app.api.v1.endpoints.employees import router as employees_router
from app.api.v1.endpoints.employee_profiles import router as employee_profiles_router
from app.api.v1.endpoints.health import router as health_router
from app.api.v1.endpoints.knowledge import router as knowledge_router
from app.api.v1.endpoints.notification_states import router as notification_states_router
from app.api.v1.endpoints.ocr import router as ocr_router
from app.api.v1.endpoints.ontology import router as ontology_router
from app.api.v1.endpoints.orchestrator import router as orchestrator_router
@@ -36,6 +37,7 @@ router.include_router(agent_traces_router, tags=["agent-traces"])
router.include_router(analytics_router, tags=["analytics"])
router.include_router(audit_logs_router, tags=["audit-logs"])
router.include_router(knowledge_router, tags=["knowledge"])
router.include_router(notification_states_router, tags=["notification-states"])
router.include_router(ocr_router, tags=["ocr"])
router.include_router(ontology_router, tags=["ontology"])
router.include_router(orchestrator_router, tags=["orchestrator"])

View File

@@ -23,6 +23,7 @@ from app.models.financial_record import (
)
from app.models.hermes_config import HermesTaskConfig, HermesTaskExecutionLog
from app.models.hermes_report import HermesRiskReport
from app.models.notification_state import NotificationState
from app.models.organization import OrganizationUnit
from app.models.reimbursement import ReimbursementRequest
from app.models.risk_observation import RiskObservation, RiskObservationFeedback
@@ -60,6 +61,7 @@ __all__ = [
"HermesTaskConfig",
"HermesTaskExecutionLog",
"HermesRiskReport",
"NotificationState",
"OrganizationUnit",
"ReimbursementRequest",
"RiskObservation",

View File

@@ -16,6 +16,7 @@ from app.models.financial_record import (
)
from app.models.hermes_config import HermesTaskConfig, HermesTaskExecutionLog
from app.models.hermes_report import HermesRiskReport
from app.models.notification_state import NotificationState
from app.models.organization import OrganizationUnit
from app.models.reimbursement import ReimbursementRequest
from app.models.risk_observation import RiskObservation, RiskObservationFeedback
@@ -51,6 +52,7 @@ __all__ = [
"HermesTaskConfig",
"HermesTaskExecutionLog",
"HermesRiskReport",
"NotificationState",
"OrganizationUnit",
"ReimbursementRequest",
"RiskObservation",

View File

@@ -0,0 +1,32 @@
from __future__ import annotations
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy import DateTime, Index, String, UniqueConstraint, func
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.types import JSON
from app.db.base_class import Base
class NotificationState(Base):
__tablename__ = "notification_states"
__table_args__ = (
UniqueConstraint("user_id", "notification_id", name="uq_notification_states_user_notification"),
Index("ix_notification_states_user_updated", "user_id", "updated_at"),
)
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
user_id: Mapped[str] = mapped_column(String(100), index=True)
notification_id: Mapped[str] = mapped_column(String(180), index=True)
read_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
hidden_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
context_json: Mapped[dict[str, Any]] = mapped_column(JSON, default=dict)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)

View File

@@ -0,0 +1,45 @@
from __future__ import annotations
from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict, Field, field_validator
def _normalize_text(value: Any) -> str:
return str(value or "").strip()
class NotificationStatePatch(BaseModel):
notification_id: str = Field(min_length=1, max_length=180)
read: bool = False
hidden: bool = False
context_json: dict[str, Any] = Field(default_factory=dict)
@field_validator("notification_id", mode="before")
@classmethod
def normalize_notification_id(cls, value: Any) -> str:
return _normalize_text(value)
@field_validator("context_json", mode="before")
@classmethod
def normalize_context(cls, value: Any) -> dict[str, Any]:
return value if isinstance(value, dict) else {}
class NotificationStateBatchPatch(BaseModel):
states: list[NotificationStatePatch] = Field(default_factory=list, max_length=100)
class NotificationStateRead(BaseModel):
model_config = ConfigDict(from_attributes=True)
notification_id: str
read_at: datetime | None
hidden_at: datetime | None
context_json: dict[str, Any]
updated_at: datetime
class NotificationStateListRead(BaseModel):
states: list[NotificationStateRead] = Field(default_factory=list)

View File

@@ -0,0 +1,88 @@
from __future__ import annotations
from datetime import UTC, datetime
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.api.deps import CurrentUserContext
from app.db.base import Base
from app.models.notification_state import NotificationState
from app.schemas.notification_state import (
NotificationStateBatchPatch,
NotificationStateListRead,
NotificationStateRead,
)
class NotificationStateService:
def __init__(self, db: Session) -> None:
self.db = db
def ensure_storage_ready(self) -> None:
Base.metadata.create_all(bind=self.db.get_bind(), tables=[NotificationState.__table__])
def list_states(self, current_user: CurrentUserContext) -> NotificationStateListRead:
self.ensure_storage_ready()
stmt = (
select(NotificationState)
.where(NotificationState.user_id == self._user_key(current_user))
.order_by(NotificationState.updated_at.desc())
)
states = list(self.db.scalars(stmt).all())
return NotificationStateListRead(
states=[NotificationStateRead.model_validate(item) for item in states]
)
def patch_states(
self,
payload: NotificationStateBatchPatch,
current_user: CurrentUserContext,
) -> NotificationStateListRead:
self.ensure_storage_ready()
user_id = self._user_key(current_user)
patches = [item for item in payload.states if item.notification_id]
if not patches:
return self.list_states(current_user)
ids = {item.notification_id for item in patches}
existing_rows = list(
self.db.scalars(
select(NotificationState).where(
NotificationState.user_id == user_id,
NotificationState.notification_id.in_(ids),
)
).all()
)
existing_by_id = {item.notification_id: item for item in existing_rows}
now = datetime.now(UTC)
for patch in patches:
row = existing_by_id.get(patch.notification_id)
if row is None:
row = NotificationState(
user_id=user_id,
notification_id=patch.notification_id,
context_json={},
)
self.db.add(row)
existing_by_id[patch.notification_id] = row
if patch.read and row.read_at is None:
row.read_at = now
if patch.hidden and row.hidden_at is None:
row.hidden_at = now
if patch.context_json:
row.context_json = self._merge_context(row.context_json, patch.context_json)
self.db.commit()
return self.list_states(current_user)
@staticmethod
def _user_key(current_user: CurrentUserContext) -> str:
return str(current_user.username or current_user.name or "anonymous").strip() or "anonymous"
@staticmethod
def _merge_context(current: dict | None, patch: dict) -> dict:
base = current if isinstance(current, dict) else {}
return {**base, **patch}

View File

@@ -0,0 +1,119 @@
from __future__ import annotations
from collections.abc import Generator
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.api.deps import CurrentUserContext, get_db
from app.db.base import Base
from app.main import create_app
from app.schemas.notification_state import NotificationStateBatchPatch, NotificationStatePatch
from app.services.notification_states import NotificationStateService
def build_session() -> Session:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
return session_factory()
def build_client() -> TestClient:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
app = create_app()
def override_db() -> Generator[Session, None, None]:
db = session_factory()
try:
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_db
return TestClient(app)
def test_notification_state_service_persists_user_scoped_read_and_hidden_state() -> None:
with build_session() as db:
service = NotificationStateService(db)
user = CurrentUserContext(username="alice", name="Alice", role_codes=[], is_admin=False)
other_user = CurrentUserContext(username="bob", name="Bob", role_codes=[], is_admin=False)
saved = service.patch_states(
NotificationStateBatchPatch(
states=[
NotificationStatePatch(
notification_id="document:owned:EXP-001",
read=True,
hidden=True,
context_json={"kind": "document"},
)
]
),
user,
)
other_saved = service.patch_states(
NotificationStateBatchPatch(
states=[
NotificationStatePatch(
notification_id="document:owned:EXP-001",
read=True,
)
]
),
other_user,
)
assert len(saved.states) == 1
assert saved.states[0].notification_id == "document:owned:EXP-001"
assert saved.states[0].read_at is not None
assert saved.states[0].hidden_at is not None
assert saved.states[0].context_json["kind"] == "document"
assert other_saved.states[0].hidden_at is None
def test_notification_state_endpoint_reads_and_updates_current_user_state() -> None:
client = build_client()
headers = {"x-auth-username": "alice", "x-auth-name": "Alice"}
post_response = client.post(
"/api/v1/notification-states",
json={
"states": [
{
"notification_id": "workbench:todo:EXP-002",
"read": True,
"hidden": False,
"context_json": {"kind": "workbench"},
}
]
},
headers=headers,
)
get_response = client.get("/api/v1/notification-states", headers=headers)
other_response = client.get(
"/api/v1/notification-states",
headers={"x-auth-username": "bob", "x-auth-name": "Bob"},
)
assert post_response.status_code == 200
assert get_response.status_code == 200
payload = get_response.json()
assert payload["states"][0]["notification_id"] == "workbench:todo:EXP-002"
assert payload["states"][0]["read_at"] is not None
assert payload["states"][0]["hidden_at"] is None
assert payload["states"][0]["context_json"]["kind"] == "workbench"
assert other_response.json()["states"] == []