feat(auth): add admin bootstrap and username login
Initialize admin bootstrap settings during startup, persist username support in auth flows, and align frontend auth requests with local API behavior.
This commit is contained in:
@@ -2,6 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.auth import UserCreate, UserOut, Token
|
||||
@@ -32,18 +33,30 @@ async def get_current_user(
|
||||
|
||||
@router.post("/register", response_model=UserOut, status_code=status.HTTP_201_CREATED)
|
||||
async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
|
||||
# 检查邮箱是否已存在
|
||||
username = user_data.username.strip()
|
||||
if not username:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名不能为空")
|
||||
|
||||
result = await db.execute(select(User).where(User.username == username))
|
||||
if result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已被注册")
|
||||
|
||||
result = await db.execute(select(User).where(User.email == user_data.email))
|
||||
if result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="邮箱已被注册")
|
||||
# 创建用户
|
||||
|
||||
user = User(
|
||||
username=username,
|
||||
email=user_data.email,
|
||||
hashed_password=get_password_hash(user_data.password),
|
||||
full_name=user_data.full_name,
|
||||
)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
try:
|
||||
await db.commit()
|
||||
except IntegrityError:
|
||||
await db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名或邮箱已被注册")
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
@@ -51,24 +64,34 @@ async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
|
||||
@router.post("/login", response_model=Token)
|
||||
async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: AsyncSession = Depends(get_db)):
|
||||
identifier = form_data.username.strip()
|
||||
# 支持:邮箱 / UUID / 用户名前缀
|
||||
user = None
|
||||
|
||||
# 1. 尝试 UUID
|
||||
import re
|
||||
if re.match(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', identifier, re.I):
|
||||
result = await db.execute(select(User).where(User.id == identifier))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
# 2. 尝试邮箱
|
||||
if not user:
|
||||
result = await db.execute(select(User).where(User.username == identifier))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
result = await db.execute(select(User).where(User.email == identifier))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
# 3. 尝试用户名前缀(email@ 前面的部分)
|
||||
if not user and '@' not in identifier:
|
||||
result = await db.execute(select(User).where(User.email.like(f"{identifier}@%")))
|
||||
user = result.scalar_one_or_none()
|
||||
escaped_identifier = (
|
||||
identifier
|
||||
.replace('\\', '\\\\')
|
||||
.replace('%', '\\%')
|
||||
.replace('_', '\\_')
|
||||
)
|
||||
result = await db.execute(
|
||||
select(User).where(User.email.like(f"{escaped_identifier}@%", escape='\\'))
|
||||
)
|
||||
prefix_matches = result.scalars().all()
|
||||
if len(prefix_matches) == 1:
|
||||
user = prefix_matches[0]
|
||||
|
||||
if not user or not verify_password(form_data.password, user.hashed_password):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名、邮箱或密码错误")
|
||||
|
||||
Reference in New Issue
Block a user