feat(auth): add admin bootstrap and username login

Initialize admin bootstrap settings during startup, persist username support in auth flows, and align frontend auth requests with local API behavior.
This commit is contained in:
2026-03-24 15:07:19 +08:00
parent 6f594631e9
commit a3aa15d339
13 changed files with 787 additions and 27 deletions

View File

@@ -3,15 +3,15 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import Literal
BASE_DIR = Path(__file__).resolve().parent.parent
ENV_FILE = BASE_DIR / ".env"
REPO_ROOT = Path(__file__).resolve().parents[2]
ENV_FILE = REPO_ROOT / ".env"
def _resolve_path(value: str) -> str:
path = Path(value)
if path.is_absolute():
return str(path)
return str((BASE_DIR / path).resolve())
return str((REPO_ROOT / path).resolve())
class Settings(BaseSettings):
@@ -31,10 +31,10 @@ class Settings(BaseSettings):
# === 数据库 ===
DATABASE_URL: str = "sqlite+aiosqlite:///./data/jarvis.db"
DATA_DIR: str = "./data"
DATA_DIR: str = "data"
# === ChromaDB ===
CHROMA_PERSIST_DIR: str = "./data/chroma"
CHROMA_PERSIST_DIR: str = "data/chroma"
# === LLM 配置 ===
# 支持: openai / claude / ollama / deepseek / custom
@@ -63,8 +63,15 @@ class Settings(BaseSettings):
CORS_ORIGINS: list[str] = ["http://localhost:5173", "http://localhost:3000"]
# === 文件上传 ===
UPLOAD_DIR: str = "./data/uploads"
UPLOAD_DIR: str = "data/uploads"
MAX_UPLOAD_SIZE: int = 50 * 1024 * 1024
MINERU_LANGUAGE: Literal["ch", "en"] = "ch"
# === 管理员 bootstrap ===
ADMIN: str = ""
ADMIN_EMAIL: str = ""
ADMIN_PASSWORD: str = ""
ADMIN_FULL_NAME: str = "Administrator"
# === 向量化 ===
EMBEDDING_MODEL: str = "text-embedding-3-small"

View File

