Files
JARVIS/backend/app/routers/auth.py

111 lines
4.4 KiB
Python
Raw Normal View History

2026-03-21 10:13:29 +08:00
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
2026-03-21 10:13:29 +08:00
from app.database import get_db
from app.models.user import User
from app.schemas.auth import UserCreate, UserOut, Token
from app.services.admin_bootstrap_service import ensure_builtin_skills
2026-03-21 10:13:29 +08:00
from app.services.auth_service import verify_password, get_password_hash, create_access_token, decode_token
from app.config import settings
router = APIRouter(prefix="/api/auth", tags=["认证"])
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_db),
) -> User:
payload = decode_token(token)
if payload is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证令牌")
user_id = payload.get("sub")
if user_id is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证令牌")
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None or not user.is_active:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在或已禁用")
return user
@router.post("/register", response_model=UserOut, status_code=status.HTTP_201_CREATED)
async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
username = user_data.username.strip()
if not username:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名不能为空")
result = await db.execute(select(User).where(User.username == username))
if result.scalar_one_or_none():
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已被注册")
2026-03-21 10:13:29 +08:00
result = await db.execute(select(User).where(User.email == user_data.email))
if result.scalar_one_or_none():
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="邮箱已被注册")
2026-03-21 10:13:29 +08:00
user = User(
username=username,
2026-03-21 10:13:29 +08:00
email=user_data.email,
hashed_password=get_password_hash(user_data.password),
full_name=user_data.full_name,
)
db.add(user)
try:
await db.commit()
except IntegrityError:
await db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名或邮箱已被注册")
2026-03-21 10:13:29 +08:00
await db.refresh(user)
await ensure_builtin_skills(db)
2026-03-21 10:13:29 +08:00
return user
@router.post("/login", response_model=Token)
async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: AsyncSession = Depends(get_db)):
identifier = form_data.username.strip()
user = None
import re
if re.match(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', identifier, re.I):
result = await db.execute(select(User).where(User.id == identifier))
user = result.scalar_one_or_none()
if not user:
result = await db.execute(select(User).where(User.username == identifier))
user = result.scalar_one_or_none()
2026-03-21 10:13:29 +08:00
if not user:
result = await db.execute(select(User).where(User.email == identifier))
user = result.scalar_one_or_none()
if not user and '@' not in identifier:
escaped_identifier = (
identifier
.replace('\\', '\\\\')
.replace('%', '\\%')
.replace('_', '\\_')
)
result = await db.execute(
select(User).where(User.email.like(f"{escaped_identifier}@%", escape='\\'))
)
prefix_matches = result.scalars().all()
if len(prefix_matches) == 1:
user = prefix_matches[0]
2026-03-21 10:13:29 +08:00
if not user or not verify_password(form_data.password, user.hashed_password):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名、邮箱或密码错误")
if not user.is_active:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="用户已被禁用")
await ensure_builtin_skills(db)
2026-03-21 10:13:29 +08:00
access_token = create_access_token(data={"sub": user.id})
return Token(access_token=access_token)
@router.get("/me", response_model=UserOut)
async def get_me(current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)):
await ensure_builtin_skills(db)
2026-03-21 10:13:29 +08:00
return current_user