from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from sqlalchemy.exc import IntegrityError from app.database import get_db from app.models.user import User from app.schemas.auth import UserCreate, UserOut, Token from app.services.auth_service import verify_password, get_password_hash, create_access_token, decode_token from app.config import settings router = APIRouter(prefix="/api/auth", tags=["认证"]) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login") async def get_current_user( token: str = Depends(oauth2_scheme), db: AsyncSession = Depends(get_db), ) -> User: payload = decode_token(token) if payload is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证令牌") user_id = payload.get("sub") if user_id is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证令牌") result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if user is None or not user.is_active: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在或已禁用") return user @router.post("/register", response_model=UserOut, status_code=status.HTTP_201_CREATED) async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)): username = user_data.username.strip() if not username: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名不能为空") result = await db.execute(select(User).where(User.username == username)) if result.scalar_one_or_none(): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已被注册") result = await db.execute(select(User).where(User.email == user_data.email)) if result.scalar_one_or_none(): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="邮箱已被注册") user = User( username=username, email=user_data.email, hashed_password=get_password_hash(user_data.password), full_name=user_data.full_name, ) db.add(user) try: await db.commit() except IntegrityError: await db.rollback() raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名或邮箱已被注册") await db.refresh(user) return user @router.post("/login", response_model=Token) async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: AsyncSession = Depends(get_db)): identifier = form_data.username.strip() user = None import re if re.match(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', identifier, re.I): result = await db.execute(select(User).where(User.id == identifier)) user = result.scalar_one_or_none() if not user: result = await db.execute(select(User).where(User.username == identifier)) user = result.scalar_one_or_none() if not user: result = await db.execute(select(User).where(User.email == identifier)) user = result.scalar_one_or_none() if not user and '@' not in identifier: escaped_identifier = ( identifier .replace('\\', '\\\\') .replace('%', '\\%') .replace('_', '\\_') ) result = await db.execute( select(User).where(User.email.like(f"{escaped_identifier}@%", escape='\\')) ) prefix_matches = result.scalars().all() if len(prefix_matches) == 1: user = prefix_matches[0] if not user or not verify_password(form_data.password, user.hashed_password): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名、邮箱或密码错误") if not user.is_active: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="用户已被禁用") access_token = create_access_token(data={"sub": user.id}) return Token(access_token=access_token) @router.get("/me", response_model=UserOut) async def get_me(current_user: User = Depends(get_current_user)): return current_user