feat(auth): add admin bootstrap and username login
Initialize admin bootstrap settings during startup, persist username support in auth flows, and align frontend auth requests with local API behavior.
This commit is contained in:
@@ -0,0 +1,183 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base
|
||||
from app.models.user import User
|
||||
from app.services.auth_service import verify_password
|
||||
from app.services.admin_bootstrap_service import ensure_admin_user
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_admin_user_creates_missing_admin(tmp_path):
|
||||
db_path = tmp_path / 'test_admin_bootstrap.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
settings = SimpleNamespace(
|
||||
ADMIN='admin',
|
||||
ADMIN_EMAIL='admin@example.com',
|
||||
ADMIN_PASSWORD='secret123',
|
||||
ADMIN_FULL_NAME='Administrator',
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
await ensure_admin_user(session, settings)
|
||||
result = await session.execute(select(User).where(User.username == 'admin'))
|
||||
admin = result.scalar_one()
|
||||
|
||||
assert admin.email == 'admin@example.com'
|
||||
assert admin.full_name == 'Administrator'
|
||||
assert admin.is_active is True
|
||||
assert admin.is_superuser is True
|
||||
assert verify_password('secret123', admin.hashed_password)
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_admin_user_skips_when_target_admin_already_exists(tmp_path):
|
||||
db_path = tmp_path / 'test_admin_bootstrap_existing.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
settings = SimpleNamespace(
|
||||
ADMIN='admin',
|
||||
ADMIN_EMAIL='admin@example.com',
|
||||
ADMIN_PASSWORD='newsecret123',
|
||||
ADMIN_FULL_NAME='Administrator',
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
existing_admin = User(
|
||||
username='admin',
|
||||
email='admin@example.com',
|
||||
hashed_password='existing-hash',
|
||||
full_name='Existing Admin',
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
)
|
||||
session.add(existing_admin)
|
||||
await session.commit()
|
||||
|
||||
await ensure_admin_user(session, settings)
|
||||
result = await session.execute(select(User).where(User.username == 'admin'))
|
||||
admins = result.scalars().all()
|
||||
|
||||
assert len(admins) == 1
|
||||
assert admins[0].hashed_password == 'existing-hash'
|
||||
assert admins[0].full_name == 'Existing Admin'
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_admin_user_skips_when_bootstrap_not_enabled(tmp_path):
|
||||
db_path = tmp_path / 'test_admin_bootstrap_disabled.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
settings = SimpleNamespace(
|
||||
ADMIN='',
|
||||
ADMIN_EMAIL='admin@example.com',
|
||||
ADMIN_PASSWORD='',
|
||||
ADMIN_FULL_NAME='Administrator',
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
await ensure_admin_user(session, settings)
|
||||
result = await session.execute(select(User))
|
||||
users = result.scalars().all()
|
||||
|
||||
assert users == []
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_admin_user_raises_for_conflicting_non_admin_user(tmp_path):
|
||||
db_path = tmp_path / 'test_admin_bootstrap_conflict.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
settings = SimpleNamespace(
|
||||
ADMIN='admin',
|
||||
ADMIN_EMAIL='admin@example.com',
|
||||
ADMIN_PASSWORD='secret123',
|
||||
ADMIN_FULL_NAME='Administrator',
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
session.add(User(
|
||||
username='admin',
|
||||
email='someone@example.com',
|
||||
hashed_password='hash',
|
||||
full_name='Existing User',
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
))
|
||||
await session.commit()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await ensure_admin_user(session, settings)
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_admin_user_succeeds_when_duplicate_insert_was_created_concurrently(tmp_path):
|
||||
db_path = tmp_path / 'test_admin_bootstrap_duplicate.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
settings = SimpleNamespace(
|
||||
ADMIN='admin',
|
||||
ADMIN_EMAIL='admin@example.com',
|
||||
ADMIN_PASSWORD='secret123',
|
||||
ADMIN_FULL_NAME='Administrator',
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
duplicate_admin = User(
|
||||
username='admin',
|
||||
email='admin@example.com',
|
||||
hashed_password='existing-hash',
|
||||
full_name='Existing Admin',
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
)
|
||||
session.add(duplicate_admin)
|
||||
await session.flush()
|
||||
|
||||
original_commit = session.commit
|
||||
|
||||
async def fake_commit():
|
||||
await session.rollback()
|
||||
raise IntegrityError('insert', {}, Exception('duplicate'))
|
||||
|
||||
session.commit = fake_commit
|
||||
try:
|
||||
await ensure_admin_user(session, settings)
|
||||
finally:
|
||||
session.commit = original_commit
|
||||
|
||||
await engine.dispose()
|
||||
259
backend/tests/backend/app/test_auth_router.py
Normal file
259
backend/tests/backend/app/test_auth_router.py
Normal file
@@ -0,0 +1,259 @@
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from jose import jwt
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.config import settings
|
||||
from app.database import Base, get_db
|
||||
from app.main import app
|
||||
from app.models.user import User
|
||||
from app.services.auth_service import get_password_hash
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def auth_test_env(tmp_path):
|
||||
db_path = tmp_path / 'test_auth.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with session_factory() as session:
|
||||
username_user = User(
|
||||
username='jarvis_admin',
|
||||
email='admin@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Admin User',
|
||||
is_active=True,
|
||||
)
|
||||
prefix_user = User(
|
||||
username='other_user',
|
||||
email='jarvis_admin@example.com',
|
||||
hashed_password=get_password_hash('othersecret123'),
|
||||
full_name='Prefix User',
|
||||
is_active=True,
|
||||
)
|
||||
inactive_user = User(
|
||||
username='disabled_user',
|
||||
email='disabled@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Disabled User',
|
||||
is_active=False,
|
||||
)
|
||||
ambiguous_user_one = User(
|
||||
username='alpha_user',
|
||||
email='alice@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Alice One',
|
||||
is_active=True,
|
||||
)
|
||||
ambiguous_user_two = User(
|
||||
username='beta_user',
|
||||
email='alice@another.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Alice Two',
|
||||
is_active=True,
|
||||
)
|
||||
fallback_user = User(
|
||||
username='fallback_target',
|
||||
email='fallback@example.com',
|
||||
hashed_password=get_password_hash('fallbacksecret123'),
|
||||
full_name='Fallback User',
|
||||
is_active=True,
|
||||
)
|
||||
wildcard_user = User(
|
||||
username='wildcard_target',
|
||||
email='alice1@example.com',
|
||||
hashed_password=get_password_hash('wildsecret123'),
|
||||
full_name='Wildcard User',
|
||||
is_active=True,
|
||||
)
|
||||
session.add_all([
|
||||
username_user,
|
||||
prefix_user,
|
||||
inactive_user,
|
||||
ambiguous_user_one,
|
||||
ambiguous_user_two,
|
||||
fallback_user,
|
||||
wildcard_user,
|
||||
])
|
||||
await session.commit()
|
||||
await session.refresh(username_user)
|
||||
await session.refresh(inactive_user)
|
||||
await session.refresh(fallback_user)
|
||||
await session.refresh(wildcard_user)
|
||||
|
||||
async def override_get_db():
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
try:
|
||||
yield {
|
||||
'username_user': username_user,
|
||||
'inactive_user': inactive_user,
|
||||
'prefix_user': prefix_user,
|
||||
'fallback_user': fallback_user,
|
||||
'wildcard_user': wildcard_user,
|
||||
}
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_accepts_username_and_returns_token(auth_test_env):
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post(
|
||||
'/api/auth/login',
|
||||
data={'username': 'jarvis_admin', 'password': 'secret123'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload['token_type'] == 'bearer'
|
||||
assert payload['access_token']
|
||||
claims = jwt.decode(payload['access_token'], settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
assert claims['sub'] == auth_test_env['username_user'].id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_prefers_username_over_email_prefix(auth_test_env):
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post(
|
||||
'/api/auth/login',
|
||||
data={'username': 'jarvis_admin', 'password': 'secret123'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
claims = jwt.decode(response.json()['access_token'], settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
assert claims['sub'] == auth_test_env['username_user'].id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_rejects_wrong_password_for_username(auth_test_env):
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post(
|
||||
'/api/auth/login',
|
||||
data={'username': 'jarvis_admin', 'password': 'wrong-password'},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()['detail'] == '用户名、邮箱或密码错误'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_rejects_inactive_username_user(auth_test_env):
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post(
|
||||
'/api/auth/login',
|
||||
data={'username': 'disabled_user', 'password': 'secret123'},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()['detail'] == '用户已被禁用'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_requires_username_and_returns_it(auth_test_env):
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post(
|
||||
'/api/auth/register',
|
||||
json={
|
||||
'username': 'new_user',
|
||||
'email': 'new@example.com',
|
||||
'password': 'secret123',
|
||||
'full_name': 'New User',
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
payload = response.json()
|
||||
assert payload['username'] == 'new_user'
|
||||
assert payload['email'] == 'new@example.com'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_accepts_exact_email(auth_test_env):
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post(
|
||||
'/api/auth/login',
|
||||
data={'username': 'admin@example.com', 'password': 'secret123'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
claims = jwt.decode(response.json()['access_token'], settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
assert claims['sub'] == auth_test_env['username_user'].id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_accepts_uuid_before_other_identifier_strategies(auth_test_env):
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post(
|
||||
'/api/auth/login',
|
||||
data={'username': auth_test_env['username_user'].id, 'password': 'secret123'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
claims = jwt.decode(response.json()['access_token'], settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
assert claims['sub'] == auth_test_env['username_user'].id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_accepts_unique_email_prefix(auth_test_env):
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post(
|
||||
'/api/auth/login',
|
||||
data={'username': 'fallback', 'password': 'fallbacksecret123'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
claims = jwt.decode(response.json()['access_token'], settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
assert claims['sub'] == auth_test_env['fallback_user'].id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_rejects_ambiguous_email_prefix(auth_test_env):
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post(
|
||||
'/api/auth/login',
|
||||
data={'username': 'alice', 'password': 'secret123'},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()['detail'] == '用户名、邮箱或密码错误'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_does_not_treat_like_wildcards_as_email_prefix_patterns(auth_test_env):
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post(
|
||||
'/api/auth/login',
|
||||
data={'username': 'alice_', 'password': 'wildsecret123'},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()['detail'] == '用户名、邮箱或密码错误'
|
||||
104
backend/tests/backend/app/test_main.py
Normal file
104
backend/tests/backend/app/test_main.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import pytest
|
||||
|
||||
from app import main as main_module
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_startup_calls_admin_bootstrap_before_logging_and_scheduler(monkeypatch):
|
||||
calls = []
|
||||
|
||||
async def fake_init_db():
|
||||
calls.append('init_db')
|
||||
|
||||
class DummySessionContext:
|
||||
async def __aenter__(self):
|
||||
calls.append('open_session')
|
||||
return object()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
calls.append('close_session')
|
||||
|
||||
async def fake_ensure_admin_user(session, settings):
|
||||
calls.append('ensure_admin_user')
|
||||
|
||||
async def fake_persist_system_log(**kwargs):
|
||||
calls.append('persist_system_log')
|
||||
|
||||
def fake_start_scheduler():
|
||||
calls.append('start_scheduler')
|
||||
|
||||
monkeypatch.setattr(main_module, 'init_db', fake_init_db)
|
||||
monkeypatch.setattr(main_module, 'async_session', lambda: DummySessionContext())
|
||||
monkeypatch.setattr(main_module, 'ensure_admin_user', fake_ensure_admin_user)
|
||||
monkeypatch.setattr(main_module, 'persist_system_log', fake_persist_system_log)
|
||||
monkeypatch.setattr(main_module, 'start_scheduler', fake_start_scheduler)
|
||||
|
||||
await main_module.run_startup()
|
||||
|
||||
assert calls == [
|
||||
'init_db',
|
||||
'open_session',
|
||||
'ensure_admin_user',
|
||||
'close_session',
|
||||
'persist_system_log',
|
||||
'start_scheduler',
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_startup_stops_before_logging_and_scheduler_when_admin_bootstrap_fails(monkeypatch):
|
||||
calls = []
|
||||
|
||||
async def fake_init_db():
|
||||
calls.append('init_db')
|
||||
|
||||
class DummySessionContext:
|
||||
async def __aenter__(self):
|
||||
calls.append('open_session')
|
||||
return object()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
calls.append('close_session')
|
||||
|
||||
async def fake_ensure_admin_user(session, settings):
|
||||
calls.append('ensure_admin_user')
|
||||
raise RuntimeError('bootstrap failed')
|
||||
|
||||
async def fake_persist_system_log(**kwargs):
|
||||
calls.append('persist_system_log')
|
||||
|
||||
def fake_start_scheduler():
|
||||
calls.append('start_scheduler')
|
||||
|
||||
monkeypatch.setattr(main_module, 'init_db', fake_init_db)
|
||||
monkeypatch.setattr(main_module, 'async_session', lambda: DummySessionContext())
|
||||
monkeypatch.setattr(main_module, 'ensure_admin_user', fake_ensure_admin_user)
|
||||
monkeypatch.setattr(main_module, 'persist_system_log', fake_persist_system_log)
|
||||
monkeypatch.setattr(main_module, 'start_scheduler', fake_start_scheduler)
|
||||
|
||||
with pytest.raises(RuntimeError, match='bootstrap failed'):
|
||||
await main_module.run_startup()
|
||||
|
||||
assert calls == [
|
||||
'init_db',
|
||||
'open_session',
|
||||
'ensure_admin_user',
|
||||
'close_session',
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_startup_rejects_default_secret_key_when_debug_is_disabled(monkeypatch):
|
||||
calls = []
|
||||
|
||||
async def fake_init_db():
|
||||
calls.append('init_db')
|
||||
|
||||
monkeypatch.setattr(main_module, 'init_db', fake_init_db)
|
||||
monkeypatch.setattr(main_module.settings, 'DEBUG', False)
|
||||
monkeypatch.setattr(main_module.settings, 'SECRET_KEY', 'change-me-in-production')
|
||||
|
||||
with pytest.raises(RuntimeError, match='SECRET_KEY'):
|
||||
await main_module.run_startup()
|
||||
|
||||
assert calls == []
|
||||
Reference in New Issue
Block a user