Files
JARVIS/backend/app/routers/auth.py

111 lines
4.4 KiB
Python

from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from app.database import get_db
from app.models.user import User
from app.schemas.auth import UserCreate, UserOut, Token
from app.services.admin_bootstrap_service import ensure_builtin_skills
from app.services.auth_service import verify_password, get_password_hash, create_access_token, decode_token
from app.config import settings
router = APIRouter(prefix="/api/auth", tags=["认证"])
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_db),
) -> User:
payload = decode_token(token)
if payload is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证令牌")
user_id = payload.get("sub")
if user_id is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证令牌")
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None or not user.is_active:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在或已禁用")
return user
@router.post("/register", response_model=UserOut, status_code=status.HTTP_201_CREATED)
async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
username = user_data.username.strip()
if not username:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名不能为空")
result = await db.execute(select(User).where(User.username == username))
if result.scalar_one_or_none():
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已被注册")
result = await db.execute(select(User).where(User.email == user_data.email))
if result.scalar_one_or_none():
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="邮箱已被注册")
user = User(
username=username,
email=user_data.email,
hashed_password=get_password_hash(user_data.password),
full_name=user_data.full_name,
)
db.add(user)
try:
await db.commit()
except IntegrityError:
await db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名或邮箱已被注册")
await db.refresh(user)
await ensure_builtin_skills(db)
return user
@router.post("/login", response_model=Token)
async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: AsyncSession = Depends(get_db)):
identifier = form_data.username.strip()
user = None
import re
if re.match(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', identifier, re.I):
result = await db.execute(select(User).where(User.id == identifier))
user = result.scalar_one_or_none()
if not user:
result = await db.execute(select(User).where(User.username == identifier))
user = result.scalar_one_or_none()
if not user:
result = await db.execute(select(User).where(User.email == identifier))
user = result.scalar_one_or_none()
if not user and '@' not in identifier:
escaped_identifier = (
identifier
.replace('\\', '\\\\')
.replace('%', '\\%')
.replace('_', '\\_')
)
result = await db.execute(
select(User).where(User.email.like(f"{escaped_identifier}@%", escape='\\'))
)
prefix_matches = result.scalars().all()
if len(prefix_matches) == 1:
user = prefix_matches[0]
if not user or not verify_password(form_data.password, user.hashed_password):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名、邮箱或密码错误")
if not user.is_active:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="用户已被禁用")
await ensure_builtin_skills(db)
access_token = create_access_token(data={"sub": user.id})
return Token(access_token=access_token)
@router.get("/me", response_model=UserOut)
async def get_me(current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)):
await ensure_builtin_skills(db)
return current_user