fix(auth): 登录目录就绪幂等化与并发控制

- employee/settings/user_session_metrics 的 ensure_*_ready 改为按 bind 缓存 + 锁,
  避免每次登录重复建表与并发场景下的竞态
- auth 登录链路先查员工再降级触发目录就绪,并吞掉查询期 SQLAlchemy 异常
- 默认管理员账号由 superadmin 迁移为 admin,兼容历史账号回填
- 补充登录降级与设置持久化相关测试
This commit is contained in:
caoxiaozhu
2026-06-18 22:11:53 +08:00
parent 59ba76c74a
commit 3f17619e0c
7 changed files with 155 additions and 19 deletions

View File

@@ -20,7 +20,7 @@ class SystemSetting(Base):
copyright_text: Mapped[str] = mapped_column(String(255), default="") copyright_text: Mapped[str] = mapped_column(String(255), default="")
theme_skin: Mapped[str] = mapped_column(String(64), default="sky") theme_skin: Mapped[str] = mapped_column(String(64), default="sky")
admin_account: Mapped[str] = mapped_column(String(120), default="superadmin") admin_account: Mapped[str] = mapped_column(String(120), default="admin")
admin_email: Mapped[str] = mapped_column(String(255), default="") admin_email: Mapped[str] = mapped_column(String(255), default="")
session_timeout: Mapped[int] = mapped_column(Integer, default=30) session_timeout: Mapped[int] = mapped_column(Integer, default=30)
conversation_retention_days: Mapped[int] = mapped_column(Integer, default=3) conversation_retention_days: Mapped[int] = mapped_column(Integer, default=3)

View File

