from collections.abc import Generator from dataclasses import dataclass from typing import Annotated from fastapi import Depends, Header, HTTPException, status from sqlalchemy.orm import Session from app.db.session import get_session_factory PLATFORM_ADMIN_IDENTITIES = {"admin", "superadmin"} ADMIN_HEADER_TRUE_VALUES = {"1", "true", "yes", "on"} def get_db() -> Generator[Session, None, None]: db = get_session_factory()() try: yield db finally: db.close() @dataclass(slots=True) class CurrentUserContext: username: str name: str role_codes: list[str] is_admin: bool department_name: str = "" cost_center: str = "" position: str = "" grade: str = "" employee_no: str = "" manager_name: str = "" def get_current_user( x_auth_username: Annotated[ str | None, Header(description="当前登录用户名。知识库接口至少需要提供用户名或姓名。"), ] = None, x_auth_name: Annotated[ str | None, Header(description="当前登录人展示姓名。未传时默认回退到用户名。"), ] = None, x_auth_role_codes: Annotated[ str | None, Header(description="角色编码列表,多个角色使用英文逗号分隔,例如 `manager,finance`。"), ] = None, x_auth_is_admin: Annotated[ str | None, Header(description="是否管理员,支持 `true/false/1/0`。"), ] = None, x_auth_department: Annotated[ str | None, Header(description="当前登录人的所属部门。"), ] = None, x_auth_cost_center: Annotated[ str | None, Header(description="当前登录人的成本中心。"), ] = None, x_auth_position: Annotated[ str | None, Header(description="当前登录人的岗位。"), ] = None, x_auth_grade: Annotated[ str | None, Header(description="当前登录人的职级。"), ] = None, x_auth_employee_no: Annotated[ str | None, Header(description="当前登录人的员工编号。"), ] = None, x_auth_manager_name: Annotated[ str | None, Header(description="当前登录人的直属领导。"), ] = None, ) -> CurrentUserContext: role_codes = [ _normalize_role_code(item) for item in (x_auth_role_codes or "").split(",") if _normalize_role_code(item) ] username = (x_auth_username or "").strip() name = (x_auth_name or username).strip() is_admin = _resolve_platform_admin_flag( username=username, name=name, role_codes=role_codes, header_value=x_auth_is_admin, ) if not username and not name: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="请先登录后再访问知识库。", ) return CurrentUserContext( username=username or name, name=name or username, role_codes=role_codes, is_admin=is_admin, department_name=(x_auth_department or "").strip(), cost_center=(x_auth_cost_center or "").strip(), position=(x_auth_position or "").strip(), grade=(x_auth_grade or "").strip(), employee_no=(x_auth_employee_no or "").strip(), manager_name=(x_auth_manager_name or "").strip(), ) def _normalize_role_code(value: str | None) -> str: role_code = str(value or "").strip().lower() if role_code == "auditor": return "budget_monitor" return role_code def _current_user_role_codes(current_user: CurrentUserContext) -> set[str]: return {_normalize_role_code(item) for item in current_user.role_codes if _normalize_role_code(item)} def _resolve_platform_admin_flag( *, username: str, name: str, role_codes: list[str], header_value: str | None, ) -> bool: if str(header_value or "").strip().lower() in ADMIN_HEADER_TRUE_VALUES: return True identities = { str(username or "").strip().lower(), str(name or "").strip().lower(), } normalized_role_codes = {_normalize_role_code(item) for item in role_codes} return bool(identities & PLATFORM_ADMIN_IDENTITIES) or bool(normalized_role_codes & PLATFORM_ADMIN_IDENTITIES) def require_admin_user( current_user: Annotated[CurrentUserContext, Depends(get_current_user)], ) -> CurrentUserContext: if current_user.is_admin or "manager" in _current_user_role_codes(current_user): return current_user raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="只有管理员可以上传、删除或修改知识库文件。", ) def require_platform_admin_user( current_user: Annotated[CurrentUserContext, Depends(get_current_user)], ) -> CurrentUserContext: if current_user.is_admin: return current_user raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="只有 admin 管理员可以执行该操作。", ) def require_rule_editor_user( current_user: Annotated[CurrentUserContext, Depends(get_current_user)], ) -> CurrentUserContext: role_codes = _current_user_role_codes(current_user) if current_user.is_admin or "manager" in role_codes or "finance" in role_codes: return current_user raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="只有财务人员或高级财务人员可以编辑规则草稿。", ) def require_rule_reviewer_user( current_user: Annotated[CurrentUserContext, Depends(get_current_user)], ) -> CurrentUserContext: role_codes = _current_user_role_codes(current_user) if current_user.is_admin or "manager" in role_codes: return current_user raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="只有高级财务人员或 admin 管理员可以执行该操作。", ) def require_budget_viewer_user( current_user: Annotated[CurrentUserContext, Depends(get_current_user)], ) -> CurrentUserContext: role_codes = _current_user_role_codes(current_user) if current_user.is_admin or role_codes & {"budget_monitor", "executive"}: return current_user raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="只有预算监控员或高级财务人员可以查看预算中心。", ) def require_budget_editor_user( current_user: Annotated[CurrentUserContext, Depends(get_current_user)], ) -> CurrentUserContext: role_codes = _current_user_role_codes(current_user) if current_user.is_admin or "executive" in role_codes: return current_user raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="只有 admin 管理员或高级财务人员可以维护预算额度。", ) def is_budget_scope_limited_user(current_user: CurrentUserContext) -> bool: if current_user.is_admin: return False role_codes = _current_user_role_codes(current_user) return "budget_monitor" in role_codes and "executive" not in role_codes