feat(auth): add admin bootstrap and username login
Initialize admin bootstrap settings during startup, persist username support in auth flows, and align frontend auth requests with local API behavior.
This commit is contained in:
@@ -16,10 +16,18 @@ DATA_DIR=./data
|
||||
CHROMA_PERSIST_DIR=./data/chroma
|
||||
UPLOAD_DIR=./data/uploads
|
||||
MAX_UPLOAD_SIZE=52428800
|
||||
# Supported values: ch | en
|
||||
MINERU_LANGUAGE=ch
|
||||
|
||||
# === JWT ===
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=1440
|
||||
|
||||
# === 管理员账号 Bootstrap ===
|
||||
ADMIN=admin
|
||||
ADMIN_EMAIL=admin@example.com
|
||||
ADMIN_PASSWORD=
|
||||
ADMIN_FULL_NAME=Administrator
|
||||
|
||||
# === 定时任务 ===
|
||||
SCHEDULER_ENABLED=true
|
||||
DAILY_PLAN_TIME=00:00
|
||||
|
||||
@@ -3,15 +3,15 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from typing import Literal
|
||||
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent
|
||||
ENV_FILE = BASE_DIR / ".env"
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
ENV_FILE = REPO_ROOT / ".env"
|
||||
|
||||
|
||||
def _resolve_path(value: str) -> str:
|
||||
path = Path(value)
|
||||
if path.is_absolute():
|
||||
return str(path)
|
||||
return str((BASE_DIR / path).resolve())
|
||||
return str((REPO_ROOT / path).resolve())
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
@@ -31,10 +31,10 @@ class Settings(BaseSettings):
|
||||
|
||||
# === 数据库 ===
|
||||
DATABASE_URL: str = "sqlite+aiosqlite:///./data/jarvis.db"
|
||||
DATA_DIR: str = "./data"
|
||||
DATA_DIR: str = "data"
|
||||
|
||||
# === ChromaDB ===
|
||||
CHROMA_PERSIST_DIR: str = "./data/chroma"
|
||||
CHROMA_PERSIST_DIR: str = "data/chroma"
|
||||
|
||||
# === LLM 配置 ===
|
||||
# 支持: openai / claude / ollama / deepseek / custom
|
||||
@@ -63,8 +63,15 @@ class Settings(BaseSettings):
|
||||
CORS_ORIGINS: list[str] = ["http://localhost:5173", "http://localhost:3000"]
|
||||
|
||||
# === 文件上传 ===
|
||||
UPLOAD_DIR: str = "./data/uploads"
|
||||
UPLOAD_DIR: str = "data/uploads"
|
||||
MAX_UPLOAD_SIZE: int = 50 * 1024 * 1024
|
||||
MINERU_LANGUAGE: Literal["ch", "en"] = "ch"
|
||||
|
||||
# === 管理员 bootstrap ===
|
||||
ADMIN: str = ""
|
||||
ADMIN_EMAIL: str = ""
|
||||
ADMIN_PASSWORD: str = ""
|
||||
ADMIN_FULL_NAME: str = "Administrator"
|
||||
|
||||
# === 向量化 ===
|
||||
EMBEDDING_MODEL: str = "text-embedding-3-small"
|
||||
|
||||
@@ -3,6 +3,7 @@ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sess
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from app.config import settings
|
||||
import os
|
||||
import re
|
||||
|
||||
os.makedirs(settings.DATA_DIR, exist_ok=True)
|
||||
|
||||
@@ -37,6 +38,8 @@ async def init_db():
|
||||
await ensure_log_columns(conn)
|
||||
await ensure_message_columns(conn)
|
||||
await ensure_document_columns(conn)
|
||||
await ensure_user_columns(conn)
|
||||
await ensure_forum_columns(conn)
|
||||
|
||||
|
||||
async def ensure_log_columns(conn):
|
||||
@@ -93,3 +96,93 @@ async def ensure_document_columns(conn):
|
||||
for column, ddl in required_columns.items():
|
||||
if column not in columns:
|
||||
await conn.execute(text(ddl))
|
||||
|
||||
|
||||
async def ensure_user_columns(conn):
|
||||
rows = await _get_table_info(conn, 'users')
|
||||
if not rows:
|
||||
return
|
||||
|
||||
columns = {row[1] for row in rows}
|
||||
if 'username' not in columns:
|
||||
await conn.execute(text("ALTER TABLE users ADD COLUMN username VARCHAR(255)"))
|
||||
rows = await _get_table_info(conn, 'users')
|
||||
|
||||
await _backfill_usernames(conn)
|
||||
|
||||
username_row = next(row for row in rows if row[1] == 'username')
|
||||
indexes = await _get_index_info(conn, 'users')
|
||||
has_username_index = any(row[1] == 'ix_users_username' and row[2] == 1 for row in indexes)
|
||||
has_email_index = any(row[1] == 'ix_users_email' and row[2] == 1 for row in indexes)
|
||||
|
||||
if username_row[3] != 1 or not has_username_index or not has_email_index:
|
||||
await _rebuild_users_table(conn)
|
||||
|
||||
|
||||
async def ensure_forum_columns(conn):
|
||||
rows = await _get_table_info(conn, 'forum_posts')
|
||||
if not rows:
|
||||
return
|
||||
|
||||
columns = {row[1] for row in rows}
|
||||
required_columns = {
|
||||
"board": "ALTER TABLE forum_posts ADD COLUMN board VARCHAR(100) DEFAULT 'general' NOT NULL",
|
||||
"is_pinned": "ALTER TABLE forum_posts ADD COLUMN is_pinned BOOLEAN DEFAULT 0 NOT NULL",
|
||||
}
|
||||
for column, ddl in required_columns.items():
|
||||
if column not in columns:
|
||||
await conn.execute(text(ddl))
|
||||
|
||||
indexes = await _get_index_info(conn, 'forum_posts')
|
||||
index_names = {row[1] for row in indexes}
|
||||
if 'ix_forum_posts_board' not in index_names:
|
||||
await conn.execute(text("CREATE INDEX IF NOT EXISTS ix_forum_posts_board ON forum_posts (board)"))
|
||||
|
||||
|
||||
async def _backfill_usernames(conn):
|
||||
result = await conn.execute(text("SELECT id, email, username FROM users ORDER BY created_at, id"))
|
||||
users = result.fetchall()
|
||||
seen_usernames: set[str] = set()
|
||||
|
||||
for user_id, email, username in users:
|
||||
if username:
|
||||
seen_usernames.add(username)
|
||||
continue
|
||||
|
||||
base_username = _slugify_username((email or '').split('@', 1)[0])
|
||||
candidate = base_username
|
||||
suffix = 2
|
||||
while candidate in seen_usernames:
|
||||
candidate = f"{base_username}_{suffix}"
|
||||
suffix += 1
|
||||
|
||||
await conn.execute(
|
||||
text("UPDATE users SET username = :username WHERE id = :user_id AND username IS NULL"),
|
||||
{"username": candidate, "user_id": user_id},
|
||||
)
|
||||
seen_usernames.add(candidate)
|
||||
|
||||
|
||||
async def _rebuild_users_table(conn):
|
||||
await conn.execute(text("CREATE TABLE users__new (id VARCHAR(36) PRIMARY KEY, username VARCHAR(255) NOT NULL, email VARCHAR(255) NOT NULL, hashed_password VARCHAR(255) NOT NULL, full_name VARCHAR(255), is_active BOOLEAN NOT NULL DEFAULT 1, is_superuser BOOLEAN NOT NULL DEFAULT 0, llm_config JSON, scheduler_config JSON, created_at DATETIME NOT NULL, updated_at DATETIME NOT NULL)"))
|
||||
await conn.execute(text("INSERT INTO users__new (id, username, email, hashed_password, full_name, is_active, is_superuser, llm_config, scheduler_config, created_at, updated_at) SELECT id, username, email, hashed_password, full_name, COALESCE(is_active, 1), COALESCE(is_superuser, 0), llm_config, scheduler_config, COALESCE(created_at, CURRENT_TIMESTAMP), COALESCE(updated_at, CURRENT_TIMESTAMP) FROM users"))
|
||||
await conn.execute(text("DROP TABLE users"))
|
||||
await conn.execute(text("ALTER TABLE users__new RENAME TO users"))
|
||||
await conn.execute(text("CREATE UNIQUE INDEX IF NOT EXISTS ix_users_username ON users (username)"))
|
||||
await conn.execute(text("CREATE UNIQUE INDEX IF NOT EXISTS ix_users_email ON users (email)"))
|
||||
|
||||
|
||||
async def _get_table_info(conn, table_name: str):
|
||||
result = await conn.execute(text(f"PRAGMA table_info({table_name})"))
|
||||
return result.fetchall()
|
||||
|
||||
|
||||
async def _get_index_info(conn, table_name: str):
|
||||
result = await conn.execute(text(f"PRAGMA index_list({table_name})"))
|
||||
return result.fetchall()
|
||||
|
||||
|
||||
def _slugify_username(value: str) -> str:
|
||||
normalized = re.sub(r'[^a-z0-9_]+', '_', value.strip().lower())
|
||||
normalized = re.sub(r'_+', '_', normalized).strip('_')
|
||||
return normalized or 'user'
|
||||
|
||||
@@ -3,7 +3,7 @@ from fastapi import FastAPI
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from app.database import init_db
|
||||
from app.database import init_db, async_session
|
||||
import app.models # noqa: F401 - 注册所有模型
|
||||
from app.routers import (
|
||||
auth_router,
|
||||
@@ -23,6 +23,7 @@ from app.routers import (
|
||||
)
|
||||
from app.routers.scheduler import router as scheduler_router
|
||||
from app.services.scheduler_service import start_scheduler, stop_scheduler, get_scheduler_status
|
||||
from app.services.admin_bootstrap_service import ensure_admin_user
|
||||
from app.config import settings
|
||||
from app.logging_utils import (
|
||||
setup_logging,
|
||||
@@ -35,14 +36,23 @@ from app.logging_utils import (
|
||||
import os
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# 启动
|
||||
setup_logging(settings.DEBUG)
|
||||
os.makedirs(settings.DATA_DIR, exist_ok=True)
|
||||
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(settings.CHROMA_PERSIST_DIR, exist_ok=True)
|
||||
INSECURE_SECRET_KEYS = {
|
||||
'change-me-in-production',
|
||||
'change-me-to-a-random-secret-key',
|
||||
'jarvis-secret-key-change-in-production',
|
||||
}
|
||||
|
||||
|
||||
def validate_startup_security() -> None:
|
||||
if not settings.DEBUG and settings.SECRET_KEY in INSECURE_SECRET_KEYS:
|
||||
raise RuntimeError('SECRET_KEY must be changed before running with DEBUG disabled')
|
||||
|
||||
|
||||
async def run_startup() -> None:
|
||||
validate_startup_security()
|
||||
await init_db()
|
||||
async with async_session() as session:
|
||||
await ensure_admin_user(session, settings)
|
||||
await persist_system_log(
|
||||
message="application_started",
|
||||
source="app",
|
||||
@@ -50,6 +60,16 @@ async def lifespan(app: FastAPI):
|
||||
details={"version": settings.APP_VERSION},
|
||||
)
|
||||
start_scheduler()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# 启动
|
||||
setup_logging(settings.DEBUG)
|
||||
os.makedirs(settings.DATA_DIR, exist_ok=True)
|
||||
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(settings.CHROMA_PERSIST_DIR, exist_ok=True)
|
||||
await run_startup()
|
||||
yield
|
||||
# 关闭
|
||||
stop_scheduler()
|
||||
|
||||
@@ -5,6 +5,7 @@ from app.models.base import BaseModel
|
||||
class User(BaseModel):
|
||||
__tablename__ = "users"
|
||||
|
||||
username = Column(String(255), unique=True, nullable=False, index=True)
|
||||
email = Column(String(255), unique=True, nullable=False, index=True)
|
||||
hashed_password = Column(String(255), nullable=False)
|
||||
full_name = Column(String(255), nullable=True)
|
||||
|
||||
@@ -2,6 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.auth import UserCreate, UserOut, Token
|
||||
@@ -32,18 +33,30 @@ async def get_current_user(
|
||||
|
||||
@router.post("/register", response_model=UserOut, status_code=status.HTTP_201_CREATED)
|
||||
async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
|
||||
# 检查邮箱是否已存在
|
||||
username = user_data.username.strip()
|
||||
if not username:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名不能为空")
|
||||
|
||||
result = await db.execute(select(User).where(User.username == username))
|
||||
if result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已被注册")
|
||||
|
||||
result = await db.execute(select(User).where(User.email == user_data.email))
|
||||
if result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="邮箱已被注册")
|
||||
# 创建用户
|
||||
|
||||
user = User(
|
||||
username=username,
|
||||
email=user_data.email,
|
||||
hashed_password=get_password_hash(user_data.password),
|
||||
full_name=user_data.full_name,
|
||||
)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
try:
|
||||
await db.commit()
|
||||
except IntegrityError:
|
||||
await db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名或邮箱已被注册")
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
@@ -51,24 +64,34 @@ async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
|
||||
@router.post("/login", response_model=Token)
|
||||
async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: AsyncSession = Depends(get_db)):
|
||||
identifier = form_data.username.strip()
|
||||
# 支持:邮箱 / UUID / 用户名前缀
|
||||
user = None
|
||||
|
||||
# 1. 尝试 UUID
|
||||
import re
|
||||
if re.match(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', identifier, re.I):
|
||||
result = await db.execute(select(User).where(User.id == identifier))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
# 2. 尝试邮箱
|
||||
if not user:
|
||||
result = await db.execute(select(User).where(User.username == identifier))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
result = await db.execute(select(User).where(User.email == identifier))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
# 3. 尝试用户名前缀(email@ 前面的部分)
|
||||
if not user and '@' not in identifier:
|
||||
result = await db.execute(select(User).where(User.email.like(f"{identifier}@%")))
|
||||
user = result.scalar_one_or_none()
|
||||
escaped_identifier = (
|
||||
identifier
|
||||
.replace('\\', '\\\\')
|
||||
.replace('%', '\\%')
|
||||
.replace('_', '\\_')
|
||||
)
|
||||
result = await db.execute(
|
||||
select(User).where(User.email.like(f"{escaped_identifier}@%", escape='\\'))
|
||||
)
|
||||
prefix_matches = result.scalars().all()
|
||||
if len(prefix_matches) == 1:
|
||||
user = prefix_matches[0]
|
||||
|
||||
if not user or not verify_password(form_data.password, user.hashed_password):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名、邮箱或密码错误")
|
||||
|
||||
@@ -2,6 +2,7 @@ from pydantic import BaseModel, EmailStr
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
username: str
|
||||
email: EmailStr
|
||||
password: str
|
||||
full_name: str | None = None
|
||||
@@ -9,6 +10,7 @@ class UserCreate(BaseModel):
|
||||
|
||||
class UserOut(BaseModel):
|
||||
id: str
|
||||
username: str
|
||||
email: str
|
||||
full_name: str | None
|
||||
is_active: bool
|
||||
|
||||
60
backend/app/services/admin_bootstrap_service.py
Normal file
60
backend/app/services/admin_bootstrap_service.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.user import User
|
||||
from app.services.auth_service import get_password_hash
|
||||
|
||||
|
||||
def _is_bootstrap_enabled(settings) -> bool:
|
||||
return bool(settings.ADMIN.strip() and settings.ADMIN_EMAIL.strip() and settings.ADMIN_PASSWORD.strip())
|
||||
|
||||
|
||||
async def ensure_admin_user(db: AsyncSession, settings) -> None:
|
||||
if not _is_bootstrap_enabled(settings):
|
||||
return
|
||||
|
||||
result = await db.execute(
|
||||
select(User).where(
|
||||
or_(User.username == settings.ADMIN.strip(), User.email == settings.ADMIN_EMAIL.strip())
|
||||
)
|
||||
)
|
||||
existing_user = result.scalar_one_or_none()
|
||||
|
||||
if existing_user:
|
||||
if (
|
||||
existing_user.username == settings.ADMIN.strip()
|
||||
and existing_user.email == settings.ADMIN_EMAIL.strip()
|
||||
and existing_user.is_superuser
|
||||
):
|
||||
return
|
||||
raise RuntimeError('admin bootstrap identity conflict')
|
||||
|
||||
admin_user = User(
|
||||
username=settings.ADMIN.strip(),
|
||||
email=settings.ADMIN_EMAIL.strip(),
|
||||
hashed_password=get_password_hash(settings.ADMIN_PASSWORD),
|
||||
full_name=settings.ADMIN_FULL_NAME or None,
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
)
|
||||
db.add(admin_user)
|
||||
try:
|
||||
await db.commit()
|
||||
except IntegrityError:
|
||||
await db.rollback()
|
||||
result = await db.execute(
|
||||
select(User).where(
|
||||
or_(User.username == settings.ADMIN.strip(), User.email == settings.ADMIN_EMAIL.strip())
|
||||
)
|
||||
)
|
||||
existing_user = result.scalar_one_or_none()
|
||||
if (
|
||||
existing_user
|
||||
and existing_user.username == settings.ADMIN.strip()
|
||||
and existing_user.email == settings.ADMIN_EMAIL.strip()
|
||||
and existing_user.is_superuser
|
||||
):
|
||||
return
|
||||
raise
|
||||
await db.refresh(admin_user)
|
||||
@@ -0,0 +1,183 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base
|
||||
from app.models.user import User
|
||||
from app.services.auth_service import verify_password
|
||||
from app.services.admin_bootstrap_service import ensure_admin_user
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_admin_user_creates_missing_admin(tmp_path):
|
||||
db_path = tmp_path / 'test_admin_bootstrap.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
settings = SimpleNamespace(
|
||||
ADMIN='admin',
|
||||
ADMIN_EMAIL='admin@example.com',
|
||||
ADMIN_PASSWORD='secret123',
|
||||
ADMIN_FULL_NAME='Administrator',
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
await ensure_admin_user(session, settings)
|
||||
result = await session.execute(select(User).where(User.username == 'admin'))
|
||||
admin = result.scalar_one()
|
||||
|
||||
assert admin.email == 'admin@example.com'
|
||||
assert admin.full_name == 'Administrator'
|
||||
assert admin.is_active is True
|
||||
assert admin.is_superuser is True
|
||||
assert verify_password('secret123', admin.hashed_password)
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_admin_user_skips_when_target_admin_already_exists(tmp_path):
|
||||
db_path = tmp_path / 'test_admin_bootstrap_existing.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
settings = SimpleNamespace(
|
||||
ADMIN='admin',
|
||||
ADMIN_EMAIL='admin@example.com',
|
||||
ADMIN_PASSWORD='newsecret123',
|
||||
ADMIN_FULL_NAME='Administrator',
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
existing_admin = User(
|
||||
username='admin',
|
||||
email='admin@example.com',
|
||||
hashed_password='existing-hash',
|
||||
full_name='Existing Admin',
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
)
|
||||
session.add(existing_admin)
|
||||
await session.commit()
|
||||
|
||||
await ensure_admin_user(session, settings)
|
||||
result = await session.execute(select(User).where(User.username == 'admin'))
|
||||
admins = result.scalars().all()
|
||||
|
||||
assert len(admins) == 1
|
||||
assert admins[0].hashed_password == 'existing-hash'
|
||||
assert admins[0].full_name == 'Existing Admin'
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_admin_user_skips_when_bootstrap_not_enabled(tmp_path):
|
||||
db_path = tmp_path / 'test_admin_bootstrap_disabled.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
settings = SimpleNamespace(
|
||||
ADMIN='',
|
||||
ADMIN_EMAIL='admin@example.com',
|
||||
ADMIN_PASSWORD='',
|
||||
ADMIN_FULL_NAME='Administrator',
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
await ensure_admin_user(session, settings)
|
||||
result = await session.execute(select(User))
|
||||
users = result.scalars().all()
|
||||
|
||||
assert users == []
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_admin_user_raises_for_conflicting_non_admin_user(tmp_path):
|
||||
db_path = tmp_path / 'test_admin_bootstrap_conflict.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
settings = SimpleNamespace(
|
||||
ADMIN='admin',
|
||||
ADMIN_EMAIL='admin@example.com',
|
||||
ADMIN_PASSWORD='secret123',
|
||||
ADMIN_FULL_NAME='Administrator',
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
session.add(User(
|
||||
username='admin',
|
||||
email='someone@example.com',
|
||||
hashed_password='hash',
|
||||
full_name='Existing User',
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
))
|
||||
await session.commit()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await ensure_admin_user(session, settings)
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_admin_user_succeeds_when_duplicate_insert_was_created_concurrently(tmp_path):
|
||||
db_path = tmp_path / 'test_admin_bootstrap_duplicate.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
settings = SimpleNamespace(
|
||||
ADMIN='admin',
|
||||
ADMIN_EMAIL='admin@example.com',
|
||||
ADMIN_PASSWORD='secret123',
|
||||
ADMIN_FULL_NAME='Administrator',
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
duplicate_admin = User(
|
||||
username='admin',
|
||||
email='admin@example.com',
|
||||
hashed_password='existing-hash',
|
||||
full_name='Existing Admin',
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
)
|
||||
session.add(duplicate_admin)
|
||||
await session.flush()
|
||||
|
||||
original_commit = session.commit
|
||||
|
||||
async def fake_commit():
|
||||
await session.rollback()
|
||||
raise IntegrityError('insert', {}, Exception('duplicate'))
|
||||
|
||||
session.commit = fake_commit
|
||||
try:
|
||||
await ensure_admin_user(session, settings)
|
||||
finally:
|
||||
session.commit = original_commit
|
||||
|
||||
await engine.dispose()
|
||||
259
backend/tests/backend/app/test_auth_router.py
Normal file
259
backend/tests/backend/app/test_auth_router.py
Normal file
@@ -0,0 +1,259 @@
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from jose import jwt
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.config import settings
|
||||
from app.database import Base, get_db
|
||||
from app.main import app
|
||||
from app.models.user import User
|
||||
from app.services.auth_service import get_password_hash
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def auth_test_env(tmp_path):
|
||||
db_path = tmp_path / 'test_auth.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with session_factory() as session:
|
||||
username_user = User(
|
||||
username='jarvis_admin',
|
||||
email='admin@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Admin User',
|
||||
is_active=True,
|
||||
)
|
||||
prefix_user = User(
|
||||
username='other_user',
|
||||
email='jarvis_admin@example.com',
|
||||
hashed_password=get_password_hash('othersecret123'),
|
||||
full_name='Prefix User',
|
||||
is_active=True,
|
||||
)
|
||||
inactive_user = User(
|
||||
username='disabled_user',
|
||||
email='disabled@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Disabled User',
|
||||
is_active=False,
|
||||
)
|
||||
ambiguous_user_one = User(
|
||||
username='alpha_user',
|
||||
email='alice@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Alice One',
|
||||
is_active=True,
|
||||
)
|
||||
ambiguous_user_two = User(
|
||||
username='beta_user',
|
||||
email='alice@another.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Alice Two',
|
||||
is_active=True,
|
||||
)
|
||||
fallback_user = User(
|
||||
username='fallback_target',
|
||||
email='fallback@example.com',
|
||||
hashed_password=get_password_hash('fallbacksecret123'),
|
||||
full_name='Fallback User',
|
||||
is_active=True,
|
||||
)
|
||||
wildcard_user = User(
|
||||
username='wildcard_target',
|
||||
email='alice1@example.com',
|
||||
hashed_password=get_password_hash('wildsecret123'),
|
||||
full_name='Wildcard User',
|
||||
is_active=True,
|
||||
)
|
||||
session.add_all([
|
||||
username_user,
|
||||
prefix_user,
|
||||
inactive_user,
|
||||
ambiguous_user_one,
|
||||
ambiguous_user_two,
|
||||
fallback_user,
|
||||
wildcard_user,
|
||||
])
|
||||
await session.commit()
|
||||
await session.refresh(username_user)
|
||||
await session.refresh(inactive_user)
|
||||
await session.refresh(fallback_user)
|
||||
await session.refresh(wildcard_user)
|
||||
|
||||
async def override_get_db():
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
try:
|
||||
yield {
|
||||
'username_user': username_user,
|
||||
'inactive_user': inactive_user,
|
||||
'prefix_user': prefix_user,
|
||||
'fallback_user': fallback_user,
|
||||
'wildcard_user': wildcard_user,
|
||||
}
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_accepts_username_and_returns_token(auth_test_env):
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post(
|
||||
'/api/auth/login',
|
||||
data={'username': 'jarvis_admin', 'password': 'secret123'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload['token_type'] == 'bearer'
|
||||
assert payload['access_token']
|
||||
claims = jwt.decode(payload['access_token'], settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
assert claims['sub'] == auth_test_env['username_user'].id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_prefers_username_over_email_prefix(auth_test_env):
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post(
|
||||
'/api/auth/login',
|
||||
data={'username': 'jarvis_admin', 'password': 'secret123'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
claims = jwt.decode(response.json()['access_token'], settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
assert claims['sub'] == auth_test_env['username_user'].id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_rejects_wrong_password_for_username(auth_test_env):
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post(
|
||||
'/api/auth/login',
|
||||
data={'username': 'jarvis_admin', 'password': 'wrong-password'},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()['detail'] == '用户名、邮箱或密码错误'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_rejects_inactive_username_user(auth_test_env):
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post(
|
||||
'/api/auth/login',
|
||||
data={'username': 'disabled_user', 'password': 'secret123'},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()['detail'] == '用户已被禁用'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_requires_username_and_returns_it(auth_test_env):
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post(
|
||||
'/api/auth/register',
|
||||
json={
|
||||
'username': 'new_user',
|
||||
'email': 'new@example.com',
|
||||
'password': 'secret123',
|
||||
'full_name': 'New User',
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
payload = response.json()
|
||||
assert payload['username'] == 'new_user'
|
||||
assert payload['email'] == 'new@example.com'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_accepts_exact_email(auth_test_env):
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post(
|
||||
'/api/auth/login',
|
||||
data={'username': 'admin@example.com', 'password': 'secret123'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
claims = jwt.decode(response.json()['access_token'], settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
assert claims['sub'] == auth_test_env['username_user'].id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_accepts_uuid_before_other_identifier_strategies(auth_test_env):
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post(
|
||||
'/api/auth/login',
|
||||
data={'username': auth_test_env['username_user'].id, 'password': 'secret123'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
claims = jwt.decode(response.json()['access_token'], settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
assert claims['sub'] == auth_test_env['username_user'].id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_accepts_unique_email_prefix(auth_test_env):
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post(
|
||||
'/api/auth/login',
|
||||
data={'username': 'fallback', 'password': 'fallbacksecret123'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
claims = jwt.decode(response.json()['access_token'], settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
assert claims['sub'] == auth_test_env['fallback_user'].id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_rejects_ambiguous_email_prefix(auth_test_env):
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post(
|
||||
'/api/auth/login',
|
||||
data={'username': 'alice', 'password': 'secret123'},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()['detail'] == '用户名、邮箱或密码错误'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_does_not_treat_like_wildcards_as_email_prefix_patterns(auth_test_env):
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post(
|
||||
'/api/auth/login',
|
||||
data={'username': 'alice_', 'password': 'wildsecret123'},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()['detail'] == '用户名、邮箱或密码错误'
|
||||
104
backend/tests/backend/app/test_main.py
Normal file
104
backend/tests/backend/app/test_main.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import pytest
|
||||
|
||||
from app import main as main_module
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_startup_calls_admin_bootstrap_before_logging_and_scheduler(monkeypatch):
|
||||
calls = []
|
||||
|
||||
async def fake_init_db():
|
||||
calls.append('init_db')
|
||||
|
||||
class DummySessionContext:
|
||||
async def __aenter__(self):
|
||||
calls.append('open_session')
|
||||
return object()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
calls.append('close_session')
|
||||
|
||||
async def fake_ensure_admin_user(session, settings):
|
||||
calls.append('ensure_admin_user')
|
||||
|
||||
async def fake_persist_system_log(**kwargs):
|
||||
calls.append('persist_system_log')
|
||||
|
||||
def fake_start_scheduler():
|
||||
calls.append('start_scheduler')
|
||||
|
||||
monkeypatch.setattr(main_module, 'init_db', fake_init_db)
|
||||
monkeypatch.setattr(main_module, 'async_session', lambda: DummySessionContext())
|
||||
monkeypatch.setattr(main_module, 'ensure_admin_user', fake_ensure_admin_user)
|
||||
monkeypatch.setattr(main_module, 'persist_system_log', fake_persist_system_log)
|
||||
monkeypatch.setattr(main_module, 'start_scheduler', fake_start_scheduler)
|
||||
|
||||
await main_module.run_startup()
|
||||
|
||||
assert calls == [
|
||||
'init_db',
|
||||
'open_session',
|
||||
'ensure_admin_user',
|
||||
'close_session',
|
||||
'persist_system_log',
|
||||
'start_scheduler',
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_startup_stops_before_logging_and_scheduler_when_admin_bootstrap_fails(monkeypatch):
|
||||
calls = []
|
||||
|
||||
async def fake_init_db():
|
||||
calls.append('init_db')
|
||||
|
||||
class DummySessionContext:
|
||||
async def __aenter__(self):
|
||||
calls.append('open_session')
|
||||
return object()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
calls.append('close_session')
|
||||
|
||||
async def fake_ensure_admin_user(session, settings):
|
||||
calls.append('ensure_admin_user')
|
||||
raise RuntimeError('bootstrap failed')
|
||||
|
||||
async def fake_persist_system_log(**kwargs):
|
||||
calls.append('persist_system_log')
|
||||
|
||||
def fake_start_scheduler():
|
||||
calls.append('start_scheduler')
|
||||
|
||||
monkeypatch.setattr(main_module, 'init_db', fake_init_db)
|
||||
monkeypatch.setattr(main_module, 'async_session', lambda: DummySessionContext())
|
||||
monkeypatch.setattr(main_module, 'ensure_admin_user', fake_ensure_admin_user)
|
||||
monkeypatch.setattr(main_module, 'persist_system_log', fake_persist_system_log)
|
||||
monkeypatch.setattr(main_module, 'start_scheduler', fake_start_scheduler)
|
||||
|
||||
with pytest.raises(RuntimeError, match='bootstrap failed'):
|
||||
await main_module.run_startup()
|
||||
|
||||
assert calls == [
|
||||
'init_db',
|
||||
'open_session',
|
||||
'ensure_admin_user',
|
||||
'close_session',
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_startup_rejects_default_secret_key_when_debug_is_disabled(monkeypatch):
|
||||
calls = []
|
||||
|
||||
async def fake_init_db():
|
||||
calls.append('init_db')
|
||||
|
||||
monkeypatch.setattr(main_module, 'init_db', fake_init_db)
|
||||
monkeypatch.setattr(main_module.settings, 'DEBUG', False)
|
||||
monkeypatch.setattr(main_module.settings, 'SECRET_KEY', 'change-me-in-production')
|
||||
|
||||
with pytest.raises(RuntimeError, match='SECRET_KEY'):
|
||||
await main_module.run_startup()
|
||||
|
||||
assert calls == []
|
||||
@@ -3,7 +3,7 @@ import axios from 'axios'
|
||||
let redirectingToLogin = false
|
||||
|
||||
const api = axios.create({
|
||||
baseURL: import.meta.env.VITE_API_URL,
|
||||
baseURL: import.meta.env.DEV ? '' : import.meta.env.VITE_API_URL,
|
||||
timeout: 30000,
|
||||
})
|
||||
|
||||
|
||||
@@ -6,16 +6,16 @@ let unauthorizedListenerRegistered = false
|
||||
|
||||
export const useAuthStore = defineStore('auth', () => {
|
||||
const token = ref<string | null>(localStorage.getItem('access_token'))
|
||||
const user = ref<{ id: string; email: string; full_name?: string } | null>(null)
|
||||
const user = ref<{ id: string; username: string; email: string; full_name?: string } | null>(null)
|
||||
const isAuthReady = ref(false)
|
||||
const isFetchingUser = ref(false)
|
||||
let authReadyPromise: Promise<void> | null = null
|
||||
|
||||
const isAuthenticated = computed(() => !!token.value)
|
||||
|
||||
async function login(email: string, password: string) {
|
||||
async function login(identifier: string, password: string) {
|
||||
const formData = new FormData()
|
||||
formData.append('username', email)
|
||||
formData.append('username', identifier)
|
||||
formData.append('password', password)
|
||||
const response = await api.post('/api/auth/login', formData, {
|
||||
headers: { 'Content-Type': 'multipart/form-data' },
|
||||
|
||||
Reference in New Issue
Block a user