@@ -3,6 +3,7 @@ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sess
from sqlalchemy.orm import DeclarativeBase
from app.config import settings
import os
import re
os.makedirs(settings.DATA_DIR, exist_ok=True)
@@ -37,6 +38,8 @@ async def init_db():
await ensure_log_columns(conn)
await ensure_message_columns(conn)
await ensure_document_columns(conn)
await ensure_user_columns(conn)
await ensure_forum_columns(conn)
async def ensure_log_columns(conn):
@@ -93,3 +96,93 @@ async def ensure_document_columns(conn):
for column, ddl in required_columns.items():
if column not in columns:
await conn.execute(text(ddl))
async def ensure_user_columns(conn):
rows = await _get_table_info(conn, 'users')
if not rows:
return
columns = {row[1] for row in rows}
if 'username' not in columns:
await conn.execute(text("ALTER TABLE users ADD COLUMN username VARCHAR(255)"))
rows = await _get_table_info(conn, 'users')
await _backfill_usernames(conn)
username_row = next(row for row in rows if row[1] == 'username')
indexes = await _get_index_info(conn, 'users')
has_username_index = any(row[1] == 'ix_users_username' and row[2] == 1 for row in indexes)
has_email_index = any(row[1] == 'ix_users_email' and row[2] == 1 for row in indexes)
if username_row[3] != 1 or not has_username_index or not has_email_index:
await _rebuild_users_table(conn)
async def ensure_forum_columns(conn):
rows = await _get_table_info(conn, 'forum_posts')
if not rows:
return
columns = {row[1] for row in rows}
required_columns = {
"board": "ALTER TABLE forum_posts ADD COLUMN board VARCHAR(100) DEFAULT 'general' NOT NULL",
"is_pinned": "ALTER TABLE forum_posts ADD COLUMN is_pinned BOOLEAN DEFAULT 0 NOT NULL",
}
for column, ddl in required_columns.items():
if column not in columns:
await conn.execute(text(ddl))
indexes = await _get_index_info(conn, 'forum_posts')
index_names = {row[1] for row in indexes}
if 'ix_forum_posts_board' not in index_names:
await conn.execute(text("CREATE INDEX IF NOT EXISTS ix_forum_posts_board ON forum_posts (board)"))
async def _backfill_usernames(conn):
result = await conn.execute(text("SELECT id, email, username FROM users ORDER BY created_at, id"))
users = result.fetchall()
seen_usernames: set[str] = set()
for user_id, email, username in users:
if username:
seen_usernames.add(username)
continue
base_username = _slugify_username((email or '').split('@', 1)[0])
candidate = base_username
suffix = 2
while candidate in seen_usernames:
candidate = f"{base_username}_{suffix}"
suffix += 1
await conn.execute(
text("UPDATE users SET username = :username WHERE id = :user_id AND username IS NULL"),
{"username": candidate, "user_id": user_id},
)
seen_usernames.add(candidate)
async def _rebuild_users_table(conn):
await conn.execute(text("CREATE TABLE users__new (id VARCHAR(36) PRIMARY KEY, username VARCHAR(255) NOT NULL, email VARCHAR(255) NOT NULL, hashed_password VARCHAR(255) NOT NULL, full_name VARCHAR(255), is_active BOOLEAN NOT NULL DEFAULT 1, is_superuser BOOLEAN NOT NULL DEFAULT 0, llm_config JSON, scheduler_config JSON, created_at DATETIME NOT NULL, updated_at DATETIME NOT NULL)"))
await conn.execute(text("INSERT INTO users__new (id, username, email, hashed_password, full_name, is_active, is_superuser, llm_config, scheduler_config, created_at, updated_at) SELECT id, username, email, hashed_password, full_name, COALESCE(is_active, 1), COALESCE(is_superuser, 0), llm_config, scheduler_config, COALESCE(created_at, CURRENT_TIMESTAMP), COALESCE(updated_at, CURRENT_TIMESTAMP) FROM users"))
await conn.execute(text("DROP TABLE users"))
await conn.execute(text("ALTER TABLE users__new RENAME TO users"))
await conn.execute(text("CREATE UNIQUE INDEX IF NOT EXISTS ix_users_username ON users (username)"))
await conn.execute(text("CREATE UNIQUE INDEX IF NOT EXISTS ix_users_email ON users (email)"))
async def _get_table_info(conn, table_name: str):
result = await conn.execute(text(f"PRAGMA table_info({table_name})"))
return result.fetchall()
async def _get_index_info(conn, table_name: str):
result = await conn.execute(text(f"PRAGMA index_list({table_name})"))
return result.fetchall()
def _slugify_username(value: str) -> str:
normalized = re.sub(r'[^a-z0-9_]+', '_', value.strip().lower())
normalized = re.sub(r'_+', '_', normalized).strip('_')
return normalized or 'user'

View File