@@ -5,6 +5,7 @@ from datetime import UTC, datetime, timedelta
from typing import Any from typing import Any
from sqlalchemy import func, or_, select from sqlalchemy import func, or_, select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session, selectinload from sqlalchemy.orm import Session, selectinload
from app.core.config import get_settings from app.core.config import get_settings
@@ -127,8 +128,15 @@ class AuthService:
if not self.settings.setup_completed: if not self.settings.setup_completed:
return None return None
EmployeeService(self.db).ensure_directory_ready() try:
employee = self._find_employee_by_email(identifier) employee = self._find_employee_by_email(identifier)
except SQLAlchemyError:
self.db.rollback()
employee = None
if employee is None:
EmployeeService(self.db).ensure_directory_ready()
employee = self._find_employee_by_email(identifier)
if employee is None or not employee.password_hash: if employee is None or not employee.password_hash:
return None return None

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from collections import Counter from collections import Counter
from datetime import UTC, date, datetime from datetime import UTC, date, datetime
import threading
from typing import Any from typing import Any
from sqlalchemy import select from sqlalchemy import select
@@ -81,11 +82,31 @@ def prepare_employee_directory() -> None:
class EmployeeService: class EmployeeService:
_directory_ready_lock = threading.Lock()
_directory_ready_keys: set[tuple[str, int]] = set()
def __init__(self, db: Session) -> None: def __init__(self, db: Session) -> None:
self.db = db self.db = db
self.repository = EmployeeRepository(db) self.repository = EmployeeRepository(db)
@staticmethod
def _bind_cache_key(db: Session) -> tuple[str, int]:
bind = db.get_bind()
return (bind.url.render_as_string(hide_password=True), id(bind.pool))
def ensure_directory_ready(self) -> None: def ensure_directory_ready(self) -> None:
cache_key = self._bind_cache_key(self.db)
if cache_key in self._directory_ready_keys:
return
with self._directory_ready_lock:
if cache_key in self._directory_ready_keys:
return
self._ensure_directory_ready_uncached()
self._directory_ready_keys.add(cache_key)
def _ensure_directory_ready_uncached(self) -> None:
try: try:
Base.metadata.create_all(bind=self.db.get_bind()) Base.metadata.create_all(bind=self.db.get_bind())
ensure_employee_schema(self.db) ensure_employee_schema(self.db)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import threading
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
@@ -28,9 +29,13 @@ from app.services.hermes_sync import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_ADMIN_ACCOUNT = "admin"
@dataclass(frozen=True, slots=True) DEFAULT_ADMIN_PASSWORD = "admin"
class ModelSlotConfig: LEGACY_DEFAULT_ADMIN_ACCOUNTS = {"", "superadmin"}
@dataclass(frozen=True, slots=True)
class ModelSlotConfig:
provider_attr: str provider_attr: str
model_attr: str model_attr: str
endpoint_attr: str endpoint_attr: str
@@ -106,14 +111,27 @@ class OnlyOfficeRuntimeConfig:
class SettingsService: class SettingsService:
def __init__(self, db: Session) -> None: _schema_ready_lock = threading.Lock()
self.db = db _schema_ready_keys: set[tuple[str, int]] = set()
self.repository = SettingsRepository(db)
self.runtime_settings = get_settings() def __init__(self, db: Session) -> None:
self.db = db
self.repository = SettingsRepository(db)
self.runtime_settings = get_settings()
@staticmethod
def _bind_cache_key(db: Session) -> tuple[str, int]:
bind = db.get_bind()
return (bind.url.render_as_string(hide_password=True), id(bind.pool))
def ensure_settings_ready(self) -> tuple[SystemSetting, SystemSettingSecret]: def ensure_settings_ready(self) -> tuple[SystemSetting, SystemSettingSecret]:
Base.metadata.create_all(bind=self.db.get_bind()) cache_key = self._bind_cache_key(self.db)
self._ensure_settings_schema() if cache_key not in self._schema_ready_keys:
with self._schema_ready_lock:
if cache_key not in self._schema_ready_keys:
Base.metadata.create_all(bind=self.db.get_bind())
self._ensure_settings_schema()
self._schema_ready_keys.add(cache_key)
settings_row = self.repository.get_settings() settings_row = self.repository.get_settings()
secrets_row = self.repository.get_secrets() secrets_row = self.repository.get_secrets()
@@ -130,12 +148,17 @@ class SettingsService:
self.db.add(secrets_row) self.db.add(secrets_row)
should_commit = True should_commit = True
if legacy_admin is not None and not secrets_row.admin_password_hash: if legacy_admin is not None and not secrets_row.admin_password_hash:
secrets_row.admin_password_hash = legacy_admin_secret_to_password_hash(legacy_admin) secrets_row.admin_password_hash = legacy_admin_secret_to_password_hash(legacy_admin)
admin_username = str(legacy_admin.get("username", "")).strip() admin_username = str(legacy_admin.get("username", "")).strip()
if admin_username and str(settings_row.admin_account or "").strip() in {"", "superadmin"}: if admin_username and str(settings_row.admin_account or "").strip() in LEGACY_DEFAULT_ADMIN_ACCOUNTS:
settings_row.admin_account = admin_username settings_row.admin_account = admin_username
should_commit = True should_commit = True
elif legacy_admin is None and not secrets_row.admin_password_hash:
secrets_row.admin_password_hash = hash_password(DEFAULT_ADMIN_PASSWORD)
if str(settings_row.admin_account or "").strip() in LEGACY_DEFAULT_ADMIN_ACCOUNTS:
settings_row.admin_account = DEFAULT_ADMIN_ACCOUNT
should_commit = True
if self._sync_onlyoffice_defaults(settings_row, secrets_row): if self._sync_onlyoffice_defaults(settings_row, secrets_row):
should_commit = True should_commit = True
@@ -454,7 +477,7 @@ class SettingsService:
company_code = str(self.runtime_settings.company_code or "XF-001").strip() or "XF-001" company_code = str(self.runtime_settings.company_code or "XF-001").strip() or "XF-001"
admin_email = str(self.runtime_settings.admin_email or "").strip() admin_email = str(self.runtime_settings.admin_email or "").strip()
legacy_admin = read_admin_secret() or {} legacy_admin = read_admin_secret() or {}
admin_account = str(legacy_admin.get("username", "")).strip() or "superadmin" admin_account = str(legacy_admin.get("username", "")).strip() or DEFAULT_ADMIN_ACCOUNT
return SystemSetting( return SystemSetting(
id=SETTINGS_ROW_ID, id=SETTINGS_ROW_ID,

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import uuid import uuid
from datetime import UTC, datetime from datetime import UTC, datetime
import threading
from typing import Any from typing import Any
from sqlalchemy import or_, select from sqlalchemy import or_, select
@@ -14,10 +15,30 @@ MAX_SESSION_DURATION_MS = 24 * 60 * 60 * 1000
class UserSessionMetricService: class UserSessionMetricService:
_storage_ready_lock = threading.Lock()
_storage_ready_keys: set[tuple[str, int]] = set()
def __init__(self, db: Session) -> None: def __init__(self, db: Session) -> None:
self.db = db self.db = db
@staticmethod
def _bind_cache_key(db: Session) -> tuple[str, int]:
bind = db.get_bind()
return (bind.url.render_as_string(hide_password=True), id(bind.pool))
def ensure_storage_ready(self) -> None: def ensure_storage_ready(self) -> None:
cache_key = self._bind_cache_key(self.db)
if cache_key in self._storage_ready_keys:
return
with self._storage_ready_lock:
if cache_key in self._storage_ready_keys:
return
self._ensure_storage_ready_uncached()
self._storage_ready_keys.add(cache_key)
def _ensure_storage_ready_uncached(self) -> None:
Base.metadata.create_all(bind=self.db.get_bind(), tables=[UserSessionMetric.__table__]) Base.metadata.create_all(bind=self.db.get_bind(), tables=[UserSessionMetric.__table__])
def start_session( def start_session(

View File

@@ -7,7 +7,7 @@ from sqlalchemy.pool import StaticPool
from app.db.base import Base from app.db.base import Base
from app.schemas.auth import LoginRequest from app.schemas.auth import LoginRequest
from app.schemas.settings import SettingsWrite from app.schemas.settings import SettingsWrite
from app.services.auth import AuthService from app.services.auth import AuthService, AuthenticatedUser
from app.services.employee import EmployeeService from app.services.employee import EmployeeService
from app.services.settings import SettingsService from app.services.settings import SettingsService
@@ -97,3 +97,49 @@ def test_reenabled_employee_can_login_again() -> None:
assert result.ok is True assert result.ok is True
assert result.user.username == employee.email assert result.user.username == employee.email
def test_employee_login_skips_directory_bootstrap_when_employee_exists(monkeypatch) -> None:
with build_session() as db:
service = AuthService(db)
calls: list[str] = []
class ExistingEmployee:
email = "demo@example.com"
password_hash = "hash"
employment_status = "在职"
def fail_if_bootstrapped(self) -> None:
calls.append("ensure_directory_ready")
raise AssertionError("existing employee login should not run directory bootstrap")
monkeypatch.setattr(AuthService, "_find_employee_by_email", lambda self, _: ExistingEmployee())
monkeypatch.setattr("app.services.auth.verify_password", lambda password, password_hash: True)
monkeypatch.setattr(
AuthService,
"_build_employee_user",
lambda self, employee: AuthenticatedUser(
username=employee.email,
name="Demo",
role="使用者",
department="",
position="",
grade="",
employee_no="",
manager_name="",
location="",
cost_center="",
finance_owner_name="",
risk_profile={},
role_codes=["user"],
email=employee.email,
avatar="D",
),
)
monkeypatch.setattr(EmployeeService, "ensure_directory_ready", fail_if_bootstrapped)
user = service._authenticate_employee("demo@example.com", "123456")
assert user is not None
assert user.username == "demo@example.com"
assert calls == []

View File

@@ -186,6 +186,23 @@ def test_legacy_setup_admin_password_is_migrated_to_database(monkeypatch) -> Non
assert service.verify_admin_login("setup-admin", password) is not None assert service.verify_admin_login("setup-admin", password) is not None
def test_default_admin_credentials_are_written_to_database(monkeypatch) -> None:
temp_dir = build_temp_secret_dir()
monkeypatch.setattr(admin_secret, "ADMIN_SECRET_FILE", temp_dir / "missing-admin.json")
monkeypatch.setattr(secret_box, "SECRET_KEY_FILE", temp_dir / "settings.key")
monkeypatch.setattr(Base.metadata, "create_all", lambda *args, **kwargs: None)
monkeypatch.setenv("HERMES_HOME", str(temp_dir / ".hermes"))
with build_session(temp_dir / "settings.db") as db:
service = SettingsService(db)
settings_row, secrets_row = service.ensure_settings_ready()
assert settings_row.admin_account == "admin"
assert secrets_row.admin_password_hash
assert service.verify_admin_login("admin", "admin") is not None
assert service.verify_admin_login("superadmin", "admin") is None
def test_settings_service_syncs_models_to_hermes_config(monkeypatch) -> None: def test_settings_service_syncs_models_to_hermes_config(monkeypatch) -> None:
temp_dir = build_temp_secret_dir() temp_dir = build_temp_secret_dir()
hermes_home = temp_dir / ".hermes" hermes_home = temp_dir / ".hermes"