From a3aa15d3393151938a8707917b4c3f24fa072086 Mon Sep 17 00:00:00 2001 From: caoxiaozhu Date: Tue, 24 Mar 2026 15:07:19 +0800 Subject: [PATCH] feat(auth): add admin bootstrap and username login Initialize admin bootstrap settings during startup, persist username support in auth flows, and align frontend auth requests with local API behavior. --- backend/.env.example | 8 + backend/app/config.py | 19 +- backend/app/database.py | 93 +++++++ backend/app/main.py | 36 ++- backend/app/models/user.py | 1 + backend/app/routers/auth.py | 41 ++- backend/app/schemas/auth.py | 2 + .../app/services/admin_bootstrap_service.py | 60 ++++ .../services/test_admin_bootstrap_service.py | 183 +++++++++++++ backend/tests/backend/app/test_auth_router.py | 259 ++++++++++++++++++ backend/tests/backend/app/test_main.py | 104 +++++++ frontend/src/api/index.ts | 2 +- frontend/src/stores/auth.ts | 6 +- 13 files changed, 787 insertions(+), 27 deletions(-) create mode 100644 backend/app/services/admin_bootstrap_service.py create mode 100644 backend/tests/backend/app/services/test_admin_bootstrap_service.py create mode 100644 backend/tests/backend/app/test_auth_router.py create mode 100644 backend/tests/backend/app/test_main.py diff --git a/backend/.env.example b/backend/.env.example index 1018f92..b39eb3f 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -16,10 +16,18 @@ DATA_DIR=./data CHROMA_PERSIST_DIR=./data/chroma UPLOAD_DIR=./data/uploads MAX_UPLOAD_SIZE=52428800 +# Supported values: ch | en +MINERU_LANGUAGE=ch # === JWT === ACCESS_TOKEN_EXPIRE_MINUTES=1440 +# === 管理员账号 Bootstrap === +ADMIN=admin +ADMIN_EMAIL=admin@example.com +ADMIN_PASSWORD= +ADMIN_FULL_NAME=Administrator + # === 定时任务 === SCHEDULER_ENABLED=true DAILY_PLAN_TIME=00:00 diff --git a/backend/app/config.py b/backend/app/config.py index d9cee31..5ca53a9 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -3,15 +3,15 @@ from pydantic_settings import BaseSettings, SettingsConfigDict from typing import Literal -BASE_DIR = Path(__file__).resolve().parent.parent -ENV_FILE = BASE_DIR / ".env" +REPO_ROOT = Path(__file__).resolve().parents[2] +ENV_FILE = REPO_ROOT / ".env" def _resolve_path(value: str) -> str: path = Path(value) if path.is_absolute(): return str(path) - return str((BASE_DIR / path).resolve()) + return str((REPO_ROOT / path).resolve()) class Settings(BaseSettings): @@ -31,10 +31,10 @@ class Settings(BaseSettings): # === 数据库 === DATABASE_URL: str = "sqlite+aiosqlite:///./data/jarvis.db" - DATA_DIR: str = "./data" + DATA_DIR: str = "data" # === ChromaDB === - CHROMA_PERSIST_DIR: str = "./data/chroma" + CHROMA_PERSIST_DIR: str = "data/chroma" # === LLM 配置 === # 支持: openai / claude / ollama / deepseek / custom @@ -63,8 +63,15 @@ class Settings(BaseSettings): CORS_ORIGINS: list[str] = ["http://localhost:5173", "http://localhost:3000"] # === 文件上传 === - UPLOAD_DIR: str = "./data/uploads" + UPLOAD_DIR: str = "data/uploads" MAX_UPLOAD_SIZE: int = 50 * 1024 * 1024 + MINERU_LANGUAGE: Literal["ch", "en"] = "ch" + + # === 管理员 bootstrap === + ADMIN: str = "" + ADMIN_EMAIL: str = "" + ADMIN_PASSWORD: str = "" + ADMIN_FULL_NAME: str = "Administrator" # === 向量化 === EMBEDDING_MODEL: str = "text-embedding-3-small" diff --git a/backend/app/database.py b/backend/app/database.py index fb93aed..5242634 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -3,6 +3,7 @@ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sess from sqlalchemy.orm import DeclarativeBase from app.config import settings import os +import re os.makedirs(settings.DATA_DIR, exist_ok=True) @@ -37,6 +38,8 @@ async def init_db(): await ensure_log_columns(conn) await ensure_message_columns(conn) await ensure_document_columns(conn) + await ensure_user_columns(conn) + await ensure_forum_columns(conn) async def ensure_log_columns(conn): @@ -93,3 +96,93 @@ async def ensure_document_columns(conn): for column, ddl in required_columns.items(): if column not in columns: await conn.execute(text(ddl)) + + +async def ensure_user_columns(conn): + rows = await _get_table_info(conn, 'users') + if not rows: + return + + columns = {row[1] for row in rows} + if 'username' not in columns: + await conn.execute(text("ALTER TABLE users ADD COLUMN username VARCHAR(255)")) + rows = await _get_table_info(conn, 'users') + + await _backfill_usernames(conn) + + username_row = next(row for row in rows if row[1] == 'username') + indexes = await _get_index_info(conn, 'users') + has_username_index = any(row[1] == 'ix_users_username' and row[2] == 1 for row in indexes) + has_email_index = any(row[1] == 'ix_users_email' and row[2] == 1 for row in indexes) + + if username_row[3] != 1 or not has_username_index or not has_email_index: + await _rebuild_users_table(conn) + + +async def ensure_forum_columns(conn): + rows = await _get_table_info(conn, 'forum_posts') + if not rows: + return + + columns = {row[1] for row in rows} + required_columns = { + "board": "ALTER TABLE forum_posts ADD COLUMN board VARCHAR(100) DEFAULT 'general' NOT NULL", + "is_pinned": "ALTER TABLE forum_posts ADD COLUMN is_pinned BOOLEAN DEFAULT 0 NOT NULL", + } + for column, ddl in required_columns.items(): + if column not in columns: + await conn.execute(text(ddl)) + + indexes = await _get_index_info(conn, 'forum_posts') + index_names = {row[1] for row in indexes} + if 'ix_forum_posts_board' not in index_names: + await conn.execute(text("CREATE INDEX IF NOT EXISTS ix_forum_posts_board ON forum_posts (board)")) + + +async def _backfill_usernames(conn): + result = await conn.execute(text("SELECT id, email, username FROM users ORDER BY created_at, id")) + users = result.fetchall() + seen_usernames: set[str] = set() + + for user_id, email, username in users: + if username: + seen_usernames.add(username) + continue + + base_username = _slugify_username((email or '').split('@', 1)[0]) + candidate = base_username + suffix = 2 + while candidate in seen_usernames: + candidate = f"{base_username}_{suffix}" + suffix += 1 + + await conn.execute( + text("UPDATE users SET username = :username WHERE id = :user_id AND username IS NULL"), + {"username": candidate, "user_id": user_id}, + ) + seen_usernames.add(candidate) + + +async def _rebuild_users_table(conn): + await conn.execute(text("CREATE TABLE users__new (id VARCHAR(36) PRIMARY KEY, username VARCHAR(255) NOT NULL, email VARCHAR(255) NOT NULL, hashed_password VARCHAR(255) NOT NULL, full_name VARCHAR(255), is_active BOOLEAN NOT NULL DEFAULT 1, is_superuser BOOLEAN NOT NULL DEFAULT 0, llm_config JSON, scheduler_config JSON, created_at DATETIME NOT NULL, updated_at DATETIME NOT NULL)")) + await conn.execute(text("INSERT INTO users__new (id, username, email, hashed_password, full_name, is_active, is_superuser, llm_config, scheduler_config, created_at, updated_at) SELECT id, username, email, hashed_password, full_name, COALESCE(is_active, 1), COALESCE(is_superuser, 0), llm_config, scheduler_config, COALESCE(created_at, CURRENT_TIMESTAMP), COALESCE(updated_at, CURRENT_TIMESTAMP) FROM users")) + await conn.execute(text("DROP TABLE users")) + await conn.execute(text("ALTER TABLE users__new RENAME TO users")) + await conn.execute(text("CREATE UNIQUE INDEX IF NOT EXISTS ix_users_username ON users (username)")) + await conn.execute(text("CREATE UNIQUE INDEX IF NOT EXISTS ix_users_email ON users (email)")) + + +async def _get_table_info(conn, table_name: str): + result = await conn.execute(text(f"PRAGMA table_info({table_name})")) + return result.fetchall() + + +async def _get_index_info(conn, table_name: str): + result = await conn.execute(text(f"PRAGMA index_list({table_name})")) + return result.fetchall() + + +def _slugify_username(value: str) -> str: + normalized = re.sub(r'[^a-z0-9_]+', '_', value.strip().lower()) + normalized = re.sub(r'_+', '_', normalized).strip('_') + return normalized or 'user' diff --git a/backend/app/main.py b/backend/app/main.py index 6a391c8..86aa412 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -3,7 +3,7 @@ from fastapi import FastAPI from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from starlette.exceptions import HTTPException as StarletteHTTPException -from app.database import init_db +from app.database import init_db, async_session import app.models # noqa: F401 - 注册所有模型 from app.routers import ( auth_router, @@ -23,6 +23,7 @@ from app.routers import ( ) from app.routers.scheduler import router as scheduler_router from app.services.scheduler_service import start_scheduler, stop_scheduler, get_scheduler_status +from app.services.admin_bootstrap_service import ensure_admin_user from app.config import settings from app.logging_utils import ( setup_logging, @@ -35,14 +36,23 @@ from app.logging_utils import ( import os -@asynccontextmanager -async def lifespan(app: FastAPI): - # 启动 - setup_logging(settings.DEBUG) - os.makedirs(settings.DATA_DIR, exist_ok=True) - os.makedirs(settings.UPLOAD_DIR, exist_ok=True) - os.makedirs(settings.CHROMA_PERSIST_DIR, exist_ok=True) +INSECURE_SECRET_KEYS = { + 'change-me-in-production', + 'change-me-to-a-random-secret-key', + 'jarvis-secret-key-change-in-production', +} + + +def validate_startup_security() -> None: + if not settings.DEBUG and settings.SECRET_KEY in INSECURE_SECRET_KEYS: + raise RuntimeError('SECRET_KEY must be changed before running with DEBUG disabled') + + +async def run_startup() -> None: + validate_startup_security() await init_db() + async with async_session() as session: + await ensure_admin_user(session, settings) await persist_system_log( message="application_started", source="app", @@ -50,6 +60,16 @@ async def lifespan(app: FastAPI): details={"version": settings.APP_VERSION}, ) start_scheduler() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + # 启动 + setup_logging(settings.DEBUG) + os.makedirs(settings.DATA_DIR, exist_ok=True) + os.makedirs(settings.UPLOAD_DIR, exist_ok=True) + os.makedirs(settings.CHROMA_PERSIST_DIR, exist_ok=True) + await run_startup() yield # 关闭 stop_scheduler() diff --git a/backend/app/models/user.py b/backend/app/models/user.py index 5cb26a4..3b2c1d4 100644 --- a/backend/app/models/user.py +++ b/backend/app/models/user.py @@ -5,6 +5,7 @@ from app.models.base import BaseModel class User(BaseModel): __tablename__ = "users" + username = Column(String(255), unique=True, nullable=False, index=True) email = Column(String(255), unique=True, nullable=False, index=True) hashed_password = Column(String(255), nullable=False) full_name = Column(String(255), nullable=True) diff --git a/backend/app/routers/auth.py b/backend/app/routers/auth.py index 3f6b649..0529d28 100644 --- a/backend/app/routers/auth.py +++ b/backend/app/routers/auth.py @@ -2,6 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select +from sqlalchemy.exc import IntegrityError from app.database import get_db from app.models.user import User from app.schemas.auth import UserCreate, UserOut, Token @@ -32,18 +33,30 @@ async def get_current_user( @router.post("/register", response_model=UserOut, status_code=status.HTTP_201_CREATED) async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)): - # 检查邮箱是否已存在 + username = user_data.username.strip() + if not username: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名不能为空") + + result = await db.execute(select(User).where(User.username == username)) + if result.scalar_one_or_none(): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已被注册") + result = await db.execute(select(User).where(User.email == user_data.email)) if result.scalar_one_or_none(): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="邮箱已被注册") - # 创建用户 + user = User( + username=username, email=user_data.email, hashed_password=get_password_hash(user_data.password), full_name=user_data.full_name, ) db.add(user) - await db.commit() + try: + await db.commit() + except IntegrityError: + await db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名或邮箱已被注册") await db.refresh(user) return user @@ -51,24 +64,34 @@ async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)): @router.post("/login", response_model=Token) async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: AsyncSession = Depends(get_db)): identifier = form_data.username.strip() - # 支持:邮箱 / UUID / 用户名前缀 user = None - # 1. 尝试 UUID import re if re.match(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', identifier, re.I): result = await db.execute(select(User).where(User.id == identifier)) user = result.scalar_one_or_none() - # 2. 尝试邮箱 + if not user: + result = await db.execute(select(User).where(User.username == identifier)) + user = result.scalar_one_or_none() + if not user: result = await db.execute(select(User).where(User.email == identifier)) user = result.scalar_one_or_none() - # 3. 尝试用户名前缀(email@ 前面的部分) if not user and '@' not in identifier: - result = await db.execute(select(User).where(User.email.like(f"{identifier}@%"))) - user = result.scalar_one_or_none() + escaped_identifier = ( + identifier + .replace('\\', '\\\\') + .replace('%', '\\%') + .replace('_', '\\_') + ) + result = await db.execute( + select(User).where(User.email.like(f"{escaped_identifier}@%", escape='\\')) + ) + prefix_matches = result.scalars().all() + if len(prefix_matches) == 1: + user = prefix_matches[0] if not user or not verify_password(form_data.password, user.hashed_password): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名、邮箱或密码错误") diff --git a/backend/app/schemas/auth.py b/backend/app/schemas/auth.py index 2fe7622..faf66d5 100644 --- a/backend/app/schemas/auth.py +++ b/backend/app/schemas/auth.py @@ -2,6 +2,7 @@ from pydantic import BaseModel, EmailStr class UserCreate(BaseModel): + username: str email: EmailStr password: str full_name: str | None = None @@ -9,6 +10,7 @@ class UserCreate(BaseModel): class UserOut(BaseModel): id: str + username: str email: str full_name: str | None is_active: bool diff --git a/backend/app/services/admin_bootstrap_service.py b/backend/app/services/admin_bootstrap_service.py new file mode 100644 index 0000000..84dbaa0 --- /dev/null +++ b/backend/app/services/admin_bootstrap_service.py @@ -0,0 +1,60 @@ +from sqlalchemy import or_, select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.user import User +from app.services.auth_service import get_password_hash + + +def _is_bootstrap_enabled(settings) -> bool: + return bool(settings.ADMIN.strip() and settings.ADMIN_EMAIL.strip() and settings.ADMIN_PASSWORD.strip()) + + +async def ensure_admin_user(db: AsyncSession, settings) -> None: + if not _is_bootstrap_enabled(settings): + return + + result = await db.execute( + select(User).where( + or_(User.username == settings.ADMIN.strip(), User.email == settings.ADMIN_EMAIL.strip()) + ) + ) + existing_user = result.scalar_one_or_none() + + if existing_user: + if ( + existing_user.username == settings.ADMIN.strip() + and existing_user.email == settings.ADMIN_EMAIL.strip() + and existing_user.is_superuser + ): + return + raise RuntimeError('admin bootstrap identity conflict') + + admin_user = User( + username=settings.ADMIN.strip(), + email=settings.ADMIN_EMAIL.strip(), + hashed_password=get_password_hash(settings.ADMIN_PASSWORD), + full_name=settings.ADMIN_FULL_NAME or None, + is_active=True, + is_superuser=True, + ) + db.add(admin_user) + try: + await db.commit() + except IntegrityError: + await db.rollback() + result = await db.execute( + select(User).where( + or_(User.username == settings.ADMIN.strip(), User.email == settings.ADMIN_EMAIL.strip()) + ) + ) + existing_user = result.scalar_one_or_none() + if ( + existing_user + and existing_user.username == settings.ADMIN.strip() + and existing_user.email == settings.ADMIN_EMAIL.strip() + and existing_user.is_superuser + ): + return + raise + await db.refresh(admin_user) diff --git a/backend/tests/backend/app/services/test_admin_bootstrap_service.py b/backend/tests/backend/app/services/test_admin_bootstrap_service.py new file mode 100644 index 0000000..d922170 --- /dev/null +++ b/backend/tests/backend/app/services/test_admin_bootstrap_service.py @@ -0,0 +1,183 @@ +from types import SimpleNamespace + +import pytest +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + +import app.models # noqa: F401 +from app.database import Base +from app.models.user import User +from app.services.auth_service import verify_password +from app.services.admin_bootstrap_service import ensure_admin_user + + +@pytest.mark.asyncio +async def test_ensure_admin_user_creates_missing_admin(tmp_path): + db_path = tmp_path / 'test_admin_bootstrap.db' + engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True) + session_factory = async_sessionmaker(engine, expire_on_commit=False) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + settings = SimpleNamespace( + ADMIN='admin', + ADMIN_EMAIL='admin@example.com', + ADMIN_PASSWORD='secret123', + ADMIN_FULL_NAME='Administrator', + ) + + async with session_factory() as session: + await ensure_admin_user(session, settings) + result = await session.execute(select(User).where(User.username == 'admin')) + admin = result.scalar_one() + + assert admin.email == 'admin@example.com' + assert admin.full_name == 'Administrator' + assert admin.is_active is True + assert admin.is_superuser is True + assert verify_password('secret123', admin.hashed_password) + + await engine.dispose() + + +@pytest.mark.asyncio +async def test_ensure_admin_user_skips_when_target_admin_already_exists(tmp_path): + db_path = tmp_path / 'test_admin_bootstrap_existing.db' + engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True) + session_factory = async_sessionmaker(engine, expire_on_commit=False) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + settings = SimpleNamespace( + ADMIN='admin', + ADMIN_EMAIL='admin@example.com', + ADMIN_PASSWORD='newsecret123', + ADMIN_FULL_NAME='Administrator', + ) + + async with session_factory() as session: + existing_admin = User( + username='admin', + email='admin@example.com', + hashed_password='existing-hash', + full_name='Existing Admin', + is_active=True, + is_superuser=True, + ) + session.add(existing_admin) + await session.commit() + + await ensure_admin_user(session, settings) + result = await session.execute(select(User).where(User.username == 'admin')) + admins = result.scalars().all() + + assert len(admins) == 1 + assert admins[0].hashed_password == 'existing-hash' + assert admins[0].full_name == 'Existing Admin' + + await engine.dispose() + + +@pytest.mark.asyncio +async def test_ensure_admin_user_skips_when_bootstrap_not_enabled(tmp_path): + db_path = tmp_path / 'test_admin_bootstrap_disabled.db' + engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True) + session_factory = async_sessionmaker(engine, expire_on_commit=False) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + settings = SimpleNamespace( + ADMIN='', + ADMIN_EMAIL='admin@example.com', + ADMIN_PASSWORD='', + ADMIN_FULL_NAME='Administrator', + ) + + async with session_factory() as session: + await ensure_admin_user(session, settings) + result = await session.execute(select(User)) + users = result.scalars().all() + + assert users == [] + + await engine.dispose() + + +@pytest.mark.asyncio +async def test_ensure_admin_user_raises_for_conflicting_non_admin_user(tmp_path): + db_path = tmp_path / 'test_admin_bootstrap_conflict.db' + engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True) + session_factory = async_sessionmaker(engine, expire_on_commit=False) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + settings = SimpleNamespace( + ADMIN='admin', + ADMIN_EMAIL='admin@example.com', + ADMIN_PASSWORD='secret123', + ADMIN_FULL_NAME='Administrator', + ) + + async with session_factory() as session: + session.add(User( + username='admin', + email='someone@example.com', + hashed_password='hash', + full_name='Existing User', + is_active=True, + is_superuser=False, + )) + await session.commit() + + with pytest.raises(RuntimeError): + await ensure_admin_user(session, settings) + + await engine.dispose() + + +@pytest.mark.asyncio +async def test_ensure_admin_user_succeeds_when_duplicate_insert_was_created_concurrently(tmp_path): + db_path = tmp_path / 'test_admin_bootstrap_duplicate.db' + engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True) + session_factory = async_sessionmaker(engine, expire_on_commit=False) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + settings = SimpleNamespace( + ADMIN='admin', + ADMIN_EMAIL='admin@example.com', + ADMIN_PASSWORD='secret123', + ADMIN_FULL_NAME='Administrator', + ) + + async with session_factory() as session: + duplicate_admin = User( + username='admin', + email='admin@example.com', + hashed_password='existing-hash', + full_name='Existing Admin', + is_active=True, + is_superuser=True, + ) + session.add(duplicate_admin) + await session.flush() + + original_commit = session.commit + + async def fake_commit(): + await session.rollback() + raise IntegrityError('insert', {}, Exception('duplicate')) + + session.commit = fake_commit + try: + await ensure_admin_user(session, settings) + finally: + session.commit = original_commit + + await engine.dispose() diff --git a/backend/tests/backend/app/test_auth_router.py b/backend/tests/backend/app/test_auth_router.py new file mode 100644 index 0000000..f46ecbe --- /dev/null +++ b/backend/tests/backend/app/test_auth_router.py @@ -0,0 +1,259 @@ +import pytest +from httpx import ASGITransport, AsyncClient +from jose import jwt +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + +import app.models # noqa: F401 +from app.config import settings +from app.database import Base, get_db +from app.main import app +from app.models.user import User +from app.services.auth_service import get_password_hash + + +@pytest.fixture +async def auth_test_env(tmp_path): + db_path = tmp_path / 'test_auth.db' + engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True) + session_factory = async_sessionmaker(engine, expire_on_commit=False) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + async with session_factory() as session: + username_user = User( + username='jarvis_admin', + email='admin@example.com', + hashed_password=get_password_hash('secret123'), + full_name='Admin User', + is_active=True, + ) + prefix_user = User( + username='other_user', + email='jarvis_admin@example.com', + hashed_password=get_password_hash('othersecret123'), + full_name='Prefix User', + is_active=True, + ) + inactive_user = User( + username='disabled_user', + email='disabled@example.com', + hashed_password=get_password_hash('secret123'), + full_name='Disabled User', + is_active=False, + ) + ambiguous_user_one = User( + username='alpha_user', + email='alice@example.com', + hashed_password=get_password_hash('secret123'), + full_name='Alice One', + is_active=True, + ) + ambiguous_user_two = User( + username='beta_user', + email='alice@another.com', + hashed_password=get_password_hash('secret123'), + full_name='Alice Two', + is_active=True, + ) + fallback_user = User( + username='fallback_target', + email='fallback@example.com', + hashed_password=get_password_hash('fallbacksecret123'), + full_name='Fallback User', + is_active=True, + ) + wildcard_user = User( + username='wildcard_target', + email='alice1@example.com', + hashed_password=get_password_hash('wildsecret123'), + full_name='Wildcard User', + is_active=True, + ) + session.add_all([ + username_user, + prefix_user, + inactive_user, + ambiguous_user_one, + ambiguous_user_two, + fallback_user, + wildcard_user, + ]) + await session.commit() + await session.refresh(username_user) + await session.refresh(inactive_user) + await session.refresh(fallback_user) + await session.refresh(wildcard_user) + + async def override_get_db(): + async with session_factory() as session: + yield session + + app.dependency_overrides[get_db] = override_get_db + + try: + yield { + 'username_user': username_user, + 'inactive_user': inactive_user, + 'prefix_user': prefix_user, + 'fallback_user': fallback_user, + 'wildcard_user': wildcard_user, + } + finally: + app.dependency_overrides.clear() + await engine.dispose() + + +@pytest.mark.asyncio +async def test_login_accepts_username_and_returns_token(auth_test_env): + transport = ASGITransport(app=app) + + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.post( + '/api/auth/login', + data={'username': 'jarvis_admin', 'password': 'secret123'}, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload['token_type'] == 'bearer' + assert payload['access_token'] + claims = jwt.decode(payload['access_token'], settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + assert claims['sub'] == auth_test_env['username_user'].id + + +@pytest.mark.asyncio +async def test_login_prefers_username_over_email_prefix(auth_test_env): + transport = ASGITransport(app=app) + + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.post( + '/api/auth/login', + data={'username': 'jarvis_admin', 'password': 'secret123'}, + ) + + assert response.status_code == 200 + claims = jwt.decode(response.json()['access_token'], settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + assert claims['sub'] == auth_test_env['username_user'].id + + +@pytest.mark.asyncio +async def test_login_rejects_wrong_password_for_username(auth_test_env): + transport = ASGITransport(app=app) + + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.post( + '/api/auth/login', + data={'username': 'jarvis_admin', 'password': 'wrong-password'}, + ) + + assert response.status_code == 401 + assert response.json()['detail'] == '用户名、邮箱或密码错误' + + +@pytest.mark.asyncio +async def test_login_rejects_inactive_username_user(auth_test_env): + transport = ASGITransport(app=app) + + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.post( + '/api/auth/login', + data={'username': 'disabled_user', 'password': 'secret123'}, + ) + + assert response.status_code == 403 + assert response.json()['detail'] == '用户已被禁用' + + +@pytest.mark.asyncio +async def test_register_requires_username_and_returns_it(auth_test_env): + transport = ASGITransport(app=app) + + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.post( + '/api/auth/register', + json={ + 'username': 'new_user', + 'email': 'new@example.com', + 'password': 'secret123', + 'full_name': 'New User', + }, + ) + + assert response.status_code == 201 + payload = response.json() + assert payload['username'] == 'new_user' + assert payload['email'] == 'new@example.com' + + +@pytest.mark.asyncio +async def test_login_accepts_exact_email(auth_test_env): + transport = ASGITransport(app=app) + + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.post( + '/api/auth/login', + data={'username': 'admin@example.com', 'password': 'secret123'}, + ) + + assert response.status_code == 200 + claims = jwt.decode(response.json()['access_token'], settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + assert claims['sub'] == auth_test_env['username_user'].id + + +@pytest.mark.asyncio +async def test_login_accepts_uuid_before_other_identifier_strategies(auth_test_env): + transport = ASGITransport(app=app) + + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.post( + '/api/auth/login', + data={'username': auth_test_env['username_user'].id, 'password': 'secret123'}, + ) + + assert response.status_code == 200 + claims = jwt.decode(response.json()['access_token'], settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + assert claims['sub'] == auth_test_env['username_user'].id + + +@pytest.mark.asyncio +async def test_login_accepts_unique_email_prefix(auth_test_env): + transport = ASGITransport(app=app) + + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.post( + '/api/auth/login', + data={'username': 'fallback', 'password': 'fallbacksecret123'}, + ) + + assert response.status_code == 200 + claims = jwt.decode(response.json()['access_token'], settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + assert claims['sub'] == auth_test_env['fallback_user'].id + + +@pytest.mark.asyncio +async def test_login_rejects_ambiguous_email_prefix(auth_test_env): + transport = ASGITransport(app=app) + + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.post( + '/api/auth/login', + data={'username': 'alice', 'password': 'secret123'}, + ) + + assert response.status_code == 401 + assert response.json()['detail'] == '用户名、邮箱或密码错误' + + +@pytest.mark.asyncio +async def test_login_does_not_treat_like_wildcards_as_email_prefix_patterns(auth_test_env): + transport = ASGITransport(app=app) + + async with AsyncClient(transport=transport, base_url='http://testserver') as client: + response = await client.post( + '/api/auth/login', + data={'username': 'alice_', 'password': 'wildsecret123'}, + ) + + assert response.status_code == 401 + assert response.json()['detail'] == '用户名、邮箱或密码错误' diff --git a/backend/tests/backend/app/test_main.py b/backend/tests/backend/app/test_main.py new file mode 100644 index 0000000..3431e6a --- /dev/null +++ b/backend/tests/backend/app/test_main.py @@ -0,0 +1,104 @@ +import pytest + +from app import main as main_module + + +@pytest.mark.asyncio +async def test_run_startup_calls_admin_bootstrap_before_logging_and_scheduler(monkeypatch): + calls = [] + + async def fake_init_db(): + calls.append('init_db') + + class DummySessionContext: + async def __aenter__(self): + calls.append('open_session') + return object() + + async def __aexit__(self, exc_type, exc, tb): + calls.append('close_session') + + async def fake_ensure_admin_user(session, settings): + calls.append('ensure_admin_user') + + async def fake_persist_system_log(**kwargs): + calls.append('persist_system_log') + + def fake_start_scheduler(): + calls.append('start_scheduler') + + monkeypatch.setattr(main_module, 'init_db', fake_init_db) + monkeypatch.setattr(main_module, 'async_session', lambda: DummySessionContext()) + monkeypatch.setattr(main_module, 'ensure_admin_user', fake_ensure_admin_user) + monkeypatch.setattr(main_module, 'persist_system_log', fake_persist_system_log) + monkeypatch.setattr(main_module, 'start_scheduler', fake_start_scheduler) + + await main_module.run_startup() + + assert calls == [ + 'init_db', + 'open_session', + 'ensure_admin_user', + 'close_session', + 'persist_system_log', + 'start_scheduler', + ] + + +@pytest.mark.asyncio +async def test_run_startup_stops_before_logging_and_scheduler_when_admin_bootstrap_fails(monkeypatch): + calls = [] + + async def fake_init_db(): + calls.append('init_db') + + class DummySessionContext: + async def __aenter__(self): + calls.append('open_session') + return object() + + async def __aexit__(self, exc_type, exc, tb): + calls.append('close_session') + + async def fake_ensure_admin_user(session, settings): + calls.append('ensure_admin_user') + raise RuntimeError('bootstrap failed') + + async def fake_persist_system_log(**kwargs): + calls.append('persist_system_log') + + def fake_start_scheduler(): + calls.append('start_scheduler') + + monkeypatch.setattr(main_module, 'init_db', fake_init_db) + monkeypatch.setattr(main_module, 'async_session', lambda: DummySessionContext()) + monkeypatch.setattr(main_module, 'ensure_admin_user', fake_ensure_admin_user) + monkeypatch.setattr(main_module, 'persist_system_log', fake_persist_system_log) + monkeypatch.setattr(main_module, 'start_scheduler', fake_start_scheduler) + + with pytest.raises(RuntimeError, match='bootstrap failed'): + await main_module.run_startup() + + assert calls == [ + 'init_db', + 'open_session', + 'ensure_admin_user', + 'close_session', + ] + + +@pytest.mark.asyncio +async def test_run_startup_rejects_default_secret_key_when_debug_is_disabled(monkeypatch): + calls = [] + + async def fake_init_db(): + calls.append('init_db') + + monkeypatch.setattr(main_module, 'init_db', fake_init_db) + monkeypatch.setattr(main_module.settings, 'DEBUG', False) + monkeypatch.setattr(main_module.settings, 'SECRET_KEY', 'change-me-in-production') + + with pytest.raises(RuntimeError, match='SECRET_KEY'): + await main_module.run_startup() + + assert calls == [] diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts index 0068f1c..3d2e0ba 100644 --- a/frontend/src/api/index.ts +++ b/frontend/src/api/index.ts @@ -3,7 +3,7 @@ import axios from 'axios' let redirectingToLogin = false const api = axios.create({ - baseURL: import.meta.env.VITE_API_URL, + baseURL: import.meta.env.DEV ? '' : import.meta.env.VITE_API_URL, timeout: 30000, }) diff --git a/frontend/src/stores/auth.ts b/frontend/src/stores/auth.ts index a57c3b9..94882d6 100644 --- a/frontend/src/stores/auth.ts +++ b/frontend/src/stores/auth.ts @@ -6,16 +6,16 @@ let unauthorizedListenerRegistered = false export const useAuthStore = defineStore('auth', () => { const token = ref(localStorage.getItem('access_token')) - const user = ref<{ id: string; email: string; full_name?: string } | null>(null) + const user = ref<{ id: string; username: string; email: string; full_name?: string } | null>(null) const isAuthReady = ref(false) const isFetchingUser = ref(false) let authReadyPromise: Promise | null = null const isAuthenticated = computed(() => !!token.value) - async function login(email: string, password: string) { + async function login(identifier: string, password: string) { const formData = new FormData() - formData.append('username', email) + formData.append('username', identifier) formData.append('password', password) const response = await api.post('/api/auth/login', formData, { headers: { 'Content-Type': 'multipart/form-data' },