Files
JARVIS/backend/app/routers/auth.py

84 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.database import get_db
from app.models.user import User
from app.schemas.auth import UserCreate, UserOut, Token
from app.services.auth_service import verify_password, get_password_hash, create_access_token, decode_token
from app.config import settings
router = APIRouter(prefix="/api/auth", tags=["认证"])
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_db),
) -> User:
payload = decode_token(token)
if payload is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证令牌")
user_id = payload.get("sub")
if user_id is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证令牌")
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None or not user.is_active:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在或已禁用")
return user
@router.post("/register", response_model=UserOut, status_code=status.HTTP_201_CREATED)
async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
# 检查邮箱是否已存在
result = await db.execute(select(User).where(User.email == user_data.email))
if result.scalar_one_or_none():
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="邮箱已被注册")
# 创建用户
user = User(
email=user_data.email,
hashed_password=get_password_hash(user_data.password),
full_name=user_data.full_name,
)
db.add(user)
await db.commit()
await db.refresh(user)
return user
@router.post("/login", response_model=Token)
async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: AsyncSession = Depends(get_db)):
identifier = form_data.username.strip()
# 支持:邮箱 / UUID / 用户名前缀
user = None
# 1. 尝试 UUID
import re
if re.match(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', identifier, re.I):
result = await db.execute(select(User).where(User.id == identifier))
user = result.scalar_one_or_none()
# 2. 尝试邮箱
if not user:
result = await db.execute(select(User).where(User.email == identifier))
user = result.scalar_one_or_none()
# 3. 尝试用户名前缀email@ 前面的部分)
if not user and '@' not in identifier:
result = await db.execute(select(User).where(User.email.like(f"{identifier}@%")))
user = result.scalar_one_or_none()
if not user or not verify_password(form_data.password, user.hashed_password):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名、邮箱或密码错误")
if not user.is_active:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="用户已被禁用")
access_token = create_access_token(data={"sub": user.id})
return Token(access_token=access_token)
@router.get("/me", response_model=UserOut)
async def get_me(current_user: User = Depends(get_current_user)):
return current_user