from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from app.database import get_db from app.models.user import User from app.schemas.auth import UserCreate, UserOut, Token from app.services.auth_service import verify_password, get_password_hash, create_access_token, decode_token from app.config import settings router = APIRouter(prefix="/api/auth", tags=["认证"]) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login") async def get_current_user( token: str = Depends(oauth2_scheme), db: AsyncSession = Depends(get_db), ) -> User: payload = decode_token(token) if payload is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证令牌") user_id = payload.get("sub") if user_id is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证令牌") result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if user is None or not user.is_active: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在或已禁用") return user @router.post("/register", response_model=UserOut, status_code=status.HTTP_201_CREATED) async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)): # 检查邮箱是否已存在 result = await db.execute(select(User).where(User.email == user_data.email)) if result.scalar_one_or_none(): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="邮箱已被注册") # 创建用户 user = User( email=user_data.email, hashed_password=get_password_hash(user_data.password), full_name=user_data.full_name, ) db.add(user) await db.commit() await db.refresh(user) return user @router.post("/login", response_model=Token) async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: AsyncSession = Depends(get_db)): identifier = form_data.username.strip() # 支持:邮箱 / UUID / 用户名前缀 user = None # 1. 尝试 UUID import re if re.match(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', identifier, re.I): result = await db.execute(select(User).where(User.id == identifier)) user = result.scalar_one_or_none() # 2. 尝试邮箱 if not user: result = await db.execute(select(User).where(User.email == identifier)) user = result.scalar_one_or_none() # 3. 尝试用户名前缀(email@ 前面的部分) if not user and '@' not in identifier: result = await db.execute(select(User).where(User.email.like(f"{identifier}@%"))) user = result.scalar_one_or_none() if not user or not verify_password(form_data.password, user.hashed_password): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名、邮箱或密码错误") if not user.is_active: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="用户已被禁用") access_token = create_access_token(data={"sub": user.id}) return Token(access_token=access_token) @router.get("/me", response_model=UserOut) async def get_me(current_user: User = Depends(get_current_user)): return current_user