feat: add auth module with login and access control
This commit is contained in:
21
server/src/app/api/v1/endpoints/auth.py
Normal file
21
server/src/app/api/v1/endpoints/auth.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.schemas.auth import LoginRequest, LoginResponse
|
||||
from app.services.auth import AuthService
|
||||
|
||||
router = APIRouter(prefix="/auth")
|
||||
DbSession = Annotated[Session, Depends(get_db)]
|
||||
|
||||
|
||||
@router.post("/login", response_model=LoginResponse)
|
||||
def login(payload: LoginRequest, db: DbSession) -> LoginResponse:
|
||||
try:
|
||||
return AuthService(db).login(payload)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc
|
||||
@@ -1,5 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.v1.endpoints.auth import router as auth_router
|
||||
from app.api.v1.endpoints.bootstrap import router as bootstrap_router
|
||||
from app.api.v1.endpoints.employees import router as employees_router
|
||||
from app.api.v1.endpoints.health import router as health_router
|
||||
@@ -8,5 +9,6 @@ from app.api.v1.endpoints.reimbursements import router as reimbursements_router
|
||||
router = APIRouter()
|
||||
router.include_router(health_router, tags=["health"])
|
||||
router.include_router(bootstrap_router, tags=["bootstrap"])
|
||||
router.include_router(auth_router, tags=["auth"])
|
||||
router.include_router(employees_router, prefix="/employees", tags=["employees"])
|
||||
router.include_router(reimbursements_router, prefix="/reimbursements", tags=["reimbursements"])
|
||||
|
||||
53
server/src/app/core/admin_secret.py
Normal file
53
server/src/app/core/admin_secret.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.config import SERVER_DIR
|
||||
|
||||
ADMIN_SECRET_FILE = SERVER_DIR / ".secrets" / "admin.json"
|
||||
|
||||
|
||||
def read_admin_secret() -> dict[str, object] | None:
|
||||
if not ADMIN_SECRET_FILE.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
payload = json.loads(ADMIN_SECRET_FILE.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
if (
|
||||
payload
|
||||
and payload.get("algorithm") == "scrypt"
|
||||
and isinstance(payload.get("username"), str)
|
||||
and isinstance(payload.get("salt"), str)
|
||||
and isinstance(payload.get("derived_key"), str)
|
||||
):
|
||||
return payload
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def verify_admin_secret(password: str, record: dict[str, object]) -> bool:
|
||||
try:
|
||||
salt = bytes.fromhex(str(record["salt"]))
|
||||
stored_key = bytes.fromhex(str(record["derived_key"]))
|
||||
key_length = int(record.get("key_length", 64))
|
||||
n_value = int(record.get("N", 16384))
|
||||
r_value = int(record.get("r", 8))
|
||||
p_value = int(record.get("p", 1))
|
||||
except (KeyError, TypeError, ValueError):
|
||||
return False
|
||||
|
||||
derived_key = hashlib.scrypt(
|
||||
password.encode("utf-8"),
|
||||
salt=salt,
|
||||
n=n_value,
|
||||
r=r_value,
|
||||
p=p_value,
|
||||
dklen=key_length,
|
||||
)
|
||||
return secrets.compare_digest(derived_key, stored_key)
|
||||
24
server/src/app/schemas/auth.py
Normal file
24
server/src/app/schemas/auth.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
username: str = Field(min_length=1, max_length=255)
|
||||
password: str = Field(min_length=1, max_length=128)
|
||||
|
||||
|
||||
class AuthUserRead(BaseModel):
|
||||
username: str
|
||||
name: str
|
||||
role: str
|
||||
roleCodes: list[str] = Field(default_factory=list)
|
||||
email: EmailStr | str
|
||||
avatar: str
|
||||
isAdmin: bool = False
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
ok: bool = True
|
||||
detail: str = "登录成功。"
|
||||
user: AuthUserRead
|
||||
144
server/src/app/services/auth.py
Normal file
144
server/src/app/services/auth.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from app.core.admin_secret import read_admin_secret, verify_admin_secret
|
||||
from app.core.config import get_settings
|
||||
from app.core.logging import get_logger
|
||||
from app.core.security import verify_password
|
||||
from app.models.employee import Employee
|
||||
from app.schemas.auth import AuthUserRead, LoginRequest, LoginResponse
|
||||
from app.services.employee import EmployeeService
|
||||
from app.services.employee_seed import ROLE_DISPLAY_ORDER
|
||||
|
||||
logger = get_logger("app.services.auth")
|
||||
|
||||
ROLE_LABELS = {
|
||||
"manager": "管理员",
|
||||
"finance": "财务人员",
|
||||
"executive": "高级管理人员",
|
||||
"approver": "审批负责人",
|
||||
"auditor": "审计观察员",
|
||||
"user": "使用者",
|
||||
}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AuthenticatedUser:
|
||||
username: str
|
||||
name: str
|
||||
role: str
|
||||
role_codes: list[str]
|
||||
email: str
|
||||
avatar: str
|
||||
is_admin: bool = False
|
||||
|
||||
|
||||
class AuthService:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
self.settings = get_settings()
|
||||
|
||||
def login(self, payload: LoginRequest) -> LoginResponse:
|
||||
identifier = payload.username.strip()
|
||||
password = payload.password
|
||||
|
||||
admin_user = self._authenticate_admin(identifier, password)
|
||||
if admin_user is not None:
|
||||
logger.info("Admin login succeeded identifier=%s", identifier)
|
||||
return LoginResponse(user=self._serialize_user(admin_user))
|
||||
|
||||
employee_user = self._authenticate_employee(identifier, password)
|
||||
if employee_user is not None:
|
||||
logger.info("Employee login succeeded identifier=%s role_codes=%s", identifier, ",".join(employee_user.role_codes))
|
||||
return LoginResponse(user=self._serialize_user(employee_user))
|
||||
|
||||
logger.warning("Login failed identifier=%s", identifier)
|
||||
raise ValueError("账号或密码错误。")
|
||||
|
||||
def _authenticate_admin(self, identifier: str, password: str) -> AuthenticatedUser | None:
|
||||
record = read_admin_secret()
|
||||
if record is None:
|
||||
return None
|
||||
|
||||
admin_username = str(record.get("username", "")).strip()
|
||||
admin_email = str(self.settings.admin_email or "").strip()
|
||||
normalized_identifier = identifier.casefold()
|
||||
|
||||
allowed_identifiers = {
|
||||
value.casefold()
|
||||
for value in [admin_username, admin_email]
|
||||
if value
|
||||
}
|
||||
|
||||
if normalized_identifier not in allowed_identifiers:
|
||||
return None
|
||||
|
||||
if not verify_admin_secret(password, record):
|
||||
return None
|
||||
|
||||
display_name = admin_username or admin_email or "系统管理员"
|
||||
return AuthenticatedUser(
|
||||
username=admin_username or admin_email,
|
||||
name=display_name,
|
||||
role="管理员",
|
||||
role_codes=["manager"],
|
||||
email=admin_email or f"{admin_username}@local",
|
||||
avatar=display_name[:1].upper(),
|
||||
is_admin=True,
|
||||
)
|
||||
|
||||
def _authenticate_employee(self, identifier: str, password: str) -> AuthenticatedUser | None:
|
||||
if not self.settings.setup_completed:
|
||||
return None
|
||||
|
||||
EmployeeService(self.db).ensure_directory_ready()
|
||||
|
||||
stmt = (
|
||||
select(Employee)
|
||||
.options(selectinload(Employee.roles))
|
||||
.where(func.lower(Employee.email) == identifier.lower())
|
||||
)
|
||||
employee = self.db.execute(stmt).scalars().first()
|
||||
|
||||
if employee is None or not employee.password_hash:
|
||||
return None
|
||||
|
||||
if employee.employment_status == "停用":
|
||||
logger.warning("Disabled employee login blocked identifier=%s", identifier)
|
||||
return None
|
||||
|
||||
if not verify_password(password, employee.password_hash):
|
||||
return None
|
||||
|
||||
sorted_roles = sorted(
|
||||
list(employee.roles),
|
||||
key=lambda item: (ROLE_DISPLAY_ORDER.get(item.role_code, 999), item.name),
|
||||
)
|
||||
role_codes = [role.role_code for role in sorted_roles]
|
||||
primary_role_code = role_codes[0] if role_codes else "user"
|
||||
|
||||
return AuthenticatedUser(
|
||||
username=employee.email,
|
||||
name=employee.name,
|
||||
role=ROLE_LABELS.get(primary_role_code, "使用者"),
|
||||
role_codes=role_codes or ["user"],
|
||||
email=employee.email,
|
||||
avatar=(employee.name or "?")[:1].upper(),
|
||||
is_admin=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _serialize_user(user: AuthenticatedUser) -> AuthUserRead:
|
||||
return AuthUserRead(
|
||||
username=user.username,
|
||||
name=user.name,
|
||||
role=user.role,
|
||||
roleCodes=user.role_codes,
|
||||
email=user.email,
|
||||
avatar=user.avatar,
|
||||
isAdmin=user.is_admin,
|
||||
)
|
||||
@@ -36,6 +36,7 @@ from app.services.employee_seed import (
|
||||
)
|
||||
|
||||
logger = get_logger("app.services.employee")
|
||||
DEFAULT_EMPLOYEE_PASSWORD = "123456"
|
||||
|
||||
STATUS_TONE_MAP = {
|
||||
"在职": "success",
|
||||
@@ -150,6 +151,7 @@ class EmployeeService:
|
||||
employment_status=payload.employment_status,
|
||||
sync_state=payload.sync_state,
|
||||
spotlight=payload.spotlight,
|
||||
password_hash=hash_password(DEFAULT_EMPLOYEE_PASSWORD),
|
||||
last_sync_at=datetime.now(),
|
||||
)
|
||||
|
||||
@@ -432,6 +434,9 @@ class EmployeeService:
|
||||
if employee.manager_id is None and manager_employee_no:
|
||||
employee.manager = employees_by_no.get(manager_employee_no)
|
||||
|
||||
if not employee.password_hash:
|
||||
employee.password_hash = hash_password(DEFAULT_EMPLOYEE_PASSWORD)
|
||||
|
||||
if not employee.roles:
|
||||
employee.roles = self._sorted_roles(
|
||||
[
|
||||
|
||||
72
server/tests/test_auth_service.py
Normal file
72
server/tests/test_auth_service.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.db.base import Base
|
||||
from app.schemas.auth import LoginRequest
|
||||
from app.services.auth import AuthService
|
||||
from app.services.employee import EmployeeService
|
||||
|
||||
|
||||
def build_session() -> Session:
|
||||
engine = create_engine(
|
||||
"sqlite+pysqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||||
return session_factory()
|
||||
|
||||
|
||||
def test_employee_can_login_with_seed_default_password() -> None:
|
||||
with build_session() as db:
|
||||
employee = EmployeeService(db).list_employees()[0]
|
||||
result = AuthService(db).login(
|
||||
LoginRequest(username=employee.email, password="123456")
|
||||
)
|
||||
|
||||
assert result.ok is True
|
||||
assert result.user.username == employee.email
|
||||
assert result.user.name == employee.name
|
||||
assert result.user.roleCodes
|
||||
assert result.user.isAdmin is False
|
||||
|
||||
|
||||
def test_admin_can_login_with_secret(monkeypatch) -> None:
|
||||
with build_session() as db:
|
||||
monkeypatch.setattr(
|
||||
"app.services.auth.read_admin_secret",
|
||||
lambda: {
|
||||
"username": "superadmin",
|
||||
"algorithm": "scrypt",
|
||||
"salt": "00",
|
||||
"derived_key": "00",
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr("app.services.auth.verify_admin_secret", lambda password, record: password == "admin123")
|
||||
|
||||
result = AuthService(db).login(
|
||||
LoginRequest(username="superadmin", password="admin123")
|
||||
)
|
||||
|
||||
assert result.ok is True
|
||||
assert result.user.username == "superadmin"
|
||||
assert result.user.isAdmin is True
|
||||
assert result.user.roleCodes == ["manager"]
|
||||
|
||||
|
||||
def test_disabled_employee_cannot_login() -> None:
|
||||
with build_session() as db:
|
||||
service = EmployeeService(db)
|
||||
employee = service.list_employees()[0]
|
||||
service.disable_employee(employee.id)
|
||||
|
||||
try:
|
||||
AuthService(db).login(LoginRequest(username=employee.email, password="123456"))
|
||||
except ValueError as exc:
|
||||
assert "账号或密码错误" in str(exc)
|
||||
else:
|
||||
raise AssertionError("disabled employee login should be rejected")
|
||||
Reference in New Issue
Block a user