@@ -3,7 +3,7 @@ from fastapi import FastAPI
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException
from app.database import init_db
from app.database import init_db, async_session
import app.models # noqa: F401 - 注册所有模型
from app.routers import (
auth_router,
@@ -23,6 +23,7 @@ from app.routers import (
)
from app.routers.scheduler import router as scheduler_router
from app.services.scheduler_service import start_scheduler, stop_scheduler, get_scheduler_status
from app.services.admin_bootstrap_service import ensure_admin_user
from app.config import settings
from app.logging_utils import (
setup_logging,
@@ -35,14 +36,23 @@ from app.logging_utils import (
import os
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动
setup_logging(settings.DEBUG)
os.makedirs(settings.DATA_DIR, exist_ok=True)
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
os.makedirs(settings.CHROMA_PERSIST_DIR, exist_ok=True)
INSECURE_SECRET_KEYS = {
'change-me-in-production',
'change-me-to-a-random-secret-key',
'jarvis-secret-key-change-in-production',
}
def validate_startup_security() -> None:
if not settings.DEBUG and settings.SECRET_KEY in INSECURE_SECRET_KEYS:
raise RuntimeError('SECRET_KEY must be changed before running with DEBUG disabled')
async def run_startup() -> None:
validate_startup_security()
await init_db()
async with async_session() as session:
await ensure_admin_user(session, settings)
await persist_system_log(
message="application_started",
source="app",
@@ -50,6 +60,16 @@ async def lifespan(app: FastAPI):
details={"version": settings.APP_VERSION},
)
start_scheduler()
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动
setup_logging(settings.DEBUG)
os.makedirs(settings.DATA_DIR, exist_ok=True)
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
os.makedirs(settings.CHROMA_PERSIST_DIR, exist_ok=True)
await run_startup()
yield
# 关闭
stop_scheduler()

View File

@@ -5,6 +5,7 @@ from app.models.base import BaseModel
class User(BaseModel):
__tablename__ = "users"
username = Column(String(255), unique=True, nullable=False, index=True)
email = Column(String(255), unique=True, nullable=False, index=True)
hashed_password = Column(String(255), nullable=False)
full_name = Column(String(255), nullable=True)

View File

@@ -2,6 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from app.database import get_db
from app.models.user import User
from app.schemas.auth import UserCreate, UserOut, Token
@@ -32,18 +33,30 @@ async def get_current_user(
@router.post("/register", response_model=UserOut, status_code=status.HTTP_201_CREATED)
async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
# 检查邮箱是否已存在
username = user_data.username.strip()
if not username:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名不能为空")
result = await db.execute(select(User).where(User.username == username))
if result.scalar_one_or_none():
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已被注册")
result = await db.execute(select(User).where(User.email == user_data.email))
if result.scalar_one_or_none():
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="邮箱已被注册")
# 创建用户
user = User(
username=username,
email=user_data.email,
hashed_password=get_password_hash(user_data.password),
full_name=user_data.full_name,
)
db.add(user)
await db.commit()
try:
await db.commit()
except IntegrityError:
await db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名或邮箱已被注册")
await db.refresh(user)
return user
@@ -51,24 +64,34 @@ async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
@router.post("/login", response_model=Token)
async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: AsyncSession = Depends(get_db)):
identifier = form_data.username.strip()
# 支持:邮箱 / UUID / 用户名前缀
user = None
# 1. 尝试 UUID
import re
if re.match(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', identifier, re.I):
result = await db.execute(select(User).where(User.id == identifier))
user = result.scalar_one_or_none()
# 2. 尝试邮箱
if not user:
result = await db.execute(select(User).where(User.username == identifier))
user = result.scalar_one_or_none()
if not user:
result = await db.execute(select(User).where(User.email == identifier))
user = result.scalar_one_or_none()
# 3. 尝试用户名前缀email@ 前面的部分)
if not user and '@' not in identifier:
result = await db.execute(select(User).where(User.email.like(f"{identifier}@%")))
user = result.scalar_one_or_none()
escaped_identifier = (
identifier
.replace('\\', '\\\\')
.replace('%', '\\%')
.replace('_', '\\_')
)
result = await db.execute(
select(User).where(User.email.like(f"{escaped_identifier}@%", escape='\\'))
)
prefix_matches = result.scalars().all()
if len(prefix_matches) == 1:
user = prefix_matches[0]
if not user or not verify_password(form_data.password, user.hashed_password):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名、邮箱或密码错误")

View File

@@ -2,6 +2,7 @@ from pydantic import BaseModel, EmailStr
class UserCreate(BaseModel):
username: str
email: EmailStr
password: str
full_name: str | None = None
@@ -9,6 +10,7 @@ class UserCreate(BaseModel):
class UserOut(BaseModel):
id: str
username: str
email: str
full_name: str | None
is_active: bool

View File

@@ -0,0 +1,60 @@
from sqlalchemy import or_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.user import User
from app.services.auth_service import get_password_hash
def _is_bootstrap_enabled(settings) -> bool:
return bool(settings.ADMIN.strip() and settings.ADMIN_EMAIL.strip() and settings.ADMIN_PASSWORD.strip())
async def ensure_admin_user(db: AsyncSession, settings) -> None:
if not _is_bootstrap_enabled(settings):
return
result = await db.execute(
select(User).where(
or_(User.username == settings.ADMIN.strip(), User.email == settings.ADMIN_EMAIL.strip())
)
)
existing_user = result.scalar_one_or_none()
if existing_user:
if (
existing_user.username == settings.ADMIN.strip()
and existing_user.email == settings.ADMIN_EMAIL.strip()
and existing_user.is_superuser
):
return
raise RuntimeError('admin bootstrap identity conflict')
admin_user = User(
username=settings.ADMIN.strip(),
email=settings.ADMIN_EMAIL.strip(),
hashed_password=get_password_hash(settings.ADMIN_PASSWORD),
full_name=settings.ADMIN_FULL_NAME or None,
is_active=True,
is_superuser=True,
)
db.add(admin_user)
try:
await db.commit()
except IntegrityError:
await db.rollback()
result = await db.execute(
select(User).where(
or_(User.username == settings.ADMIN.strip(), User.email == settings.ADMIN_EMAIL.strip())
)
)
existing_user = result.scalar_one_or_none()
if (
existing_user
and existing_user.username == settings.ADMIN.strip()
and existing_user.email == settings.ADMIN_EMAIL.strip()
and existing_user.is_superuser
):
return
raise
await db.refresh(admin_user)