Initialize admin bootstrap settings during startup, persist username support in auth flows, and align frontend auth requests with local API behavior.
107 lines
4.2 KiB
Python
107 lines
4.2 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
from sqlalchemy.exc import IntegrityError
|
|
from app.database import get_db
|
|
from app.models.user import User
|
|
from app.schemas.auth import UserCreate, UserOut, Token
|
|
from app.services.auth_service import verify_password, get_password_hash, create_access_token, decode_token
|
|
from app.config import settings
|
|
|
|
router = APIRouter(prefix="/api/auth", tags=["认证"])
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
|
|
|
|
|
|
async def get_current_user(
|
|
token: str = Depends(oauth2_scheme),
|
|
db: AsyncSession = Depends(get_db),
|
|
) -> User:
|
|
payload = decode_token(token)
|
|
if payload is None:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证令牌")
|
|
user_id = payload.get("sub")
|
|
if user_id is None:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证令牌")
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if user is None or not user.is_active:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在或已禁用")
|
|
return user
|
|
|
|
|
|
@router.post("/register", response_model=UserOut, status_code=status.HTTP_201_CREATED)
|
|
async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
|
|
username = user_data.username.strip()
|
|
if not username:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名不能为空")
|
|
|
|
result = await db.execute(select(User).where(User.username == username))
|
|
if result.scalar_one_or_none():
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已被注册")
|
|
|
|
result = await db.execute(select(User).where(User.email == user_data.email))
|
|
if result.scalar_one_or_none():
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="邮箱已被注册")
|
|
|
|
user = User(
|
|
username=username,
|
|
email=user_data.email,
|
|
hashed_password=get_password_hash(user_data.password),
|
|
full_name=user_data.full_name,
|
|
)
|
|
db.add(user)
|
|
try:
|
|
await db.commit()
|
|
except IntegrityError:
|
|
await db.rollback()
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名或邮箱已被注册")
|
|
await db.refresh(user)
|
|
return user
|
|
|
|
|
|
@router.post("/login", response_model=Token)
|
|
async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: AsyncSession = Depends(get_db)):
|
|
identifier = form_data.username.strip()
|
|
user = None
|
|
|
|
import re
|
|
if re.match(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', identifier, re.I):
|
|
result = await db.execute(select(User).where(User.id == identifier))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if not user:
|
|
result = await db.execute(select(User).where(User.username == identifier))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if not user:
|
|
result = await db.execute(select(User).where(User.email == identifier))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if not user and '@' not in identifier:
|
|
escaped_identifier = (
|
|
identifier
|
|
.replace('\\', '\\\\')
|
|
.replace('%', '\\%')
|
|
.replace('_', '\\_')
|
|
)
|
|
result = await db.execute(
|
|
select(User).where(User.email.like(f"{escaped_identifier}@%", escape='\\'))
|
|
)
|
|
prefix_matches = result.scalars().all()
|
|
if len(prefix_matches) == 1:
|
|
user = prefix_matches[0]
|
|
|
|
if not user or not verify_password(form_data.password, user.hashed_password):
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名、邮箱或密码错误")
|
|
if not user.is_active:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="用户已被禁用")
|
|
access_token = create_access_token(data={"sub": user.id})
|
|
return Token(access_token=access_token)
|
|
|
|
|
|
@router.get("/me", response_model=UserOut)
|
|
async def get_me(current_user: User = Depends(get_current_user)):
|
|
return current_user
|